Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add integration tests for trained_models API #104819

Merged
merged 14 commits into from
Jul 9, 2021
1 change: 1 addition & 0 deletions x-pack/test/api_integration/apis/ml/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,6 @@ export default function ({ getService, loadTestFile }: FtrProviderContext) {
loadTestFile(require.resolve('./results'));
loadTestFile(require.resolve('./saved_objects'));
loadTestFile(require.resolve('./system'));
loadTestFile(require.resolve('./trained_models'));
});
}
66 changes: 66 additions & 0 deletions x-pack/test/api_integration/apis/ml/trained_models/get_models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import expect from '@kbn/expect';
import { FtrProviderContext } from '../../../ftr_provider_context';
import { USER } from '../../../../functional/services/ml/security_common';
import { COMMON_REQUEST_HEADERS } from '../../../../functional/services/ml/common_api';

export default ({ getService }: FtrProviderContext) => {
const supertest = getService('supertestWithoutAuth');
const ml = getService('ml');

describe('GET trained_models', () => {
let testModelIds: string[] = [];

before(async () => {
await ml.testResources.setKibanaTimeZoneToUTC();
testModelIds = await ml.api.createdTestTrainedModels('regression', 5, true);
});

after(async () => {
await ml.api.cleanMlIndices();
// delete created ingest pipelines
await Promise.all(testModelIds.map((modelId) => ml.api.deleteIngestPipeline(modelId)));
});

it('returns all trained models with associated pipelined', async () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo - should just be pipelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in e13787b

const { body } = await supertest
.get(`/api/ml/trained_models?with_pipelines=true`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_REQUEST_HEADERS)
.expect(200);
// Created models + system model
expect(body.length).to.eql(6);

const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0');
expect(Object.keys(sampleModel.pipelines).length).to.eql(1);
});

it('returns models without pipeline in case user does not have required permission', async () => {
const { body } = await supertest
.get(`/api/ml/trained_models?with_pipelines=true`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(COMMON_REQUEST_HEADERS)
.expect(200);
// Created models + system model
expect(body.length).to.eql(6);
const sampleModel = body.find((v: any) => v.model_id === 'dfa_regression_model_n_0');
expect(sampleModel.pipelines).to.eql(undefined);
});

it('returns trained model by id', async () => {
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
const { body } = await supertest
.get(`/api/ml/trained_models/dfa_regression_model_n_1`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(COMMON_REQUEST_HEADERS)
.expect(200);
expect(body.length).to.eql(1);
expect(body[0].model_id).to.eql('dfa_regression_model_n_1');
});
});
};
14 changes: 14 additions & 0 deletions x-pack/test/api_integration/apis/ml/trained_models/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { FtrProviderContext } from '../../../ftr_provider_context';

export default function ({ loadTestFile }: FtrProviderContext) {
describe('trained models', function () {
loadTestFile(require.resolve('./get_models'));
});
}
84 changes: 84 additions & 0 deletions x-pack/test/functional/services/ml/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import { estypes } from '@elastic/elasticsearch';
import expect from '@kbn/expect';
import { ProvidedType } from '@kbn/test';
import fs from 'fs';
import path from 'path';
import { Calendar } from '../../../../plugins/ml/server/models/calendar/index';
import { Annotation } from '../../../../plugins/ml/common/types/annotations';
import { DataFrameAnalyticsConfig } from '../../../../plugins/ml/public/application/data_frame_analytics/common';
Expand All @@ -25,6 +27,8 @@ import {
import { COMMON_REQUEST_HEADERS } from '../../../functional/services/ml/common_api';
import { PutTrainedModelConfig } from '../../../../plugins/ml/common/types/trained_models';

type ModelType = 'regression' | 'classification';

export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
const es = getService('es');
const log = getService('log');
Expand Down Expand Up @@ -943,5 +947,85 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
log.debug('> Trained model crated');
return model;
},

async createdTestTrainedModels(
modelType: ModelType,
count: number = 10,
withIngestPipelines = false
) {
const compressedDefinition = this.getCompressedModelDefinition(modelType);

const modelIds = new Array(count).fill(null).map((v, i) => `dfa_${modelType}_model_n_${i}`);

const models = modelIds.map((id) => {
return {
model_id: id,
body: {
compressed_definition: compressedDefinition,
inference_config: {
[modelType]: {},
},
input: {
field_names: ['common_field'],
},
} as PutTrainedModelConfig,
};
});

for (const model of models) {
await this.createTrainedModel(model.model_id, model.body);
if (withIngestPipelines) {
await this.createIngestPipeline(model.model_id);
}
}

return modelIds;
},

/**
* Retrieves compressed model definition from the test resources.
* @param modelType
*/
getCompressedModelDefinition(modelType: ModelType) {
return fs.readFileSync(
path.resolve(
__dirname,
'resources',
'trained_model_definitions',
`minimum_valid_config_${modelType}.json.gz.b64`
),
'utf-8'
);
},

/**
* Creates ingest pipelines for trained model
* @param modelId
*/
async createIngestPipeline(modelId: string) {
log.debug(`Creating ingest pipeline for trained model with id "${modelId}"`);
const ingestPipeline = await esSupertest
.put(`/_ingest/pipeline/pipeline_${modelId}`)
.send({
processors: [
{
inference: {
model_id: modelId,
},
},
],
})
.expect(200)
.then((res) => res.body);

log.debug('> Ingest pipeline crated');
return ingestPipeline;
},

async deleteIngestPipeline(modelId: string) {
log.debug(`Deleting ingest pipeline for trained model with id "${modelId}"`);
await esSupertest.delete(`/_ingest/pipeline/pipeline_${modelId}`).expect(200);
log.debug('> Ingest pipeline deleted');
},
};
}
36 changes: 1 addition & 35 deletions x-pack/test/functional/services/ml/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
* 2.0.
*/

import fs from 'fs';
import path from 'path';
import expect from '@kbn/expect';
import { FtrProviderContext } from '../../ftr_provider_context';
import { MlApi } from './api';
import { PutTrainedModelConfig } from '../../../../plugins/ml/common/types/trained_models';
import { MlCommonUI } from './common_ui';

type ModelType = 'regression' | 'classification';
Expand All @@ -24,38 +21,7 @@ export function TrainedModelsProvider(

return {
async createdTestTrainedModels(modelType: ModelType, count: number = 10) {
const compressedDefinition = this.getCompressedModelDefinition(modelType);

const models = new Array(count).fill(null).map((v, i) => {
return {
model_id: `dfa_${modelType}_model_n_${i}`,
body: {
compressed_definition: compressedDefinition,
inference_config: {
[modelType]: {},
},
input: {
field_names: ['common_field'],
},
} as PutTrainedModelConfig,
};
});

for (const model of models) {
await mlApi.createTrainedModel(model.model_id, model.body);
}
},

getCompressedModelDefinition(modelType: ModelType) {
return fs.readFileSync(
path.resolve(
__dirname,
'resources',
'trained_model_definitions',
`minimum_valid_config_${modelType}.json.gz.b64`
),
'utf-8'
);
await mlApi.createdTestTrainedModels(modelType, count);
},

async assertStats(expectedTotalCount: number) {
Expand Down