Skip to content
This repository has been archived by the owner on Dec 7, 2021. It is now read-only.

Commit

Permalink
feat: CNTK Export Provider (#771)
Browse files Browse the repository at this point in the history
Adds CNTK export provider into v2

Resolves #754
  • Loading branch information
wbreza committed Apr 29, 2019
1 parent 0fe6386 commit c10c971
Show file tree
Hide file tree
Showing 15 changed files with 353 additions and 28 deletions.
11 changes: 7 additions & 4 deletions src/common/localization/en-us.ts
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ export const english: IAppStrings = {
tagged: "Only tagged Assets",
},
},
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
},
},
vottJson: {
Expand Down Expand Up @@ -344,15 +348,14 @@ export const english: IAppStrings = {
},
pascalVoc: {
displayName: "Pascal VOC",
testTrainSplit: {
title: "Test / Train Split",
description: "The test train split to use for exported data",
},
exportUnassigned: {
title: "Export Unassigned",
description: "Whether or not to include unassigned tags in exported data",
},
},
cntk: {
displayName: "Microsoft Cognitive Toolkit (CNTK)",
},
},
messages: {
saveSuccess: "Successfully saved export settings",
Expand Down
11 changes: 7 additions & 4 deletions src/common/localization/es-cl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ export const spanish: IAppStrings = {
tagged: "Solo activos etiquetados",
},
},
testTrainSplit: {
title: "La división para entrenar y comprobar",
description: "La división de datos para utilizar entre el entrenamiento y la comprobación",
},
},
},
vottJson: {
Expand Down Expand Up @@ -346,15 +350,14 @@ export const spanish: IAppStrings = {
},
pascalVoc: {
displayName: "Pascal VOC",
testTrainSplit: {
title: "Prueba/tren Split",
description: "La división del tren de prueba que se utilizará para los datos exportados",
},
exportUnassigned: {
title: "Exportar sin asignar",
description: "Si se incluyen o no etiquetas no asignadas en los datos exportados",
},
},
cntk: {
displayName: "Microsoft Cognitive Toolkit (CNTK)",
},
},
messages: {
saveSuccess: "Configuración de exportación guardada correctamente",
Expand Down
11 changes: 7 additions & 4 deletions src/common/strings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ export interface IAppStrings {
tagged: string,
},
},
testTrainSplit: {
title: string,
description: string,
},
},
},
vottJson: {
Expand Down Expand Up @@ -342,15 +346,14 @@ export interface IAppStrings {
},
pascalVoc: {
displayName: string,
testTrainSplit: {
title: string,
description: string,
},
exportUnassigned: {
title: string,
description: string,
},
},
cntk: {
displayName: string,
},
},
messages: {
saveSuccess: string;
Expand Down
2 changes: 1 addition & 1 deletion src/providers/export/azureCustomVision.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { ExportProviderFactory } from "./exportProviderFactory";
import MockFactory from "../../common/mockFactory";
import {
IProject, AssetState, IAsset, IAssetMetadata,
RegionType, IRegion, IExportProviderOptions, AssetType,
RegionType, IRegion, IExportProviderOptions,
} from "../../models/applicationState";
import { ExportAssetState } from "./exportProvider";
jest.mock("./azureCustomVision/azureCustomVisionService");
Expand Down
28 changes: 28 additions & 0 deletions src/providers/export/cntk.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"type": "object",
"title": "${strings.export.providers.cntk.displayName}",
"properties": {
"assetState": {
"type": "string",
"title": "${strings.export.providers.common.properties.assetState.title}",
"description": "${strings.export.providers.common.properties.assetState.description}",
"enum": [
"all",
"visited",
"tagged"
],
"default": "visited",
"enumNames": [
"${strings.export.providers.common.properties.assetState.options.all}",
"${strings.export.providers.common.properties.assetState.options.visited}",
"${strings.export.providers.common.properties.assetState.options.tagged}"
]
},
"testTrainSplit": {
"title": "${strings.export.providers.common.properties.testTrainSplit.title}",
"description": "${strings.export.providers.common.properties.testTrainSplit.description}",
"type": "number",
"default": 80
}
}
}
167 changes: 167 additions & 0 deletions src/providers/export/cntk.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import _ from "lodash";
import os from "os";
import { CntkExportProvider, ICntkExportProviderOptions } from "./cntk";
import { IProject, AssetState, IAssetMetadata } from "../../models/applicationState";
import { AssetProviderFactory } from "../storage/assetProviderFactory";
import { ExportAssetState } from "./exportProvider";
import MockFactory from "../../common/mockFactory";
import registerMixins from "../../registerMixins";
import registerProviders from "../../registerProviders";
import { ExportProviderFactory } from "./exportProviderFactory";
jest.mock("../../services/assetService");
import { AssetService } from "../../services/assetService";

jest.mock("../storage/localFileSystemProxy");
import { LocalFileSystemProxy } from "../storage/localFileSystemProxy";
import HtmlFileReader from "../../common/htmlFileReader";
import { appInfo } from "../../common/appInfo";

