Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Adding UI for question_answering model testing #132033

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugins/ml/common/constants/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export type TrainedModelType = typeof TRAINED_MODEL_TYPE[keyof typeof TRAINED_MO

export const SUPPORTED_PYTORCH_TASKS = {
NER: 'ner',
QUESTION_ANSWERING: 'question_answering',
ZERO_SHOT_CLASSIFICATION: 'zero_shot_classification',
TEXT_CLASSIFICATION: 'text_classification',
TEXT_EMBEDDING: 'text_embedding',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

import { NerInference } from './ner';
import { QuestionAnsweringInference } from './question_answering';
import {
TextClassificationInference,
ZeroShotClassificationInference,
Expand All @@ -16,6 +17,7 @@ import { TextEmbeddingInference } from './text_embedding';

export type InferrerType =
| NerInference
| QuestionAnsweringInference
| TextClassificationInference
| TextEmbeddingInference
| ZeroShotClassificationInference
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export type {
FormattedQuestionAnsweringResponse,
QuestionAnsweringResponse,
} from './question_answering_inference';
export { QuestionAnsweringInference } from './question_answering_inference';
export { getQuestionAnsweringOutputComponent } from './question_answering_output';
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { BehaviorSubject } from 'rxjs';

import { InferenceBase, InferResponse } from '../inference_base';
import { getQuestionAnsweringInput } from './question_answering_input';
import { getQuestionAnsweringOutputComponent } from './question_answering_output';
import { SUPPORTED_PYTORCH_TASKS } from '../../../../../../../common/constants/trained_models';

export interface RawQuestionAnsweringResponse {
inference_results: Array<{
predicted_value: string;
prediction_probability: number;
start_offset: number;
end_offset: number;
top_classes?: Array<{
end_offset: number;
score: number;
start_offset: number;
answer: string;
}>;
}>;
}

export interface FormattedQuestionAnsweringResult {
value: string;
predictionProbability: number;
startOffset: number;
endOffset: number;
}

export type FormattedQuestionAnsweringResponse = FormattedQuestionAnsweringResult[];

export type QuestionAnsweringResponse = InferResponse<
FormattedQuestionAnsweringResponse,
RawQuestionAnsweringResponse
>;

export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING;

public questionText$ = new BehaviorSubject<string>('');

public async infer() {
try {
this.setRunning();
const inputText = this.inputText$.getValue();
const questionText = this.questionText$.value;
const numTopClassesConfig = this.getNumTopClassesConfig()?.inference_config;

const payload = {
docs: [{ [this.inputField]: inputText }],
inference_config: {
[this.inferenceType]: {
question: questionText,
...(numTopClassesConfig
? {
num_top_classes: numTopClassesConfig[this.inferenceType].num_top_classes,
}
: {}),
},
},
};
const resp = (await this.trainedModelsApi.inferTrainedModel(
this.model.model_id,
payload,
'30s'
)) as unknown as RawQuestionAnsweringResponse;

const processedResponse: QuestionAnsweringResponse = processResponse(
resp,
this.model,
inputText
);

this.inferenceResult$.next(processedResponse);
this.setFinished();

return processedResponse;
} catch (error) {
this.setFinishedWithErrors(error);
throw error;
}
}

public getInputComponent(): JSX.Element {
return getQuestionAnsweringInput(this);
}

public getOutputComponent(): JSX.Element {
return getQuestionAnsweringOutputComponent(this);
}
}

function processResponse(
resp: RawQuestionAnsweringResponse,
model: estypes.MlTrainedModelConfig,
inputText: string
) {
const {
inference_results: [inferenceResults],
} = resp;

let formattedResponse = [
{
value: inferenceResults.predicted_value,
predictionProbability: inferenceResults.prediction_probability,
startOffset: inferenceResults.start_offset,
endOffset: inferenceResults.end_offset,
},
];

if (inferenceResults.top_classes !== undefined) {
formattedResponse = inferenceResults.top_classes.map((topClass) => {
return {
value: topClass.answer,
predictionProbability: topClass.score,
startOffset: topClass.start_offset,
endOffset: topClass.end_offset,
};
});
}

return {
response: formattedResponse,
rawResponse: resp,
inputText,
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React, { FC, useEffect, useState } from 'react';
import useObservable from 'react-use/lib/useObservable';
import { i18n } from '@kbn/i18n';

import { EuiSpacer, EuiFieldText, EuiFormRow } from '@elastic/eui';

import { TextInput } from '../text_input';
import { QuestionAnsweringInference } from './question_answering_inference';
import { RUNNING_STATE } from '../inference_base';

const QuestionInput: FC<{
inferrer: QuestionAnsweringInference;
}> = ({ inferrer }) => {
const [questionText, setQuestionText] = useState('');

useEffect(() => {
inferrer.questionText$.next(questionText);
}, [questionText]);

const runningState = useObservable(inferrer.runningState$);
return (
<EuiFormRow
fullWidth
label={i18n.translate(
'xpack.ml.trainedModels.testModelsFlyout.questionAnswering.questionInput',
{
defaultMessage: 'Question',
}
)}
>
<EuiFieldText
value={questionText}
disabled={runningState === RUNNING_STATE.RUNNING}
fullWidth
onChange={(e) => {
setQuestionText(e.target.value);
}}
/>
</EuiFormRow>
);
};

export const getQuestionAnsweringInput = (
inferrer: QuestionAnsweringInference,
placeholder?: string
) => (
<>
<TextInput placeholder={placeholder} inferrer={inferrer} />
<EuiSpacer />
<QuestionInput inferrer={inferrer} />
</>
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React, { FC, ReactNode } from 'react';
import useObservable from 'react-use/lib/useObservable';

import { EuiBadge } from '@elastic/eui';

import { useCurrentEuiTheme } from '../../../../../components/color_range_legend/use_color_range';

import type {
QuestionAnsweringInference,
FormattedQuestionAnsweringResult,
} from './question_answering_inference';

const ICON_PADDING = '2px';
const TRIM_CHAR_COUNT = 200;

export const getQuestionAnsweringOutputComponent = (inferrer: QuestionAnsweringInference) => (
<QuestionAnsweringOutput inferrer={inferrer} />
);

const QuestionAnsweringOutput: FC<{ inferrer: QuestionAnsweringInference }> = ({ inferrer }) => {
const result = useObservable(inferrer.inferenceResult$);
if (!result || result.response.length === 0) {
return null;
}

const bestResult = result.response[0];
const { inputText } = result;

return <>{insertHighlighting(bestResult, inputText)}</>;
};

function insertHighlighting(result: FormattedQuestionAnsweringResult, inputText: string) {
const start = inputText.slice(0, result.startOffset);
const end = inputText.slice(result.endOffset, inputText.length);
const truncatedStart =
start.length > TRIM_CHAR_COUNT
? `...${start.slice(start.length - TRIM_CHAR_COUNT, start.length)}`
: start;
const truncatedEnd = end.length > TRIM_CHAR_COUNT ? `${end.slice(0, TRIM_CHAR_COUNT)}...` : end;

return (
<div style={{ lineHeight: '24px' }}>
{truncatedStart}
<ResultBadge>{result.value}</ResultBadge>
{truncatedEnd}
</div>
);
}

const ResultBadge = ({ children }: { children: ReactNode }) => {
const { euiTheme } = useCurrentEuiTheme();
return (
<EuiBadge
color={euiTheme.euiColorVis5_behindText}
style={{
marginRight: ICON_PADDING,
marginTop: `-${ICON_PADDING}`,
border: `1px solid ${euiTheme.euiColorVis5}`,
fontSize: euiTheme.euiFontSizeXS,
padding: '0px 6px',
pointerEvents: 'none',
}}
>
{children}
</EuiBadge>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import React, { FC } from 'react';

import { NerInference } from './models/ner';
import { QuestionAnsweringInference } from './models/question_answering';

import {
TextClassificationInference,
Expand Down Expand Up @@ -64,6 +65,11 @@ export const SelectedModel: FC<Props> = ({ model }) => {
const inferrer = new FillMaskInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}

if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING) {
const inferrer = new QuestionAnsweringInference(trainedModels, model);
return <InferenceInputForm inferrer={inferrer} />;
}
}
if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
const inferrer = new LangIdentInference(trainedModels, model);
Expand Down