diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts index a4fc12aba244ba..4fda485a29ac84 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_base.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { BehaviorSubject } from 'rxjs'; import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; @@ -18,6 +19,8 @@ export type FormattedNerResponse = Array<{ export abstract class InferenceBase { protected readonly inputField: string; + public inferenceResult$ = new BehaviorSubject(null); + public isRunning$ = new BehaviorSubject(false); constructor( protected trainedModelsApi: ReturnType, diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx index 5b2c6538613aaa..540ba83242c595 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/inference_input_form.tsx @@ -5,47 +5,46 @@ * 2.0. */ -import React, { FC, useState } from 'react'; +import React, { FC, useState, useMemo } from 'react'; +import useObservable from 'react-use/lib/useObservable'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiSpacer, EuiButton, EuiTabs, EuiTab } from '@elastic/eui'; -import { LangIdentInference } from './lang_ident'; -import { NerInference } from './ner'; -import { - TextClassificationInference, - ZeroShotClassificationInference, - FillMaskInference, -} from './text_classification'; - import type { FormattedLangIdentResponse } from './lang_ident'; import type { FormattedNerResponse } from './ner'; import type { FormattedTextClassificationResponse } from './text_classification'; +import type { FormattedTextEmbeddingResponse } from './text_embedding'; import { MLJobEditor } from '../../../../jobs/jobs_list/components/ml_job_editor'; import { extractErrorMessage } from '../../../../../../common/util/errors'; import { ErrorMessage } from '../inference_error'; import { OutputLoadingContent } from '../output_loading'; +import { NerInference } from './ner'; +import { + TextClassificationInference, + ZeroShotClassificationInference, + FillMaskInference, +} from './text_classification'; +import { TextEmbeddingInference } from './text_embedding'; +import { LangIdentInference } from './lang_ident'; type FormattedInferenceResponse = | FormattedLangIdentResponse | FormattedNerResponse - | FormattedTextClassificationResponse; - -type InferResponse = - | ReturnType - | ReturnType - | ReturnType - | ReturnType - | ReturnType; + | FormattedTextClassificationResponse + | FormattedTextEmbeddingResponse; interface Props { - getOutputComponent(output: FormattedInferenceResponse): JSX.Element; - getInputComponent(): JSX.Element; - inputText: string; - infer(): InferResponse; - isRunning: boolean; - setIsRunning(running: boolean): void; + getOutputComponent(inputText: string): JSX.Element; + getInputComponent: () => { inputComponent: JSX.Element; infer: () => any }; + inferrer: + | NerInference + | TextClassificationInference + | TextEmbeddingInference + | ZeroShotClassificationInference + | FillMaskInference + | LangIdentInference; } enum TAB { @@ -56,46 +55,41 @@ enum TAB { export const InferenceInputForm: FC = ({ getOutputComponent, getInputComponent, - inputText, - infer, - isRunning, - setIsRunning, + inferrer, }) => { const [output, setOutput] = useState(null); + const [inputText, setInputText] = useState(''); const [rawOutput, setRawOutput] = useState(null); const [selectedTab, setSelectedTab] = useState(TAB.TEXT); const [showOutput, setShowOutput] = useState(false); const [errorText, setErrorText] = useState(null); + const isRunning = useObservable(inferrer.isRunning$); + const { inputComponent, infer } = useMemo(getInputComponent, []); + async function run() { setShowOutput(true); setOutput(null); setRawOutput(null); - setIsRunning(true); setErrorText(null); try { - const { response, rawResponse } = await infer(); + const { response, rawResponse, inputText: inputText2 } = await infer(); setOutput(response); + setInputText(inputText2); setRawOutput(JSON.stringify(rawResponse, null, 2)); } catch (e) { - setIsRunning(false); setOutput(null); setErrorText(extractErrorMessage(e)); setRawOutput(JSON.stringify(e.body ?? e, null, 2)); } - setIsRunning(false); } return ( <> - <>{getInputComponent()} + <>{inputComponent}
- + = ({ {errorText !== null ? ( ) : output === null ? ( - + ) : ( - <>{getOutputComponent(output)} + <>{getOutputComponent(inputText)} )} ) : ( diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts index ede7104526e512..635cbad348c30e 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_inference.ts @@ -16,12 +16,14 @@ export type FormattedLangIdentResponse = Array<{ }>; interface InferResponse { + inputText: string; response: FormattedLangIdentResponse; rawResponse: estypes.IngestSimulateResponse; } export class LangIdentInference extends InferenceBase { public async infer(inputText: string) { + this.isRunning$.next(true); const payload: estypes.IngestSimulateRequest['body'] = { pipeline: { processors: [ @@ -54,15 +56,21 @@ export class LangIdentInference extends InferenceBase { if (resp.docs.length) { const topClasses = resp.docs[0].doc?._source._ml?.lang_ident?.top_classes ?? []; - return { + const r = { response: topClasses.map((t: any) => ({ className: t.class_name, classProbability: t.class_probability, classScore: t.class_score, })), rawResponse: resp, + inputText, }; + this.inferenceResult$.next(r); + return r; } - return { response: [], rawResponse: resp }; + this.isRunning$.next(false); + const r = { response: [], rawResponse: resp, inputText }; + this.inferenceResult$.next(r); + return r; } } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx index 346e52a7da3ff0..e44852f2f5586a 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/lang_ident/lang_ident_output.tsx @@ -6,26 +6,31 @@ */ import React, { FC } from 'react'; +import useObservable from 'react-use/lib/useObservable'; import { i18n } from '@kbn/i18n'; import { EuiSpacer, EuiBasicTable, EuiTitle } from '@elastic/eui'; -import type { FormattedLangIdentResponse } from './lang_ident_inference'; +import type { LangIdentInference } from './lang_ident_inference'; import { getLanguage } from './lang_codes'; const PROBABILITY_SIG_FIGS = 3; -export const getLangIdentOutputComponent = (output: FormattedLangIdentResponse) => ( - -); +export const getLangIdentOutputComponent = (inferrer: LangIdentInference) => () => + ; -const LangIdentOutput: FC<{ result: FormattedLangIdentResponse }> = ({ result }) => { - if (result.length === 0) { +const LangIdentOutput: FC<{ inferrer: LangIdentInference }> = ({ inferrer }) => { + const result = useObservable(inferrer.inferenceResult$); + if (!result) { return null; } - const lang = getLanguage(result[0].className); + if (result.response.length === 0) { + return null; + } + + const lang = getLanguage(result.response[0].className); - const items = result.map(({ className, classProbability }, i) => { + const items = result.response.map(({ className, classProbability }, i) => { return { noa: `${i + 1}`, className: getLanguage(className), @@ -74,7 +79,7 @@ const LangIdentOutput: FC<{ result: FormattedLangIdentResponse }> = ({ result }) }) : i18n.translate('xpack.ml.trainedModels.testModelsFlyout.langIdent.output.titleUnknown', { defaultMessage: 'Language code unknown: {code}', - values: { code: result[0].className }, + values: { code: result.response[0].className }, }); return ( diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts index 60c21c72c3c210..666763f48f6683 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_inference.ts @@ -17,14 +17,20 @@ export type FormattedNerResponse = Array<{ interface InferResponse { response: FormattedNerResponse; rawResponse: estypes.MlInferTrainedModelDeploymentResponse; + inputText: string; } export class NerInference extends InferenceBase { public async infer(inputText: string) { + this.isRunning$.next(true); const payload = { docs: { [this.inputField]: inputText } }; const resp = await this.trainedModelsApi.inferTrainedModel(this.model.model_id, payload, '30s'); - return { response: parseResponse(resp), rawResponse: resp }; + const processedResponse = { response: parseResponse(resp), rawResponse: resp, inputText }; + this.inferenceResult$.next(processedResponse); + this.isRunning$.next(false); + + return processedResponse; } } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx index 284b18404a0e7c..e4ce2384132971 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/ner/ner_output.tsx @@ -7,6 +7,7 @@ import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import React, { FC, ReactNode } from 'react'; +import useObservable from 'react-use/lib/useObservable'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiHorizontalRule, @@ -21,7 +22,7 @@ import { useCurrentEuiTheme, EuiThemeType, } from '../../../../../components/color_range_legend/use_color_range'; -import type { FormattedNerResponse } from './ner_inference'; +import type { NerInference } from './ner_inference'; const ICON_PADDING = '2px'; const PROBABILITY_SIG_FIGS = 3; @@ -60,14 +61,19 @@ const UNKNOWN_ENTITY_TYPE = { borderColor: 'euiColorVis5', }; -export const getNerOutputComponent = (output: FormattedNerResponse) => ( - -); +export const getNerOutputComponent = (inferrer: NerInference) => () => + ; -const NerOutput: FC<{ result: FormattedNerResponse }> = ({ result }) => { +const NerOutput: FC<{ inferrer: NerInference }> = ({ inferrer }) => { const { euiTheme } = useCurrentEuiTheme(); + const result = useObservable(inferrer.inferenceResult$); + + if (!result) { + return null; + } + const lineSplit: JSX.Element[] = []; - result.forEach(({ value, entity }) => { + result.response.forEach(({ value, entity }) => { if (entity === null) { const lines = value .split(/(\n)/) diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/common.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/common.ts index 0634e5bc39fc1c..1f8f685d5e92ae 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/common.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/common.ts @@ -24,13 +24,15 @@ export type FormattedTextClassificationResponse = Array<{ }>; export interface InferResponse { + inputText: string; response: FormattedTextClassificationResponse; rawResponse: TextClassificationResponse; } export function processResponse( resp: TextClassificationResponse, - model: estypes.MlTrainedModelConfig + model: estypes.MlTrainedModelConfig, + inputText: string ): InferResponse { const labels: string[] = // @ts-expect-error inference config is wrong @@ -77,5 +79,6 @@ export function processResponse( .sort((a, b) => a.predictionProbability - b.predictionProbability) .reverse(), rawResponse: resp, + inputText, }; } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_inference.ts index f052a88571829f..720a1b02399602 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_inference.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_inference.ts @@ -11,6 +11,7 @@ import { processResponse } from './common'; export class FillMaskInference extends InferenceBase { public async infer(inputText: string) { + this.isRunning$.next(true); const payload = { docs: { [this.inputField]: inputText }, inference_config: { fill_mask: { num_top_classes: 5 } }, @@ -21,6 +22,10 @@ export class FillMaskInference extends InferenceBase { '30s' )) as unknown as TextClassificationResponse; - return processResponse(resp, this.model); + const processedResponse = processResponse(resp, this.model, inputText); + this.inferenceResult$.next(processedResponse); + this.isRunning$.next(false); + + return processedResponse; } } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_output.tsx index 510841a17859f0..b9624ddc3a67ae 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_output.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/fill_mask_output.tsx @@ -6,21 +6,28 @@ */ import React, { FC } from 'react'; +import useObservable from 'react-use/lib/useObservable'; import { EuiFlexGroup, EuiFlexItem, EuiSpacer, EuiProgress, EuiTitle } from '@elastic/eui'; -import type { FormattedTextClassificationResponse } from './common'; +import type { FillMaskInference } from '.'; const MASK = '[MASK]'; -export const getFillMaskOutputComponent = - (inputText: string) => (output: FormattedTextClassificationResponse) => - ; +export const getFillMaskOutputComponent = (inferrer: FillMaskInference) => (inputText: string) => + ; const FillMaskOutput: FC<{ - result: FormattedTextClassificationResponse; inputText: string; -}> = ({ result, inputText }) => { - const title = result[0]?.value ? inputText.replace(MASK, result[0].value) : inputText; + inferrer: FillMaskInference; +}> = ({ inferrer, inputText }) => { + const result = useObservable(inferrer.inferenceResult$); + if (!result) { + return null; + } + + const title = result.response[0]?.value + ? inputText.replace(MASK, result.response[0].value) + : inputText; return ( <> @@ -29,7 +36,7 @@ const FillMaskOutput: FC<{ - {result.map(({ value, predictionProbability }) => ( + {result.response.map(({ value, predictionProbability }) => ( <> diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/index.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/index.ts index 828bf95e8f0677..4eeef37519ff27 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/index.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/index.ts @@ -11,7 +11,7 @@ export { TextClassificationInference } from './text_classification_inference'; export { getTextClassificationOutputComponent } from './text_classification_output'; export { ZeroShotClassificationInference } from './zero_shot_classification_inference'; -export { ZeroShotClassificationInput } from './zero_shot_classification_input'; +export { getZeroShotClassificationInput } from './zero_shot_classification_input'; export { FillMaskInference } from './fill_mask_inference'; export { getFillMaskOutputComponent } from './fill_mask_output'; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_inference.ts index a19f10c0ff1ccc..972e3489e7fe7a 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_inference.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_inference.ts @@ -11,6 +11,7 @@ import type { InferResponse, TextClassificationResponse } from './common'; export class TextClassificationInference extends InferenceBase { public async infer(inputText: string) { + this.isRunning$.next(true); const payload = { docs: { [this.inputField]: inputText }, }; @@ -20,6 +21,10 @@ export class TextClassificationInference extends InferenceBase { '30s' )) as unknown as TextClassificationResponse; - return processResponse(resp, this.model); + const processedResponse = processResponse(resp, this.model, inputText); + this.inferenceResult$.next(processedResponse); + this.isRunning$.next(false); + + return processedResponse; } } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_output.tsx index 4daf928a8d011c..7560e110f5bde4 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_output.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/text_classification_output.tsx @@ -6,20 +6,25 @@ */ import React, { FC } from 'react'; +import useObservable from 'react-use/lib/useObservable'; import { EuiFlexGroup, EuiFlexItem, EuiSpacer, EuiProgress } from '@elastic/eui'; -import type { FormattedTextClassificationResponse } from './common'; +import type { TextClassificationInference, ZeroShotClassificationInference } from '.'; -export const getTextClassificationOutputComponent = ( - output: FormattedTextClassificationResponse -) => ; +export const getTextClassificationOutputComponent = + (inferrer: TextClassificationInference | ZeroShotClassificationInference) => () => + ; -const TextClassificationOutput: FC<{ result: FormattedTextClassificationResponse }> = ({ - result, -}) => { +const TextClassificationOutput: FC<{ + inferrer: TextClassificationInference | ZeroShotClassificationInference; +}> = ({ inferrer }) => { + const result = useObservable(inferrer.inferenceResult$); + if (!result) { + return null; + } return ( <> - {result.map(({ value, predictionProbability }) => ( + {result.response.map(({ value, predictionProbability }) => ( <> diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/zero_shot_classification_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/zero_shot_classification_inference.ts index 6f657d472d85a8..e3d9b1671b1e13 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/zero_shot_classification_inference.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_classification/zero_shot_classification_inference.ts @@ -11,6 +11,7 @@ import type { InferResponse, TextClassificationResponse } from './common'; export class ZeroShotClassificationInference extends InferenceBase { public async infer(inputText: string, labelsText?: string) { + this.isRunning$.next(true); const inputLabels = labelsText?.split(',').map((l) => l.trim()); const payload = { docs: { [this.inputField]: inputText }, @@ -27,6 +28,10 @@ export class ZeroShotClassificationInference extends InferenceBase = ({ disabled, inputText, setInputText }) => { + setExternalInputText: (inputText: string) => void; + inferrer: ZeroShotClassificationInference; +}> = ({ setExternalInputText, inferrer }) => { + const [inputText, setInputText] = useState(''); + + useEffect(() => { + setExternalInputText(inputText); + }, [inputText]); + + const isRunning = useObservable(inferrer.isRunning$); return ( { setInputText(e.target.value); @@ -38,18 +46,26 @@ const ClassNameInput: FC<{ ); }; -export const ZeroShotClassificationInput: FC<{ - disabled: boolean; - inputText: string; - inputText2: string; - setInputText(input: string): void; - setInputText2(input: string): void; -}> = ({ disabled, inputText, setInputText, inputText2, setInputText2 }) => { - return ( - <> - - - - - ); -}; +export const getZeroShotClassificationInput = + (inferrer: ZeroShotClassificationInference, placeholder?: string) => () => { + let inputText = ''; + let inputText2 = ''; + + return { + inputComponent: ( + <> + (inputText = txt)} + inferrer={inferrer} + /> + + (inputText2 = txt)} + inferrer={inferrer} + /> + + ), + infer: () => inferrer.infer(inputText, inputText2), + }; + }; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_inference.ts b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_inference.ts index 6de03f4dd87d7d..c3622646316d7f 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_inference.ts +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_inference.ts @@ -18,15 +18,17 @@ export interface FormattedTextEmbeddingResponse { } export interface InferResponse { + inputText: string; response: FormattedTextEmbeddingResponse; rawResponse: TextEmbeddingResponse; } export class TextEmbeddingInference extends InferenceBase { public async infer(inputText: string) { + this.isRunning$.next(true); const payload = { docs: { [this.inputField]: inputText }, - inference_config: { fill_mask: { num_top_classes: 5 } }, + // inference_config: { text_embedding: { num_top_classes: 5 } }, }; const resp = (await this.trainedModelsApi.inferTrainedModel( this.model.model_id, @@ -34,11 +36,19 @@ export class TextEmbeddingInference extends InferenceBase { '30s' )) as unknown as TextEmbeddingResponse; - return processResponse(resp, this.model); + const processedResponse = processResponse(resp, this.model, inputText); + this.inferenceResult$.next(processedResponse); + this.isRunning$.next(false); + + return processedResponse; } } -function processResponse(resp: TextEmbeddingResponse, model: estypes.MlTrainedModelConfig) { +function processResponse( + resp: TextEmbeddingResponse, + model: estypes.MlTrainedModelConfig, + inputText: string +) { const predictedValue = resp.predicted_value; - return { response: { predictedValue }, rawResponse: resp }; + return { response: { predictedValue }, rawResponse: resp, inputText }; } diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_output.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_output.tsx index bef0ef70a5fd01..ce4ce41ab9539c 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_output.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_embedding/text_embedding_output.tsx @@ -6,17 +6,24 @@ */ import React, { FC } from 'react'; +import useObservable from 'react-use/lib/useObservable'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiTextArea, EuiCopy, EuiButton } from '@elastic/eui'; -import type { FormattedTextEmbeddingResponse } from './text_embedding_inference'; +import type { TextEmbeddingInference } from './text_embedding_inference'; -export const getTextEmbeddingOutputComponent = (output: FormattedTextEmbeddingResponse) => ( - -); +export const getTextEmbeddingOutputComponent = (inferrer: TextEmbeddingInference) => () => + ; -const TextEmbeddingOutput: FC<{ result: FormattedTextEmbeddingResponse }> = ({ result }) => { - const value = result.predictedValue.toString(); +const TextEmbeddingOutput: FC<{ + inferrer: TextEmbeddingInference; +}> = ({ inferrer }) => { + const result = useObservable(inferrer.inferenceResult$); + if (!result) { + return null; + } + + const value = result.response.predictedValue.toString(); return ( <> diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_input.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_input.tsx index 3cd9c84eb14992..9785609bac66c5 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_input.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/models/text_input.tsx @@ -5,16 +5,38 @@ * 2.0. */ -import React, { FC } from 'react'; +import React, { FC, useState, useEffect } from 'react'; import { i18n } from '@kbn/i18n'; +import useObservable from 'react-use/lib/useObservable'; import { EuiTextArea } from '@elastic/eui'; +import { NerInference } from './ner'; +import { + TextClassificationInference, + ZeroShotClassificationInference, + FillMaskInference, +} from './text_classification'; +import { TextEmbeddingInference } from './text_embedding'; +import { LangIdentInference } from './lang_ident'; export const TextInput: FC<{ - disabled: boolean; - inputText: string; - setInputText(input: string): void; placeholder?: string; -}> = ({ disabled, inputText, setInputText, placeholder }) => { + setExternalInputText: (inputText: string) => void; + inferrer: + | NerInference + | TextClassificationInference + | TextEmbeddingInference + | ZeroShotClassificationInference + | FillMaskInference + | LangIdentInference; +}> = ({ placeholder, setExternalInputText, inferrer }) => { + const [inputText, setInputText] = useState(''); + + useEffect(() => { + setExternalInputText(inputText); + }, [inputText]); + + const isRunning = useObservable(inferrer.isRunning$); + return ( { setInputText(e.target.value); @@ -32,3 +54,29 @@ export const TextInput: FC<{ /> ); }; + +export const getGeneralInputComponent = + ( + inferrer: + | NerInference + | TextClassificationInference + | TextEmbeddingInference + | ZeroShotClassificationInference + | FillMaskInference + | LangIdentInference, + placeholder?: string + ) => + () => { + let inputText = ''; + + return { + inputComponent: ( + (inputText = txt)} + inferrer={inferrer} + /> + ), + infer: () => inferrer.infer(inputText), + }; + }; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx index 0bc86b1d27517b..3140ed71984306 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/test_models/selected_model.tsx @@ -6,7 +6,7 @@ */ import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; -import React, { FC, useState } from 'react'; +import React, { FC } from 'react'; import { i18n } from '@kbn/i18n'; import { getNerOutputComponent, NerInference } from './models/ner'; @@ -17,14 +17,14 @@ import { getTextClassificationOutputComponent, TextClassificationInference, ZeroShotClassificationInference, - ZeroShotClassificationInput, FillMaskInference, getFillMaskOutputComponent, + getZeroShotClassificationInput, } from './models/text_classification'; import { getTextEmbeddingOutputComponent, TextEmbeddingInference } from './models/text_embedding'; -import { TextInput } from './models/text_input'; +import { getGeneralInputComponent } from './models/text_input'; import { TRAINED_MODEL_TYPE, @@ -39,48 +39,38 @@ interface Props { export const SelectedModel: FC = ({ model }) => { const { trainedModels } = useMlApiContext(); - const [inputText, setInputText] = useState(''); - const [inputText2, setInputText2] = useState(''); - const [isRunning, setIsRunning] = useState(false); + // const [inputText, setInputText] = useState(''); + // const [inputText2, setInputText2] = useState(''); + // const [isRunning, setIsRunning] = useState(false); if (model === null) { return null; } - const getComp = (infer: any, getOutputComponent: any, getInputComponent: any) => { - return ( - - ); - }; - - const getGeneralInputComponent = (placeholder?: string) => ( - - ); + // const getComp = (getOutputComponent: any, getInputComponent: any) => { + // return ( + // + // ); + // }; if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) { if (Object.keys(model.inference_config)[0] === SUPPORTED_PYTORCH_TASKS.NER) { const inferrer = new NerInference(trainedModels, model); - + // eslint-disable-next-line no-console + console.log(222222); + // return <>{getComp(() => getNerOutputComponent, getGeneralInputComponent(inferrer))}; return ( - <> - {getComp( - () => inferrer.infer(inputText), - getNerOutputComponent, - getGeneralInputComponent - )} - + ); } @@ -88,13 +78,11 @@ export const SelectedModel: FC = ({ model }) => { const inferrer = new TextClassificationInference(trainedModels, model); return ( - <> - {getComp( - () => inferrer.infer(inputText), - getTextClassificationOutputComponent, - getGeneralInputComponent - )} - + ); } @@ -103,24 +91,12 @@ export const SelectedModel: FC = ({ model }) => { ) { const inferrer = new ZeroShotClassificationInference(trainedModels, model); - const getZeroShotInputComponent = () => ( - - ); - return ( - <> - {getComp( - () => inferrer.infer(inputText, inputText2), - getTextClassificationOutputComponent, - getZeroShotInputComponent - )} - + ); } @@ -128,13 +104,11 @@ export const SelectedModel: FC = ({ model }) => { const inferrer = new TextEmbeddingInference(trainedModels, model); return ( - <> - {getComp( - () => inferrer.infer(inputText), - getTextEmbeddingOutputComponent, - getGeneralInputComponent - )} - + ); } @@ -149,13 +123,11 @@ export const SelectedModel: FC = ({ model }) => { ); return ( - <> - {getComp( - () => inferrer.infer(inputText), - getFillMaskOutputComponent(inputText), - () => getGeneralInputComponent(placeholder) - )} - + ); } } @@ -163,13 +135,11 @@ export const SelectedModel: FC = ({ model }) => { const inferrer = new LangIdentInference(trainedModels, model); return ( - <> - {getComp( - () => inferrer.infer(inputText), - getLangIdentOutputComponent, - getGeneralInputComponent - )} - + ); }