describe("CNTK Export Provider", () => {
const testAssets = MockFactory.createTestAssets(10, 1);
let testProject: IProject = null;

const defaultOptions: ICntkExportProviderOptions = {
assetState: ExportAssetState.Tagged,
testTrainSplit: 80,
};

function createProvider(project: IProject): CntkExportProvider {
return new CntkExportProvider(
project,
project.exportFormat.providerOptions as ICntkExportProviderOptions,
);
}

beforeAll(() => {
registerMixins();
registerProviders();

HtmlFileReader.getAssetBlob = jest.fn(() => {
return Promise.resolve(new Blob(["Some binary data"]));
});
});

beforeEach(() => {
jest.resetAllMocks();

testAssets.forEach((asset) => {
asset.state = AssetState.Tagged;
});

testProject = {
...MockFactory.createTestProject("TestProject"),
assets: _.keyBy(testAssets, (a) => a.id),
exportFormat: {
providerType: "cntk",
providerOptions: defaultOptions,
},
};

AssetProviderFactory.create = jest.fn(() => {
return {
getAssets: jest.fn(() => Promise.resolve(testAssets)),
};
});

const assetServiceMock = AssetService as jest.Mocked<typeof AssetService>;
assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => {
const assetMetadata = {
asset: { ...asset },
regions: [
MockFactory.createTestRegion("region-1", ["tag1"]),
MockFactory.createTestRegion("region-2", ["tag1"]),
],
version: appInfo.version,
};

return Promise.resolve(assetMetadata);
});
});

it("Is defined", () => {
expect(CntkExportProvider).toBeDefined();
});

it("Can be instantiated through the factory", () => {
const options: ICntkExportProviderOptions = {
assetState: ExportAssetState.All,
testTrainSplit: 80,
};
const exportProvider = ExportProviderFactory.create("cntk", testProject, options);
expect(exportProvider).not.toBeNull();
expect(exportProvider).toBeInstanceOf(CntkExportProvider);
});

it("Creates correct folder structure", async () => {
const provider = createProvider(testProject);
await provider.export();

const storageProviderMock = LocalFileSystemProxy as any;
const createContainerCalls = storageProviderMock.mock.instances[0].createContainer.mock.calls;
const createContainerArgs = createContainerCalls.map((args) => args[0]);

const expectedFolderPath = "Project-TestProject-CNTK-export";
expect(createContainerArgs).toContain(expectedFolderPath);
expect(createContainerArgs).toContain(`${expectedFolderPath}/positive`);
expect(createContainerArgs).toContain(`${expectedFolderPath}/negative`);
expect(createContainerArgs).toContain(`${expectedFolderPath}/testImages`);
});

it("Writes export files to storage provider", async () => {
const provider = createProvider(testProject);
const getAssetsSpy = jest.spyOn(provider, "getAssetsForExport");

await provider.export();

const assetsToExport = await getAssetsSpy.mock.results[0].value;
const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100;
const testCount = Math.ceil(assetsToExport.length * testSplit);
const testArray = assetsToExport.slice(0, testCount);
const trainArray = assetsToExport.slice(testCount, assetsToExport.length);

const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;

expect(writeBinaryCalls).toHaveLength(testAssets.length);
expect(writeTextFileCalls).toHaveLength(testAssets.length * 2);

testArray.forEach((assetMetadata) => {
const testFolderPath = "Project-TestProject-CNTK-export/testImages";
assertExportedAsset(testFolderPath, assetMetadata);
});

trainArray.forEach((assetMetadata) => {
const trainFolderPath = "Project-TestProject-CNTK-export/positive";
assertExportedAsset(trainFolderPath, assetMetadata);
});
});

function assertExportedAsset(folderPath: string, assetMetadata: IAssetMetadata) {
const storageProviderMock = LocalFileSystemProxy as any;
const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls;
const writeBinaryFilenameArgs = writeBinaryCalls.map((args) => args[0]);
const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls;
const writeTextFilenameArgs = writeTextFileCalls.map((args) => args[0]);

expect(writeBinaryFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}`);
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.labels.tsv`);
expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.tsv`);

const writeLabelsCall = writeTextFileCalls
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.labels.tsv`) >= 0);

const writeBoxesCall = writeTextFileCalls
.find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.tsv`) >= 0);

const expectedLabelData = `${assetMetadata.regions[0].tags[0]}${os.EOL}${assetMetadata.regions[1].tags[0]}`;
expect(writeLabelsCall[1]).toEqual(expectedLabelData);

const expectedBoxData = [];
// tslint:disable-next-line:max-line-length
expectedBoxData.push(`${assetMetadata.regions[0].boundingBox.left}\t${assetMetadata.regions[0].boundingBox.left + assetMetadata.regions[0].boundingBox.width}\t${assetMetadata.regions[0].boundingBox.top}\t${assetMetadata.regions[0].boundingBox.top + assetMetadata.regions[0].boundingBox.height}`);
// tslint:disable-next-line:max-line-length
expectedBoxData.push(`${assetMetadata.regions[1].boundingBox.left}\t${assetMetadata.regions[1].boundingBox.left + assetMetadata.regions[1].boundingBox.width}\t${assetMetadata.regions[1].boundingBox.top}\t${assetMetadata.regions[1].boundingBox.top + assetMetadata.regions[1].boundingBox.height}`);
expect(writeBoxesCall[1]).toEqual(expectedBoxData.join(os.EOL));
}
});
Loading

0 comments on commit c10c971

Please sign in to comment.