Skip to content

Commit

Permalink
adding observable for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jgowdyelastic committed Apr 25, 2022
1 parent 03e1e6e commit 2372bcc
Show file tree
Hide file tree
Showing 18 changed files with 298 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import { BehaviorSubject } from 'rxjs';
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';

import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
Expand All @@ -18,6 +19,8 @@ export type FormattedNerResponse = Array<{

export abstract class InferenceBase<TInferResponse> {
protected readonly inputField: string;
public inferenceResult$ = new BehaviorSubject<TInferResponse | null>(null);
public isRunning$ = new BehaviorSubject<boolean>(false);

constructor(
protected trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,46 @@
* 2.0.
*/

import React, { FC, useState } from 'react';
import React, { FC, useState, useMemo } from 'react';

import useObservable from 'react-use/lib/useObservable';
import { FormattedMessage } from '@kbn/i18n-react';
import { EuiSpacer, EuiButton, EuiTabs, EuiTab } from '@elastic/eui';

import { LangIdentInference } from './lang_ident';
import { NerInference } from './ner';
import {
TextClassificationInference,
ZeroShotClassificationInference,
FillMaskInference,
} from './text_classification';

import type { FormattedLangIdentResponse } from './lang_ident';
import type { FormattedNerResponse } from './ner';
import type { FormattedTextClassificationResponse } from './text_classification';
import type { FormattedTextEmbeddingResponse } from './text_embedding';

import { MLJobEditor } from '../../../../jobs/jobs_list/components/ml_job_editor';
import { extractErrorMessage } from '../../../../../../common/util/errors';
import { ErrorMessage } from '../inference_error';
import { OutputLoadingContent } from '../output_loading';
import { NerInference } from './ner';
import {
TextClassificationInference,
ZeroShotClassificationInference,
FillMaskInference,
} from './text_classification';
import { TextEmbeddingInference } from './text_embedding';
import { LangIdentInference } from './lang_ident';

type FormattedInferenceResponse =
| FormattedLangIdentResponse
| FormattedNerResponse
| FormattedTextClassificationResponse;

type InferResponse =
| ReturnType<LangIdentInference['infer']>
| ReturnType<NerInference['infer']>
| ReturnType<TextClassificationInference['infer']>
| ReturnType<ZeroShotClassificationInference['infer']>
| ReturnType<FillMaskInference['infer']>;
| FormattedTextClassificationResponse
| FormattedTextEmbeddingResponse;

interface Props {
getOutputComponent(output: FormattedInferenceResponse): JSX.Element;
getInputComponent(): JSX.Element;
inputText: string;
infer(): InferResponse;
isRunning: boolean;
setIsRunning(running: boolean): void;
getOutputComponent(inputText: string): JSX.Element;
getInputComponent: () => { inputComponent: JSX.Element; infer: () => any };
inferrer:
| NerInference
| TextClassificationInference
| TextEmbeddingInference
| ZeroShotClassificationInference
| FillMaskInference
| LangIdentInference;
}

enum TAB {
Expand All @@ -56,46 +55,41 @@ enum TAB {
export const InferenceInputForm: FC<Props> = ({
getOutputComponent,
getInputComponent,
inputText,
infer,
isRunning,
setIsRunning,
inferrer,
}) => {
const [output, setOutput] = useState<FormattedInferenceResponse | null>(null);
const [inputText, setInputText] = useState('');
const [rawOutput, setRawOutput] = useState<string | null>(null);
const [selectedTab, setSelectedTab] = useState(TAB.TEXT);
const [showOutput, setShowOutput] = useState(false);
const [errorText, setErrorText] = useState<string | null>(null);

const isRunning = useObservable(inferrer.isRunning$);
const { inputComponent, infer } = useMemo(getInputComponent, []);

async function run() {
setShowOutput(true);
setOutput(null);
setRawOutput(null);
setIsRunning(true);
setErrorText(null);
try {
const { response, rawResponse } = await infer();
const { response, rawResponse, inputText: inputText2 } = await infer();
setOutput(response);
setInputText(inputText2);
setRawOutput(JSON.stringify(rawResponse, null, 2));
} catch (e) {
setIsRunning(false);
setOutput(null);
setErrorText(extractErrorMessage(e));
setRawOutput(JSON.stringify(e.body ?? e, null, 2));
}
setIsRunning(false);
}

return (
<>
<>{getInputComponent()}</>
<>{inputComponent}</>
<EuiSpacer size="m" />
<div>
<EuiButton
onClick={run}
disabled={isRunning === true || inputText === ''}
fullWidth={false}
>
<EuiButton onClick={run} disabled={isRunning === true} fullWidth={false}>
<FormattedMessage
id="xpack.ml.trainedModels.testModelsFlyout.inferenceInputForm.runButton"
defaultMessage="Test"
Expand Down Expand Up @@ -133,9 +127,9 @@ export const InferenceInputForm: FC<Props> = ({
{errorText !== null ? (
<ErrorMessage errorText={errorText} />
) : output === null ? (
<OutputLoadingContent text={inputText} />
<OutputLoadingContent text={''} />
) : (
<>{getOutputComponent(output)}</>
<>{getOutputComponent(inputText)}</>
)}
</>
) : (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ export type FormattedLangIdentResponse = Array<{
}>;

interface InferResponse {
inputText: string;
response: FormattedLangIdentResponse;
rawResponse: estypes.IngestSimulateResponse;
}

export class LangIdentInference extends InferenceBase<InferResponse> {
public async infer(inputText: string) {
this.isRunning$.next(true);
const payload: estypes.IngestSimulateRequest['body'] = {
pipeline: {
processors: [
Expand Down Expand Up @@ -54,15 +56,21 @@ export class LangIdentInference extends InferenceBase<InferResponse> {
if (resp.docs.length) {
const topClasses = resp.docs[0].doc?._source._ml?.lang_ident?.top_classes ?? [];

return {
const r = {
response: topClasses.map((t: any) => ({
className: t.class_name,
classProbability: t.class_probability,
classScore: t.class_score,
})),
rawResponse: resp,
inputText,
};
this.inferenceResult$.next(r);
return r;
}
return { response: [], rawResponse: resp };
this.isRunning$.next(false);
const r = { response: [], rawResponse: resp, inputText };
this.inferenceResult$.next(r);
return r;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,31 @@
*/

import React, { FC } from 'react';
import useObservable from 'react-use/lib/useObservable';
import { i18n } from '@kbn/i18n';
import { EuiSpacer, EuiBasicTable, EuiTitle } from '@elastic/eui';

import type { FormattedLangIdentResponse } from './lang_ident_inference';
import type { LangIdentInference } from './lang_ident_inference';
import { getLanguage } from './lang_codes';

const PROBABILITY_SIG_FIGS = 3;

export const getLangIdentOutputComponent = (output: FormattedLangIdentResponse) => (
<LangIdentOutput result={output} />
);
export const getLangIdentOutputComponent = (inferrer: LangIdentInference) => () =>
<LangIdentOutput inferrer={inferrer} />;

const LangIdentOutput: FC<{ result: FormattedLangIdentResponse }> = ({ result }) => {
if (result.length === 0) {
const LangIdentOutput: FC<{ inferrer: LangIdentInference }> = ({ inferrer }) => {
const result = useObservable(inferrer.inferenceResult$);
if (!result) {
return null;
}

const lang = getLanguage(result[0].className);
if (result.response.length === 0) {
return null;
}

const lang = getLanguage(result.response[0].className);

const items = result.map(({ className, classProbability }, i) => {
const items = result.response.map(({ className, classProbability }, i) => {
return {
noa: `${i + 1}`,
className: getLanguage(className),
Expand Down Expand Up @@ -74,7 +79,7 @@ const LangIdentOutput: FC<{ result: FormattedLangIdentResponse }> = ({ result })
})
: i18n.translate('xpack.ml.trainedModels.testModelsFlyout.langIdent.output.titleUnknown', {
defaultMessage: 'Language code unknown: {code}',
values: { code: result[0].className },
values: { code: result.response[0].className },
});

return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@ export type FormattedNerResponse = Array<{
interface InferResponse {
response: FormattedNerResponse;
rawResponse: estypes.MlInferTrainedModelDeploymentResponse;
inputText: string;
}

export class NerInference extends InferenceBase<InferResponse> {
public async infer(inputText: string) {
this.isRunning$.next(true);
const payload = { docs: { [this.inputField]: inputText } };
const resp = await this.trainedModelsApi.inferTrainedModel(this.model.model_id, payload, '30s');

return { response: parseResponse(resp), rawResponse: resp };
const processedResponse = { response: parseResponse(resp), rawResponse: resp, inputText };
this.inferenceResult$.next(processedResponse);
this.isRunning$.next(false);

return processedResponse;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import React, { FC, ReactNode } from 'react';
import useObservable from 'react-use/lib/useObservable';
import { FormattedMessage } from '@kbn/i18n-react';
import {
EuiHorizontalRule,
Expand All @@ -21,7 +22,7 @@ import {
useCurrentEuiTheme,
EuiThemeType,
} from '../../../../../components/color_range_legend/use_color_range';
import type { FormattedNerResponse } from './ner_inference';
import type { NerInference } from './ner_inference';

const ICON_PADDING = '2px';
const PROBABILITY_SIG_FIGS = 3;
Expand Down Expand Up @@ -60,14 +61,19 @@ const UNKNOWN_ENTITY_TYPE = {
borderColor: 'euiColorVis5',
};

export const getNerOutputComponent = (output: FormattedNerResponse) => (
<NerOutput result={output} />
);
export const getNerOutputComponent = (inferrer: NerInference) => () =>
<NerOutput inferrer={inferrer} />;

const NerOutput: FC<{ result: FormattedNerResponse }> = ({ result }) => {
const NerOutput: FC<{ inferrer: NerInference }> = ({ inferrer }) => {
const { euiTheme } = useCurrentEuiTheme();
const result = useObservable(inferrer.inferenceResult$);

if (!result) {
return null;
}

const lineSplit: JSX.Element[] = [];
result.forEach(({ value, entity }) => {
result.response.forEach(({ value, entity }) => {
if (entity === null) {
const lines = value
.split(/(\n)/)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ export type FormattedTextClassificationResponse = Array<{
}>;

export interface InferResponse {
inputText: string;
response: FormattedTextClassificationResponse;
rawResponse: TextClassificationResponse;
}

export function processResponse(
resp: TextClassificationResponse,
model: estypes.MlTrainedModelConfig
model: estypes.MlTrainedModelConfig,
inputText: string
): InferResponse {
const labels: string[] =
// @ts-expect-error inference config is wrong
Expand Down Expand Up @@ -77,5 +79,6 @@ export function processResponse(
.sort((a, b) => a.predictionProbability - b.predictionProbability)
.reverse(),
rawResponse: resp,
inputText,
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { processResponse } from './common';

export class FillMaskInference extends InferenceBase<InferResponse> {
public async infer(inputText: string) {
this.isRunning$.next(true);
const payload = {
docs: { [this.inputField]: inputText },
inference_config: { fill_mask: { num_top_classes: 5 } },
Expand All @@ -21,6 +22,10 @@ export class FillMaskInference extends InferenceBase<InferResponse> {
'30s'
)) as unknown as TextClassificationResponse;

return processResponse(resp, this.model);
const processedResponse = processResponse(resp, this.model, inputText);
this.inferenceResult$.next(processedResponse);
this.isRunning$.next(false);

return processedResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,28 @@
*/

import React, { FC } from 'react';
import useObservable from 'react-use/lib/useObservable';
import { EuiFlexGroup, EuiFlexItem, EuiSpacer, EuiProgress, EuiTitle } from '@elastic/eui';

import type { FormattedTextClassificationResponse } from './common';
import type { FillMaskInference } from '.';

const MASK = '[MASK]';

export const getFillMaskOutputComponent =
(inputText: string) => (output: FormattedTextClassificationResponse) =>
<FillMaskOutput result={output} inputText={inputText} />;
export const getFillMaskOutputComponent = (inferrer: FillMaskInference) => (inputText: string) =>
<FillMaskOutput inputText={inputText} inferrer={inferrer} />;

const FillMaskOutput: FC<{
result: FormattedTextClassificationResponse;
inputText: string;
}> = ({ result, inputText }) => {
const title = result[0]?.value ? inputText.replace(MASK, result[0].value) : inputText;
inferrer: FillMaskInference;
}> = ({ inferrer, inputText }) => {
const result = useObservable(inferrer.inferenceResult$);
if (!result) {
return null;
}

const title = result.response[0]?.value
? inputText.replace(MASK, result.response[0].value)
: inputText;
return (
<>
<EuiTitle size="xs">
Expand All @@ -29,7 +36,7 @@ const FillMaskOutput: FC<{

<EuiSpacer />

{result.map(({ value, predictionProbability }) => (
{result.response.map(({ value, predictionProbability }) => (
<>
<EuiProgress value={predictionProbability * 100} max={100} size="m" />
<EuiSpacer size="s" />
Expand Down
Loading

0 comments on commit 2372bcc

Please sign in to comment.