From 929c91d4c49da176aa17eafd0f614ceb689506b9 Mon Sep 17 00:00:00 2001 From: Wallace Breza Date: Wed, 17 Apr 2019 16:43:31 -0700 Subject: [PATCH] feat: CNTK Export Provider (#771) Adds CNTK export provider into v2 Resolves #754 --- src/common/localization/en-us.ts | 11 +- src/common/localization/es-cl.ts | 11 +- src/common/strings.ts | 11 +- .../export/azureCustomVision.test.ts | 2 +- src/providers/export/cntk.json | 28 +++ src/providers/export/cntk.test.ts | 167 ++++++++++++++++++ src/providers/export/cntk.ts | 105 +++++++++++ src/providers/export/cntk.ui.json | 5 + src/providers/export/pascalVOC.json | 4 +- src/providers/export/pascalVOC.ts | 5 +- src/providers/export/tensorFlowRecords.ts | 2 +- src/providers/export/vottJson.test.ts | 2 +- src/redux/actions/projectActions.test.ts | 1 + src/redux/actions/projectActions.ts | 21 ++- src/registerProviders.ts | 6 + 15 files changed, 353 insertions(+), 28 deletions(-) create mode 100644 src/providers/export/cntk.json create mode 100644 src/providers/export/cntk.test.ts create mode 100644 src/providers/export/cntk.ts create mode 100644 src/providers/export/cntk.ui.json diff --git a/src/common/localization/en-us.ts b/src/common/localization/en-us.ts index 58c37cc5ce..5f8f10b6fd 100644 --- a/src/common/localization/en-us.ts +++ b/src/common/localization/en-us.ts @@ -287,6 +287,10 @@ export const english: IAppStrings = { tagged: "Only tagged Assets", }, }, + testTrainSplit: { + title: "Test / Train Split", + description: "The test train split to use for exported data", + }, }, }, vottJson: { @@ -344,15 +348,14 @@ export const english: IAppStrings = { }, pascalVoc: { displayName: "Pascal VOC", - testTrainSplit: { - title: "Test / Train Split", - description: "The test train split to use for exported data", - }, exportUnassigned: { title: "Export Unassigned", description: "Whether or not to include unassigned tags in exported data", }, }, + cntk: { + displayName: "Microsoft Cognitive Toolkit (CNTK)", + }, }, messages: { saveSuccess: "Successfully saved export settings", diff --git a/src/common/localization/es-cl.ts b/src/common/localization/es-cl.ts index 8c0c9114ba..1fc212fa24 100644 --- a/src/common/localization/es-cl.ts +++ b/src/common/localization/es-cl.ts @@ -289,6 +289,10 @@ export const spanish: IAppStrings = { tagged: "Solo activos etiquetados", }, }, + testTrainSplit: { + title: "La división para entrenar y comprobar", + description: "La división de datos para utilizar entre el entrenamiento y la comprobación", + }, }, }, vottJson: { @@ -346,15 +350,14 @@ export const spanish: IAppStrings = { }, pascalVoc: { displayName: "Pascal VOC", - testTrainSplit: { - title: "Prueba/tren Split", - description: "La división del tren de prueba que se utilizará para los datos exportados", - }, exportUnassigned: { title: "Exportar sin asignar", description: "Si se incluyen o no etiquetas no asignadas en los datos exportados", }, }, + cntk: { + displayName: "Microsoft Cognitive Toolkit (CNTK)", + }, }, messages: { saveSuccess: "Configuración de exportación guardada correctamente", diff --git a/src/common/strings.ts b/src/common/strings.ts index 7a7a20991a..a3c5bf89e4 100644 --- a/src/common/strings.ts +++ b/src/common/strings.ts @@ -285,6 +285,10 @@ export interface IAppStrings { tagged: string, }, }, + testTrainSplit: { + title: string, + description: string, + }, }, }, vottJson: { @@ -342,15 +346,14 @@ export interface IAppStrings { }, pascalVoc: { displayName: string, - testTrainSplit: { - title: string, - description: string, - }, exportUnassigned: { title: string, description: string, }, }, + cntk: { + displayName: string, + }, }, messages: { saveSuccess: string; diff --git a/src/providers/export/azureCustomVision.test.ts b/src/providers/export/azureCustomVision.test.ts index 3a4f23bfac..36c1f21fa7 100644 --- a/src/providers/export/azureCustomVision.test.ts +++ b/src/providers/export/azureCustomVision.test.ts @@ -6,7 +6,7 @@ import { ExportProviderFactory } from "./exportProviderFactory"; import MockFactory from "../../common/mockFactory"; import { IProject, AssetState, IAsset, IAssetMetadata, - RegionType, IRegion, IExportProviderOptions, AssetType, + RegionType, IRegion, IExportProviderOptions, } from "../../models/applicationState"; import { ExportAssetState } from "./exportProvider"; jest.mock("./azureCustomVision/azureCustomVisionService"); diff --git a/src/providers/export/cntk.json b/src/providers/export/cntk.json new file mode 100644 index 0000000000..b0adc18fe7 --- /dev/null +++ b/src/providers/export/cntk.json @@ -0,0 +1,28 @@ +{ + "type": "object", + "title": "${strings.export.providers.cntk.displayName}", + "properties": { + "assetState": { + "type": "string", + "title": "${strings.export.providers.common.properties.assetState.title}", + "description": "${strings.export.providers.common.properties.assetState.description}", + "enum": [ + "all", + "visited", + "tagged" + ], + "default": "visited", + "enumNames": [ + "${strings.export.providers.common.properties.assetState.options.all}", + "${strings.export.providers.common.properties.assetState.options.visited}", + "${strings.export.providers.common.properties.assetState.options.tagged}" + ] + }, + "testTrainSplit": { + "title": "${strings.export.providers.common.properties.testTrainSplit.title}", + "description": "${strings.export.providers.common.properties.testTrainSplit.description}", + "type": "number", + "default": 80 + } + } +} diff --git a/src/providers/export/cntk.test.ts b/src/providers/export/cntk.test.ts new file mode 100644 index 0000000000..3d4d66c9aa --- /dev/null +++ b/src/providers/export/cntk.test.ts @@ -0,0 +1,167 @@ +import _ from "lodash"; +import os from "os"; +import { CntkExportProvider, ICntkExportProviderOptions } from "./cntk"; +import { IProject, AssetState, IAssetMetadata } from "../../models/applicationState"; +import { AssetProviderFactory } from "../storage/assetProviderFactory"; +import { ExportAssetState } from "./exportProvider"; +import MockFactory from "../../common/mockFactory"; +import registerMixins from "../../registerMixins"; +import registerProviders from "../../registerProviders"; +import { ExportProviderFactory } from "./exportProviderFactory"; +jest.mock("../../services/assetService"); +import { AssetService } from "../../services/assetService"; + +jest.mock("../storage/localFileSystemProxy"); +import { LocalFileSystemProxy } from "../storage/localFileSystemProxy"; +import HtmlFileReader from "../../common/htmlFileReader"; +import { appInfo } from "../../common/appInfo"; + +describe("CNTK Export Provider", () => { + const testAssets = MockFactory.createTestAssets(10, 1); + let testProject: IProject = null; + + const defaultOptions: ICntkExportProviderOptions = { + assetState: ExportAssetState.Tagged, + testTrainSplit: 80, + }; + + function createProvider(project: IProject): CntkExportProvider { + return new CntkExportProvider( + project, + project.exportFormat.providerOptions as ICntkExportProviderOptions, + ); + } + + beforeAll(() => { + registerMixins(); + registerProviders(); + + HtmlFileReader.getAssetBlob = jest.fn(() => { + return Promise.resolve(new Blob(["Some binary data"])); + }); + }); + + beforeEach(() => { + jest.resetAllMocks(); + + testAssets.forEach((asset) => { + asset.state = AssetState.Tagged; + }); + + testProject = { + ...MockFactory.createTestProject("TestProject"), + assets: _.keyBy(testAssets, (a) => a.id), + exportFormat: { + providerType: "cntk", + providerOptions: defaultOptions, + }, + }; + + AssetProviderFactory.create = jest.fn(() => { + return { + getAssets: jest.fn(() => Promise.resolve(testAssets)), + }; + }); + + const assetServiceMock = AssetService as jest.Mocked; + assetServiceMock.prototype.getAssetMetadata = jest.fn((asset) => { + const assetMetadata = { + asset: { ...asset }, + regions: [ + MockFactory.createTestRegion("region-1", ["tag1"]), + MockFactory.createTestRegion("region-2", ["tag1"]), + ], + version: appInfo.version, + }; + + return Promise.resolve(assetMetadata); + }); + }); + + it("Is defined", () => { + expect(CntkExportProvider).toBeDefined(); + }); + + it("Can be instantiated through the factory", () => { + const options: ICntkExportProviderOptions = { + assetState: ExportAssetState.All, + testTrainSplit: 80, + }; + const exportProvider = ExportProviderFactory.create("cntk", testProject, options); + expect(exportProvider).not.toBeNull(); + expect(exportProvider).toBeInstanceOf(CntkExportProvider); + }); + + it("Creates correct folder structure", async () => { + const provider = createProvider(testProject); + await provider.export(); + + const storageProviderMock = LocalFileSystemProxy as any; + const createContainerCalls = storageProviderMock.mock.instances[0].createContainer.mock.calls; + const createContainerArgs = createContainerCalls.map((args) => args[0]); + + const expectedFolderPath = "Project-TestProject-CNTK-export"; + expect(createContainerArgs).toContain(expectedFolderPath); + expect(createContainerArgs).toContain(`${expectedFolderPath}/positive`); + expect(createContainerArgs).toContain(`${expectedFolderPath}/negative`); + expect(createContainerArgs).toContain(`${expectedFolderPath}/testImages`); + }); + + it("Writes export files to storage provider", async () => { + const provider = createProvider(testProject); + const getAssetsSpy = jest.spyOn(provider, "getAssetsForExport"); + + await provider.export(); + + const assetsToExport = await getAssetsSpy.mock.results[0].value; + const testSplit = (100 - (defaultOptions.testTrainSplit || 80)) / 100; + const testCount = Math.ceil(assetsToExport.length * testSplit); + const testArray = assetsToExport.slice(0, testCount); + const trainArray = assetsToExport.slice(testCount, assetsToExport.length); + + const storageProviderMock = LocalFileSystemProxy as any; + const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls; + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls; + + expect(writeBinaryCalls).toHaveLength(testAssets.length); + expect(writeTextFileCalls).toHaveLength(testAssets.length * 2); + + testArray.forEach((assetMetadata) => { + const testFolderPath = "Project-TestProject-CNTK-export/testImages"; + assertExportedAsset(testFolderPath, assetMetadata); + }); + + trainArray.forEach((assetMetadata) => { + const trainFolderPath = "Project-TestProject-CNTK-export/positive"; + assertExportedAsset(trainFolderPath, assetMetadata); + }); + }); + + function assertExportedAsset(folderPath: string, assetMetadata: IAssetMetadata) { + const storageProviderMock = LocalFileSystemProxy as any; + const writeBinaryCalls = storageProviderMock.mock.instances[0].writeBinary.mock.calls; + const writeBinaryFilenameArgs = writeBinaryCalls.map((args) => args[0]); + const writeTextFileCalls = storageProviderMock.mock.instances[0].writeText.mock.calls; + const writeTextFilenameArgs = writeTextFileCalls.map((args) => args[0]); + + expect(writeBinaryFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}`); + expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.labels.tsv`); + expect(writeTextFilenameArgs).toContain(`${folderPath}/${assetMetadata.asset.name}.bboxes.tsv`); + + const writeLabelsCall = writeTextFileCalls + .find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.labels.tsv`) >= 0); + + const writeBoxesCall = writeTextFileCalls + .find((args: string[]) => args[0].indexOf(`${assetMetadata.asset.name}.bboxes.tsv`) >= 0); + + const expectedLabelData = `${assetMetadata.regions[0].tags[0]}${os.EOL}${assetMetadata.regions[1].tags[0]}`; + expect(writeLabelsCall[1]).toEqual(expectedLabelData); + + const expectedBoxData = []; + // tslint:disable-next-line:max-line-length + expectedBoxData.push(`${assetMetadata.regions[0].boundingBox.left}\t${assetMetadata.regions[0].boundingBox.left + assetMetadata.regions[0].boundingBox.width}\t${assetMetadata.regions[0].boundingBox.top}\t${assetMetadata.regions[0].boundingBox.top + assetMetadata.regions[0].boundingBox.height}`); + // tslint:disable-next-line:max-line-length + expectedBoxData.push(`${assetMetadata.regions[1].boundingBox.left}\t${assetMetadata.regions[1].boundingBox.left + assetMetadata.regions[1].boundingBox.width}\t${assetMetadata.regions[1].boundingBox.top}\t${assetMetadata.regions[1].boundingBox.top + assetMetadata.regions[1].boundingBox.height}`); + expect(writeBoxesCall[1]).toEqual(expectedBoxData.join(os.EOL)); + } +}); diff --git a/src/providers/export/cntk.ts b/src/providers/export/cntk.ts new file mode 100644 index 0000000000..2d14864328 --- /dev/null +++ b/src/providers/export/cntk.ts @@ -0,0 +1,105 @@ +import os from "os"; +import { ExportProvider, IExportResults } from "./exportProvider"; +import { IAssetMetadata, IExportProviderOptions, IProject } from "../../models/applicationState"; +import HtmlFileReader from "../../common/htmlFileReader"; +import Guard from "../../common/guard"; + +enum ExportSplit { + Test, + Train, +} + +/** + * Export options for CNTK export provider + */ +export interface ICntkExportProviderOptions extends IExportProviderOptions { + /** The test / train split ratio for exporting data */ + testTrainSplit?: number; +} + +/** + * CNTK Export provider + */ +export class CntkExportProvider extends ExportProvider { + private exportFolderName: string; + + constructor(project: IProject, options: ICntkExportProviderOptions) { + super(project, options); + Guard.null(options); + + this.exportFolderName = `${this.project.name.replace(/\s/g, "-")}-CNTK-export`; + } + + public async export(): Promise { + await this.createFolderStructure(); + const assetsToExport = await this.getAssetsForExport(); + const testSplit = (100 - (this.options.testTrainSplit || 80)) / 100; + const testCount = Math.ceil(assetsToExport.length * testSplit); + const testArray = assetsToExport.slice(0, testCount); + + const results = await assetsToExport.mapAsync(async (assetMetadata) => { + try { + const exportSplit = testArray.find((am) => am.asset.id === assetMetadata.asset.id) + ? ExportSplit.Test + : ExportSplit.Train; + + await this.exportAssetFrame(assetMetadata, exportSplit); + return { + asset: assetMetadata, + success: true, + }; + } catch (e) { + return { + asset: assetMetadata, + success: false, + error: e, + }; + } + }); + + return { + completed: results.filter((r) => r.success), + errors: results.filter((r) => !r.success), + count: results.length, + }; + } + + private async exportAssetFrame(assetMetadata: IAssetMetadata, exportSplit: ExportSplit) { + const labelData = []; + const boundingBoxData = []; + + assetMetadata.regions.forEach((region) => { + region.tags.forEach((tagName) => { + labelData.push(tagName); + // tslint:disable-next-line:max-line-length + boundingBoxData.push(`${region.boundingBox.left}\t${region.boundingBox.left + region.boundingBox.width}\t${region.boundingBox.top}\t${region.boundingBox.top + region.boundingBox.height}`); + }); + }); + + const folderName = exportSplit === ExportSplit.Train ? "positive" : "testImages"; + const labelsPath = `${this.exportFolderName}/${folderName}/${assetMetadata.asset.name}.bboxes.labels.tsv`; + const boundingBoxPath = `${this.exportFolderName}/${folderName}/${assetMetadata.asset.name}.bboxes.tsv`; + const binaryPath = `${this.exportFolderName}/${folderName}/${assetMetadata.asset.name}`; + + const buffer = await HtmlFileReader.getAssetArray(assetMetadata.asset); + + await Promise.all([ + this.storageProvider.writeText(labelsPath, labelData.join(os.EOL)), + this.storageProvider.writeText(boundingBoxPath, boundingBoxData.join(os.EOL)), + this.storageProvider.writeBinary(binaryPath, Buffer.from(buffer)), + ]); + } + + private async createFolderStructure(): Promise { + const positiveFolder = `${this.exportFolderName}/positive`; + const negativeFolder = `${this.exportFolderName}/negative`; + const testImagesFolder = `${this.exportFolderName}/testImages`; + + await this.storageProvider.createContainer(this.exportFolderName); + + await [positiveFolder, negativeFolder, testImagesFolder] + .forEachAsync(async (folderPath) => { + await this.storageProvider.createContainer(folderPath); + }); + } +} diff --git a/src/providers/export/cntk.ui.json b/src/providers/export/cntk.ui.json new file mode 100644 index 0000000000..24dffa9463 --- /dev/null +++ b/src/providers/export/cntk.ui.json @@ -0,0 +1,5 @@ +{ + "testTrainSplit": { + "ui:widget": "slider" + } +} diff --git a/src/providers/export/pascalVOC.json b/src/providers/export/pascalVOC.json index 01bc1fa04d..0a5924ac89 100644 --- a/src/providers/export/pascalVOC.json +++ b/src/providers/export/pascalVOC.json @@ -19,8 +19,8 @@ ] }, "testTrainSplit": { - "title": "${strings.export.providers.pascalVoc.testTrainSplit.title}", - "description": "${strings.export.providers.pascalVoc.testTrainSplit.description}", + "title": "${strings.export.providers.common.properties.testTrainSplit.title}", + "description": "${strings.export.providers.common.properties.testTrainSplit.description}", "type": "number", "default": 80 }, diff --git a/src/providers/export/pascalVOC.ts b/src/providers/export/pascalVOC.ts index 97287608ce..26efff9ca0 100644 --- a/src/providers/export/pascalVOC.ts +++ b/src/providers/export/pascalVOC.ts @@ -1,11 +1,10 @@ import _ from "lodash"; import { ExportProvider } from "./exportProvider"; -import { IProject, IAssetMetadata, RegionType, ITag, IExportProviderOptions } from "../../models/applicationState"; +import { IProject, IAssetMetadata, ITag, IExportProviderOptions } from "../../models/applicationState"; import Guard from "../../common/guard"; import HtmlFileReader from "../../common/htmlFileReader"; import { itemTemplate, annotationTemplate, objectTemplate } from "./pascalVOC/pascalVOCTemplates"; import { interpolate } from "../../common/strings"; -import { PlatformType } from "../../common/hostProcess"; import os from "os"; interface IObjectInfo { @@ -53,7 +52,7 @@ export class PascalVOCExportProvider extends ExportProvider assetMetadata.asset.id); // Create Export Folder - const exportFolderName = `${this.project.name.replace(" ", "-")}-PascalVOC-export`; + const exportFolderName = `${this.project.name.replace(/\s/g, "-")}-PascalVOC-export`; await this.storageProvider.createContainer(exportFolderName); await this.exportImages(exportFolderName, allAssets); diff --git a/src/providers/export/tensorFlowRecords.ts b/src/providers/export/tensorFlowRecords.ts index f6c7defdaf..56dd5eacb9 100644 --- a/src/providers/export/tensorFlowRecords.ts +++ b/src/providers/export/tensorFlowRecords.ts @@ -41,7 +41,7 @@ export class TFRecordsExportProvider extends ExportProvider { exportObject.assets = _.keyBy(allAssets, (assetMetadata) => assetMetadata.asset.id); // Create Export Folder - const exportFolderName = `${this.project.name.replace(" ", "-")}-TFRecords-export`; + const exportFolderName = `${this.project.name.replace(/\s/g, "-")}-TFRecords-export`; await this.storageProvider.createContainer(exportFolderName); await this.exportPBTXT(exportFolderName, this.project); diff --git a/src/providers/export/vottJson.test.ts b/src/providers/export/vottJson.test.ts index da23e3e020..bd556c2e9c 100644 --- a/src/providers/export/vottJson.test.ts +++ b/src/providers/export/vottJson.test.ts @@ -3,7 +3,7 @@ import { VottJsonExportProvider, IVottJsonExportProviderOptions } from "./vottJs import registerProviders from "../../registerProviders"; import { ExportAssetState } from "./exportProvider"; import { ExportProviderFactory } from "./exportProviderFactory"; -import { IProject, IAssetMetadata, AssetState, IExportProviderOptions } from "../../models/applicationState"; +import { IProject, IAssetMetadata, AssetState } from "../../models/applicationState"; import MockFactory from "../../common/mockFactory"; jest.mock("../../services/assetService"); diff --git a/src/redux/actions/projectActions.test.ts b/src/redux/actions/projectActions.test.ts index ddf7b12eaa..8a8197ed37 100644 --- a/src/redux/actions/projectActions.test.ts +++ b/src/redux/actions/projectActions.test.ts @@ -96,6 +96,7 @@ describe("Project Redux Actions", () => { providerType: "vottJson", providerOptions: { assetState: ExportAssetState.Visited, + includeImages: true, }, }); }); diff --git a/src/redux/actions/projectActions.ts b/src/redux/actions/projectActions.ts index 99b465f04d..0b9b1b3670 100644 --- a/src/redux/actions/projectActions.ts +++ b/src/redux/actions/projectActions.ts @@ -15,6 +15,8 @@ import { createAction, createPayloadAction, IPayloadAction } from "./actionCreat import { ExportAssetState, IExportResults } from "../../providers/export/exportProvider"; import { appInfo } from "../../common/appInfo"; import { strings } from "../../common/strings"; +import { IExportFormat } from "vott-react"; +import { IVottJsonExportProviderOptions } from "../../providers/export/vottJson"; /** * Actions to be performed in relation to projects @@ -76,11 +78,14 @@ export function saveProject(project: IProject) throw new AppError(ErrorCode.SecurityTokenNotFound, "Security Token Not Found"); } - const defaultExportFormat = { + const defaultExportProviderOptions: IVottJsonExportProviderOptions = { + assetState: ExportAssetState.Visited, + includeImages: true, + }; + + const defaultExportFormat: IExportFormat = { providerType: "vottJson", - providerOptions: { - assetState: ExportAssetState.Visited, - }, + providerOptions: defaultExportProviderOptions, }; const newProject = { @@ -130,7 +135,7 @@ export function deleteProject(project: IProject) */ export function closeProject(): (dispatch: Dispatch) => void { return (dispatch: Dispatch): void => { - dispatch({type: ActionTypes.CLOSE_PROJECT_SUCCESS}); + dispatch({ type: ActionTypes.CLOSE_PROJECT_SUCCESS }); }; } @@ -159,7 +164,7 @@ export function loadAssetMetadata(project: IProject, asset: IAsset): (dispatch: const assetMetadata = await assetService.getAssetMetadata(asset); dispatch(loadAssetMetadataAction(assetMetadata)); - return {...assetMetadata}; + return { ...assetMetadata }; }; } @@ -171,14 +176,14 @@ export function loadAssetMetadata(project: IProject, asset: IAsset): (dispatch: export function saveAssetMetadata( project: IProject, assetMetadata: IAssetMetadata): (dispatch: Dispatch) => Promise { - const newAssetMetadata = {...assetMetadata, version: appInfo.version}; + const newAssetMetadata = { ...assetMetadata, version: appInfo.version }; return async (dispatch: Dispatch) => { const assetService = new AssetService(project); const savedMetadata = await assetService.save(newAssetMetadata); dispatch(saveAssetMetadataAction(savedMetadata)); - return {...savedMetadata}; + return { ...savedMetadata }; }; } diff --git a/src/registerProviders.ts b/src/registerProviders.ts index 0f0cc4baa9..6b2f6f9378 100644 --- a/src/registerProviders.ts +++ b/src/registerProviders.ts @@ -11,6 +11,7 @@ import registerToolbar from "./registerToolbar"; import { strings } from "./common/strings"; import { HostProcessType } from "./common/hostProcess"; import { AzureCustomVisionProvider } from "./providers/export/azureCustomVision"; +import { CntkExportProvider } from "./providers/export/cntk"; /** * Registers storage, asset and export providers, as well as all toolbar items @@ -68,6 +69,11 @@ export default function registerProviders() { displayName: strings.export.providers.azureCV.displayName, factory: (project, options) => new AzureCustomVisionProvider(project, options), }); + ExportProviderFactory.register({ + name: "cntk", + displayName: strings.export.providers.cntk.displayName, + factory: (project, options) => new CntkExportProvider(project, options), + }); registerToolbar(); }