Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Cardinality validation API integration tests #65971

Merged
merged 3 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ export class DataVisualizer {
aggregatableFields: string[],
samplerShardSize: number,
timeFieldName: string,
earliestMs: number,
latestMs: number
earliestMs?: number,
latestMs?: number
) {
const index = indexPatternTitle;
const size = 0;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@
* you may not use this file except in compliance with the Elastic License.
*/

import _ from 'lodash';

import { APICaller } from 'kibana/server';
import { DataVisualizer } from '../data_visualizer';

import { validateJobObject } from './validate_job_object';
import { CombinedJob } from '../../../common/types/anomaly_detection_jobs';
import { Detector } from '../../../common/types/anomaly_detection_jobs';

function isValidCategorizationConfig(job, fieldName) {
function isValidCategorizationConfig(job: CombinedJob, fieldName: string): boolean {
return (
typeof job.analysis_config.categorization_field_name !== 'undefined' &&
fieldName === 'mlcategory'
);
}

function isScriptField(job, fieldName) {
const scriptFields = Object.keys(_.get(job, 'datafeed_config.script_fields', {}));
function isScriptField(job: CombinedJob, fieldName: string): boolean {
const scriptFields = Object.keys(job.datafeed_config.script_fields ?? {});
return scriptFields.includes(fieldName);
}

Expand All @@ -30,10 +31,21 @@ const PARTITION_FIELD_CARDINALITY_THRESHOLD = 1000;
const BY_FIELD_CARDINALITY_THRESHOLD = 1000;
const MODEL_PLOT_THRESHOLD_HIGH = 100;

const validateFactory = (callWithRequest, job) => {
type Messages = Array<{ id: string; fieldName?: string }>;

type Validator = (obj: {
type: string;
isInvalid: (cardinality: number) => boolean;
messageId?: string;
}) => Promise<{
modelPlotCardinality: number;
messages: Messages;
}>;

const validateFactory = (callWithRequest: APICaller, job: CombinedJob): Validator => {
const dv = new DataVisualizer(callWithRequest);

const modelPlotConfigTerms = _.get(job, ['model_plot_config', 'terms'], '');
const modelPlotConfigTerms = job?.model_plot_config?.terms ?? '';
const modelPlotConfigFieldCount =
modelPlotConfigTerms.length > 0 ? modelPlotConfigTerms.split(',').length : 0;

Expand All @@ -42,8 +54,11 @@ const validateFactory = (callWithRequest, job) => {
// if model_plot_config.terms is used, it doesn't count the real cardinality of the field
// but adds only the count of fields used in model_plot_config.terms
let modelPlotCardinality = 0;
const messages = [];
const fieldName = `${type}_field_name`;
const messages: Messages = [];
const fieldName = `${type}_field_name` as keyof Pick<
Detector,
'by_field_name' | 'over_field_name' | 'partition_field_name'
>;

const detectors = job.analysis_config.detectors;
const relevantDetectors = detectors.filter(detector => {
Expand All @@ -52,15 +67,15 @@ const validateFactory = (callWithRequest, job) => {

if (relevantDetectors.length > 0) {
try {
const uniqueFieldNames = _.uniq(relevantDetectors.map(f => f[fieldName]));
const uniqueFieldNames = [...new Set(relevantDetectors.map(f => f[fieldName]))] as string[];

// use fieldCaps endpoint to get data about whether fields are aggregatable
const fieldCaps = await callWithRequest('fieldCaps', {
index: job.datafeed_config.indices.join(','),
fields: uniqueFieldNames,
});

let aggregatableFieldNames = [];
let aggregatableFieldNames: string[] = [];
// parse fieldCaps to return an array of just the fields which are aggregatable
if (typeof fieldCaps === 'object' && typeof fieldCaps.fields === 'object') {
aggregatableFieldNames = uniqueFieldNames.filter(field => {
Expand All @@ -81,12 +96,14 @@ const validateFactory = (callWithRequest, job) => {
);

uniqueFieldNames.forEach(uniqueFieldName => {
const field = _.find(stats.aggregatableExistsFields, { fieldName: uniqueFieldName });
if (typeof field === 'object') {
const field = stats.aggregatableExistsFields.find(
fieldData => fieldData.fieldName === uniqueFieldName
);
if (field !== undefined && typeof field === 'object' && field.stats) {
modelPlotCardinality +=
modelPlotConfigFieldCount > 0 ? modelPlotConfigFieldCount : field.stats.cardinality;
modelPlotConfigFieldCount > 0 ? modelPlotConfigFieldCount : field.stats.cardinality!;

if (isInvalid(field.stats.cardinality)) {
if (isInvalid(field.stats.cardinality!)) {
messages.push({
id: messageId || `cardinality_${type}_field`,
fieldName: uniqueFieldName,
Expand Down Expand Up @@ -115,7 +132,7 @@ const validateFactory = (callWithRequest, job) => {
if (relevantDetectors.length === 1) {
messages.push({
id: 'field_not_aggregatable',
fieldName: relevantDetectors[0][fieldName],
fieldName: relevantDetectors[0][fieldName]!,
});
} else {
messages.push({ id: 'fields_not_aggregatable' });
Expand All @@ -129,25 +146,30 @@ const validateFactory = (callWithRequest, job) => {
};
};

export async function validateCardinality(callWithRequest, job) {
export async function validateCardinality(
callWithRequest: APICaller,
job?: CombinedJob
): Promise<Array<{ id: string; modelPlotCardinality?: number; fieldName?: string }>> | never {
const messages = [];

validateJobObject(job);
if (!validateJobObject(job)) {
// required for TS type casting, validateJobObject throws an error internally.
throw new Error();
}

// find out if there are any relevant detector field names
// where cardinality checks could be run against.
const numDetectorsWithFieldNames = job.analysis_config.detectors.filter(d => {
return d.by_field_name || d.over_field_name || d.partition_field_name;
});
if (numDetectorsWithFieldNames.length === 0) {
return Promise.resolve([]);
return [];
}

// validate({ type, isInvalid }) asynchronously returns an array of validation messages
const validate = validateFactory(callWithRequest, job);

const modelPlotEnabled =
(job.model_plot_config && job.model_plot_config.enabled === true) || false;
const modelPlotEnabled = job.model_plot_config?.enabled ?? false;

// check over fields (population analysis)
const validateOverFieldsLow = validate({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import { i18n } from '@kbn/i18n';
import { CombinedJob } from '../../../common/types/anomaly_detection_jobs';

export function validateJobObject(job: CombinedJob | null) {
export function validateJobObject(job: CombinedJob | null | undefined): job is CombinedJob | never {
if (job === null || typeof job !== 'object') {
throw new Error(
i18n.translate('xpack.ml.models.jobValidation.validateJobObject.jobIsNotObjectErrorMessage', {
Expand Down Expand Up @@ -93,4 +93,5 @@ export function validateJobObject(job: CombinedJob | null) {
)
);
}
return true;
}
175 changes: 175 additions & 0 deletions x-pack/test/api_integration/apis/ml/job_validation/cardinality.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
import expect from '@kbn/expect';
import { FtrProviderContext } from '../../../ftr_provider_context';
import { USER } from '../../../../functional/services/machine_learning/security_common';

const COMMON_HEADERS = {
'kbn-xsrf': 'some-xsrf-token',
};

// eslint-disable-next-line import/no-default-export
export default ({ getService }: FtrProviderContext) => {
const esArchiver = getService('esArchiver');
const supertest = getService('supertestWithoutAuth');
const ml = getService('ml');

describe('ValidateCardinality', function() {
before(async () => {
await esArchiver.loadIfNeeded('ml/ecommerce');
await ml.testResources.setKibanaTimeZoneToUTC();
});

after(async () => {
await ml.api.cleanMlIndices();
});

it(`should recognize a valid cardinality`, async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
analysis_config: {
bucket_span: '10m',
detectors: [
{
function: 'mean',
field_name: 'products.base_price',
partition_field_name: 'geoip.city_name',
},
],
influencers: ['geoip.city_name'],
},
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '12MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};

const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(200);

expect(body).to.eql([{ id: 'success_cardinality' }]);
});

it(`should recognize a high model plot cardinality`, async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
analysis_config: {
bucket_span: '10m',
detectors: [
{
function: 'mean',
field_name: 'products.base_price',
// some high cardinality field
partition_field_name: 'order_id',
},
],
influencers: ['geoip.city_name'],
},
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '11MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};
const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(200);

expect(body).to.eql([
{ id: 'cardinality_model_plot_high', modelPlotCardinality: 4711 },
{ id: 'cardinality_partition_field', fieldName: 'order_id' },
]);
});

it('should not validate cardinality in case request payload is invalid', async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
// missing analysis_config
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '12MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};

const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(400);

expect(body.error).to.eql('Bad Request');
expect(body.message).to.eql(
'[request body.analysis_config.detectors]: expected value of type [array] but got [undefined]'
);
});

it('should not validate cardinality if the user does not have required permissions', async () => {
const requestBody = {
job_id: '',
description: '',
groups: [],
analysis_config: {
bucket_span: '10m',
detectors: [
{
function: 'mean',
field_name: 'products.base_price',
partition_field_name: 'geoip.city_name',
},
],
influencers: ['geoip.city_name'],
},
data_description: { time_field: 'order_date' },
analysis_limits: { model_memory_limit: '12MB' },
model_plot_config: { enabled: true },
datafeed_config: {
datafeed_id: 'datafeed-',
job_id: '',
indices: ['ft_ecommerce'],
query: { bool: { must: [{ match_all: {} }], filter: [], must_not: [] } },
},
};

const { body } = await supertest
.post('/api/ml/validate/cardinality')
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(COMMON_HEADERS)
.send(requestBody)
.expect(404);

expect(body.error).to.eql('Not Found');
expect(body.message).to.eql('Not Found');
});
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ export default function({ loadTestFile }: FtrProviderContext) {
describe('job validation', function() {
loadTestFile(require.resolve('./bucket_span_estimator'));
loadTestFile(require.resolve('./calculate_model_memory_limit'));
loadTestFile(require.resolve('./cardinality'));
});
}