Skip to content

Commit

Permalink
[ML] Cardinality validation API integration tests (elastic#65971)
Browse files Browse the repository at this point in the history
* [ML] refactor validate_cardinality to TS

* [ML] cardinality api integration tests

* [ML] resolve PR comments, validateJobObject as TS guard
  • Loading branch information
darnautov committed May 11, 2020
1 parent a86ea8a commit c204104
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ export class DataVisualizer {
aggregatableFields: string[],
samplerShardSize: number,
timeFieldName: string,
earliestMs: number,
latestMs: number
earliestMs?: number,
latestMs?: number
) {
const index = indexPatternTitle;
const size = 0;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@
* you may not use this file except in compliance with the Elastic License.
*/

import _ from 'lodash';

import { APICaller } from 'kibana/server';
import { DataVisualizer } from '../data_visualizer';

import { validateJobObject } from './validate_job_object';
import { CombinedJob } from '../../../common/types/anomaly_detection_jobs';
import { Detector } from '../../../common/types/anomaly_detection_jobs';

function isValidCategorizationConfig(job, fieldName) {
function isValidCategorizationConfig(job: CombinedJob, fieldName: string): boolean {
return (
typeof job.analysis_config.categorization_field_name !== 'undefined' &&
fieldName === 'mlcategory'
);
}

function isScriptField(job, fieldName) {
const scriptFields = Object.keys(_.get(job, 'datafeed_config.script_fields', {}));
function isScriptField(job: CombinedJob, fieldName: string): boolean {
const scriptFields = Object.keys(job.datafeed_config.script_fields ?? {});
return scriptFields.includes(fieldName);
}

Expand All @@ -30,10 +31,21 @@ const PARTITION_FIELD_CARDINALITY_THRESHOLD = 1000;
const BY_FIELD_CARDINALITY_THRESHOLD = 1000;
const MODEL_PLOT_THRESHOLD_HIGH = 100;

const validateFactory = (callWithRequest, job) => {
type Messages = Array<{ id: string; fieldName?: string }>;

type Validator = (obj: {
type: string;
isInvalid: (cardinality: number) => boolean;
messageId?: string;
}) => Promise<{
modelPlotCardinality: number;
messages: Messages;
}>;

const validateFactory = (callWithRequest: APICaller, job: CombinedJob): Validator => {
const dv = new DataVisualizer(callWithRequest);

const modelPlotConfigTerms = _.get(job, ['model_plot_config', 'terms'], '');
const modelPlotConfigTerms = job?.model_plot_config?.terms ?? '';
const modelPlotConfigFieldCount =
modelPlotConfigTerms.length > 0 ? modelPlotConfigTerms.split(',').length : 0;

Expand All @@ -42,8 +54,11 @@ const validateFactory = (callWithRequest, job) => {
// if model_plot_config.terms is used, it doesn't count the real cardinality of the field
// but adds only the count of fields used in model_plot_config.terms
let modelPlotCardinality = 0;
const messages = [];
const fieldName = `${type}_field_name`;
const messages: Messages = [];
const fieldName = `${type}_field_name` as keyof Pick<
Detector,
'by_field_name' | 'over_field_name' | 'partition_field_name'
>;

const detectors = job.analysis_config.detectors;
const relevantDetectors = detectors.filter(detector => {
Expand All @@ -52,15 +67,15 @@ const validateFactory = (callWithRequest, job) => {

if (relevantDetectors.length > 0) {
try {
const uniqueFieldNames = _.uniq(relevantDetectors.map(f => f[fieldName]));
const uniqueFieldNames = [...new Set(relevantDetectors.map(f => f[fieldName]))] as string[];

// use fieldCaps endpoint to get data about whether fields are aggregatable
const fieldCaps = await callWithRequest('fieldCaps', {
index: job.datafeed_config.indices.join(','),
fields: uniqueFieldNames,
});

let aggregatableFieldNames = [];
let aggregatableFieldNames: string[] = [];
// parse fieldCaps to return an array of just the fields which are aggregatable
if (typeof fieldCaps === 'object' && typeof fieldCaps.fields === 'object') {
aggregatableFieldNames = uniqueFieldNames.filter(field => {
Expand All @@ -81,12 +96,14 @@ const validateFactory = (callWithRequest, job) => {
);

uniqueFieldNames.forEach(uniqueFieldName => {
const field = _.find(stats.aggregatableExistsFields, { fieldName: uniqueFieldName });
if (typeof field === 'object') {
const field = stats.aggregatableExistsFields.find(
fieldData => fieldData.fieldName === uniqueFieldName
);
if (field !== undefined && typeof field === 'object' && field.stats) {
modelPlotCardinality +=
modelPlotConfigFieldCount > 0 ? modelPlotConfigFieldCount : field.stats.cardinality;
modelPlotConfigFieldCount > 0 ? modelPlotConfigFieldCount : field.stats.cardinality!;

if (isInvalid(field.stats.cardinality)) {
if (isInvalid(field.stats.cardinality!)) {
messages.push({
id: messageId || `cardinality_${type}_field`,
fieldName: uniqueFieldName,
Expand Down Expand Up @@ -115,7 +132,7 @@ const validateFactory = (callWithRequest, job) => {
if (relevantDetectors.length === 1) {
messages.push({
id: 'field_not_aggregatable',
fieldName: relevantDetectors[0][fieldName],
fieldName: relevantDetectors[0][fieldName]!,
});
} else {
messages.push({ id: 'fields_not_aggregatable' });
Expand All @@ -129,25 +146,30 @@ const validateFactory = (callWithRequest, job) => {
};
};

export async function validateCardinality(callWithRequest, job) {
export async function validateCardinality(
callWithRequest: APICaller,
job?: CombinedJob
): Promise<Array<{ id: string; modelPlotCardinality?: number; fieldName?: string }>> | never {
const messages = [];

validateJobObject(job);
if (!validateJobObject(job)) {
// required for TS type casting, validateJobObject throws an error internally.
throw new Error();
}

// find out if there are any relevant detector field names
// where cardinality checks could be run against.
const numDetectorsWithFieldNames = job.analysis_config.detectors.filter(d => {
return d.by_field_name || d.over_field_name || d.partition_field_name;
});
if (numDetectorsWithFieldNames.length === 0) {
return Promise.resolve([]);
return [];
}

// validate({ type, isInvalid }) asynchronously returns an array of validation messages
const validate = validateFactory(callWithRequest, job);

const modelPlotEnabled =
(job.model_plot_config && job.model_plot_config.enabled === true) || false;
const modelPlotEnabled = job.model_plot_config?.enabled ?? false;

// check over fields (population analysis)
const validateOverFieldsLow = validate({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import { i18n } from '@kbn/i18n';
import { CombinedJob } from '../../../common/types/anomaly_detection_jobs';

export function validateJobObject(job: CombinedJob | null) {
export function validateJobObject(job: CombinedJob | null | undefined): job is CombinedJob | never {
if (job === null || typeof job !== 'object') {
throw new Error(
i18n.translate('xpack.ml.models.jobValidation.validateJobObject.jobIsNotObjectErrorMessage', {
Expand Down Expand Up @@ -93,4 +93,5 @@ export function validateJobObject(job: CombinedJob | null) {
)
);
}
return true;
}
175 changes: 175 additions & 0 deletions x-pack/test/api_integration/apis/ml/job_validation/cardinality.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import expect from '@kbn/expect';
import { FtrProviderContext } from '../../../ftr_provider_context';
import { USER } from '../../../../functional/services/machine_learning/security_common';

const COMMON_HEADERS = {
'kbn-xsrf': 'some-xsrf-token',
};

// eslint-disable-next-line import/no-default-export
export default ({ getService }: FtrProviderContext) => {
const esArchiver = getService('esArchiver');
const supertest = getService('supertestWithoutAuth');
const ml = getService('ml');

describe('ValidateCardinality', function() {
before(async () => {
await esArchiver.loadIfNeeded('ml/ecommerce');
await ml.testResources.setKibanaTimeZoneToUTC();
});

after(async () => {
await ml.api.cleanMlIndices();
});

it(`should recognize a valid cardinality`, async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
analysis_config: {
bucket_span: '10m',
detectors: [
{
function: 'mean',
field_name: 'products.base_price',
partition_field_name: 'geoip.city_name',
},
],
influencers: ['geoip.city_name'],
},
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '12MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};

const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(200);

expect(body).to.eql([{ id: 'success_cardinality' }]);
});

it(`should recognize a high model plot cardinality`, async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
analysis_config: {
bucket_span: '10m',
detectors: [
{
function: 'mean',
field_name: 'products.base_price',
// some high cardinality field
partition_field_name: 'order_id',
},
],
influencers: ['geoip.city_name'],
},
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '11MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};
const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(200);

expect(body).to.eql([
{ id: 'cardinality_model_plot_high', modelPlotCardinality: 4711 },
{ id: 'cardinality_partition_field', fieldName: 'order_id' },
]);
});

it('should not validate cardinality in case request payload is invalid', async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
// missing analysis_config
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '12MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};

const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(400);

expect(body.error).to.eql('Bad Request');
expect(body.message).to.eql(
'[request body.analysis_config.detectors]: expected value of type [array] but got [undefined]'
);
});

it('should not validate cardinality if the user does not have required permissions', async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
analysis_config: {
bucket_span: '10m',
detectors: [
{
function: 'mean',
field_name: 'products.base_price',
partition_field_name: 'geoip.city_name',
},
],
influencers: ['geoip.city_name'],
},
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '12MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};

const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(404);

expect(body.error).to.eql('Not Found');
expect(body.message).to.eql('Not Found');
});
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ export default function({ loadTestFile }: FtrProviderContext) {
describe('job validation', function() {
loadTestFile(require.resolve('./bucket_span_estimator'));
loadTestFile(require.resolve('./calculate_model_memory_limit'));
loadTestFile(require.resolve('./cardinality'));
});
}

0 comments on commit c204104

Please sign in to comment.