diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts index 7cffaa64d16d31..a0e0a40a4ac89a 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.test.ts @@ -7,9 +7,11 @@ import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; import type { ElasticsearchClient, Logger } from '@kbn/core/server'; import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; -import { merge } from 'lodash'; +import { waitFor } from '@testing-library/react'; +import { last, merge, repeat } from 'lodash'; +import { ChatCompletionResponseMessage } from 'openai'; import { Subject } from 'rxjs'; -import { PassThrough, type Readable } from 'stream'; +import { EventEmitter, PassThrough, type Readable } from 'stream'; import { finished } from 'stream/promises'; import { ObservabilityAIAssistantClient } from '.'; import { createResourceNamesMap } from '..'; @@ -70,7 +72,7 @@ function createLlmSimulator() { }; } -describe('Observability AI Assistant service', () => { +describe('Observability AI Assistant client', () => { let client: ObservabilityAIAssistantClient; const actionsClientMock: DeeplyMockedKeys = { @@ -84,14 +86,8 @@ describe('Observability AI Assistant service', () => { } as any; const currentUserEsClientMock: DeeplyMockedKeys = { - search: jest.fn().mockResolvedValue({ - hits: { - hits: [], - }, - }), - fieldCaps: jest.fn().mockResolvedValue({ - fields: [], - }), + search: jest.fn(), + fieldCaps: jest.fn(), } as any; const knowledgeBaseServiceMock: DeeplyMockedKeys = { @@ -107,16 +103,29 @@ describe('Observability AI Assistant service', () => { const functionClientMock: DeeplyMockedKeys = { executeFunction: jest.fn(), - getFunctions: jest.fn().mockReturnValue([]), - hasFunction: jest.fn().mockImplementation((name) => { - return name !== 'recall'; - }), + getFunctions: jest.fn(), + hasFunction: jest.fn(), } as any; let llmSimulator: LlmSimulator; function createClient() { - jest.clearAllMocks(); + jest.resetAllMocks(); + + functionClientMock.getFunctions.mockReturnValue([]); + functionClientMock.hasFunction.mockImplementation((name) => { + return name !== 'recall'; + }); + + currentUserEsClientMock.search.mockResolvedValue({ + hits: { + hits: [], + }, + } as any); + + currentUserEsClientMock.fieldCaps.mockResolvedValue({ + fields: [], + } as any); return new ObservabilityAIAssistantClient({ actionsClient: actionsClientMock, @@ -158,6 +167,10 @@ describe('Observability AI Assistant service', () => { ); } + beforeEach(() => { + jest.clearAllMocks(); + }); + describe('when completing a conversation without an initial conversation id', () => { let stream: Readable; @@ -1148,4 +1161,197 @@ describe('Observability AI Assistant service', () => { }); }); }); + + describe('when the LLM keeps on calling a function and the limit has been exceeded', () => { + let stream: Readable; + + let dataHandler: jest.Mock; + + beforeEach(async () => { + client = createClient(); + + const onLlmCall = new EventEmitter(); + + function waitForNextLlmCall() { + return new Promise((resolve) => onLlmCall.addListener('next', resolve)); + } + + actionsClientMock.execute.mockImplementation(async () => { + llmSimulator = createLlmSimulator(); + onLlmCall.emit('next'); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + functionClientMock.getFunctions.mockImplementation(() => [ + { + definition: { + name: 'get_top_alerts', + contexts: ['core'], + description: '', + parameters: {}, + }, + respond: async () => { + return { content: 'Call this function again' }; + }, + }, + ]); + + functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts'); + functionClientMock.executeFunction.mockImplementation(async () => ({ + content: 'Call this function again', + })); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + title: 'My predefined title', + persist: true, + }); + + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + async function requestAlertsFunctionCall() { + const body = JSON.parse( + (actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body + ); + + let nextLlmCallPromise: Promise; + + if (body.functions?.length) { + nextLlmCallPromise = waitForNextLlmCall(); + await llmSimulator.next({ function_call: { name: 'get_top_alerts' } }); + } else { + nextLlmCallPromise = Promise.resolve(); + await llmSimulator.next({ content: 'Looks like we are done here' }); + } + + await llmSimulator.complete(); + + await nextLlmCallPromise; + } + + await requestAlertsFunctionCall(); + + await requestAlertsFunctionCall(); + + await requestAlertsFunctionCall(); + + await requestAlertsFunctionCall(); + + await finished(stream); + }); + + it('executed the function no more than three times', () => { + expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(3); + }); + + it('does not give the LLM the choice to call a function anymore', () => { + const firstBody = JSON.parse( + (actionsClientMock.execute.mock.calls[0][0].params as any).subActionParams.body + ); + const body = JSON.parse( + (actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body + ); + + expect(firstBody.functions.length).toBe(1); + + expect(body.functions).toBeUndefined(); + }); + }); + + describe('when the function response exceeds the max no of tokens for one', () => { + let stream: Readable; + + let dataHandler: jest.Mock; + + beforeEach(async () => { + client = createClient(); + + let functionResponsePromiseResolve: Function | undefined; + + actionsClientMock.execute.mockImplementation(async () => { + llmSimulator = createLlmSimulator(); + return { + actionId: '', + status: 'ok', + data: llmSimulator.stream, + }; + }); + + functionClientMock.getFunctions.mockImplementation(() => [ + { + definition: { + name: 'get_top_alerts', + contexts: ['core'], + description: '', + parameters: {}, + }, + respond: async () => { + return { content: '' }; + }, + }, + ]); + + functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts'); + + functionClientMock.executeFunction.mockImplementation(() => { + return new Promise((resolve) => { + functionResponsePromiseResolve = resolve; + }); + }); + + stream = await client.complete({ + connectorId: 'foo', + messages: [system('This is a system message'), user('How many alerts do I have?')], + functionClient: functionClientMock, + signal: new AbortController().signal, + title: 'My predefined title', + persist: true, + }); + + dataHandler = jest.fn(); + + stream.on('data', dataHandler); + + await llmSimulator.next({ function_call: { name: 'get_top_alerts' } }); + + await llmSimulator.complete(); + + await waitFor(() => functionResponsePromiseResolve !== undefined); + + functionResponsePromiseResolve!({ + content: repeat('word ', 10000), + }); + + await waitFor(() => actionsClientMock.execute.mock.calls.length > 1); + + await llmSimulator.next({ content: 'Looks like this was truncated' }); + + await llmSimulator.complete(); + + await finished(stream); + }); + it('truncates the message', () => { + const body = JSON.parse( + (actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body + ); + + const parsed = JSON.parse(last(body.messages as ChatCompletionResponseMessage[])!.content!); + + expect(parsed).toEqual({ + message: 'Function response exceeded the maximum length allowed and was truncated', + truncated: expect.any(String), + }); + + expect(parsed.truncated.includes('word ')).toBe(true); + }); + }); }); diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index fafb7606a27695..c158d40fbc8a09 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -10,12 +10,13 @@ import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { ElasticsearchClient } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; import type { PublicMethodsOf } from '@kbn/utility-types'; -import { compact, isEmpty, last, merge, omit, pick } from 'lodash'; import type { ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionResponse, } from 'openai'; +import { decode, encode } from 'gpt-tokenizer'; +import { compact, isEmpty, last, merge, omit, pick, take } from 'lodash'; import { isObservable, lastValueFrom } from 'rxjs'; import { PassThrough, Readable } from 'stream'; import { v4 } from 'uuid'; @@ -176,6 +177,11 @@ export class ObservabilityAIAssistantClient { }); } + let numFunctionsCalled: number = 0; + + const MAX_FUNCTION_CALLS = 3; + const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; + const next = async (nextMessages: Message[]): Promise => { const lastMessage = last(nextMessages); @@ -222,9 +228,12 @@ export class ObservabilityAIAssistantClient { connectorId, stream: true, signal, - functions: functionClient - .getFunctions() - .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')), + functions: + numFunctionsCalled >= MAX_FUNCTION_CALLS + ? [] + : functionClient + .getFunctions() + .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')), }) ).pipe(processOpenAiStream()), }); @@ -232,22 +241,52 @@ export class ObservabilityAIAssistantClient { } if (isAssistantMessageWithFunctionRequest) { - const functionResponse = await functionClient - .executeFunction({ - connectorId, - name: lastMessage.message.function_call!.name, - messages: nextMessages, - args: lastMessage.message.function_call!.arguments, - signal, - }) - .catch((error): FunctionResponse => { - return { - content: { - message: error.toString(), - error, - }, - }; - }); + const functionResponse = + numFunctionsCalled >= MAX_FUNCTION_CALLS + ? { + content: { + error: {}, + message: 'Function limit exceeded, ask the user what to do next', + }, + } + : await functionClient + .executeFunction({ + connectorId, + name: lastMessage.message.function_call!.name, + messages: nextMessages, + args: lastMessage.message.function_call!.arguments, + signal, + }) + .then((response) => { + if (isObservable(response)) { + return response; + } + + const encoded = encode(JSON.stringify(response.content || {})); + + if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) { + return response; + } + + return { + data: response.data, + content: { + message: + 'Function response exceeded the maximum length allowed and was truncated', + truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)), + }, + }; + }) + .catch((error): FunctionResponse => { + return { + content: { + message: error.toString(), + error, + }, + }; + }); + + numFunctionsCalled++; if (signal.aborted) { return;