From 118c2046d2e7a74a23b479dc28761cf6dc5ad207 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Mon, 5 Feb 2024 10:40:54 +0100 Subject: [PATCH] [Obs AI Assistant] Add APM instrumentation for chat/complete (#175799) Adds APM instrumentation for function calls and interactions with the LLM. Closes https://github.com/elastic/obs-ai-assistant-team/issues/120 --- .../service/create_chat_service.test.ts | 2 +- .../public/service/create_chat_service.ts | 26 ++--- .../public/types.ts | 15 +-- .../scripts/evaluation/kibana_client.ts | 26 ++--- .../server/functions/esql/index.ts | 61 ++++++------ .../server/functions/get_dataset_info.ts | 2 +- .../server/functions/recall.ts | 2 +- .../server/routes/chat/route.ts | 5 +- .../server/service/client/index.ts | 96 ++++++++++++++----- .../tests/chat/chat.spec.ts | 3 + 10 files changed, 150 insertions(+), 88 deletions(-) diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts index b5b86fa4f15b3c..683792d5cf7081 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.test.ts @@ -57,7 +57,7 @@ describe('createChatService', () => { } function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) { - return service.chat({ signal, messages: [], connectorId: '' }); + return service.chat('my_test', { signal, messages: [], connectorId: '' }); } beforeEach(async () => { diff --git a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts index 06b521ca045505..1a001e7b10b1a0 100644 --- a/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts +++ b/x-pack/plugins/observability_ai_assistant/public/service/create_chat_service.ts @@ -174,17 +174,20 @@ export async function createChatService({ }); }); }, - chat({ - connectorId, - messages, - function: callFunctions = 'auto', - signal, - }: { - connectorId: string; - messages: Message[]; - function?: 'none' | 'auto'; - signal: AbortSignal; - }) { + chat( + name: string, + { + connectorId, + messages, + function: callFunctions = 'auto', + signal, + }: { + connectorId: string; + messages: Message[]; + function?: 'none' | 'auto'; + signal: AbortSignal; + } + ) { return new Observable((subscriber) => { const contexts = ['core', 'apm']; @@ -193,6 +196,7 @@ export async function createChatService({ client('POST /internal/observability_ai_assistant/chat', { params: { body: { + name, messages, connectorId, functions: diff --git a/x-pack/plugins/observability_ai_assistant/public/types.ts b/x-pack/plugins/observability_ai_assistant/public/types.ts index d535a7cfae27a9..ce2ea829081db7 100644 --- a/x-pack/plugins/observability_ai_assistant/public/types.ts +++ b/x-pack/plugins/observability_ai_assistant/public/types.ts @@ -50,12 +50,15 @@ export type { PendingMessage }; export interface ObservabilityAIAssistantChatService { analytics: AnalyticsServiceStart; - chat: (options: { - messages: Message[]; - connectorId: string; - function?: 'none' | 'auto'; - signal: AbortSignal; - }) => Observable; + chat: ( + name: string, + options: { + messages: Message[]; + connectorId: string; + function?: 'none' | 'auto'; + signal: AbortSignal; + } + ) => Observable; complete: (options: { messages: Message[]; connectorId: string; diff --git a/x-pack/plugins/observability_ai_assistant/scripts/evaluation/kibana_client.ts b/x-pack/plugins/observability_ai_assistant/scripts/evaluation/kibana_client.ts index 7ed447c6c907a4..d77e37a2b55a85 100644 --- a/x-pack/plugins/observability_ai_assistant/scripts/evaluation/kibana_client.ts +++ b/x-pack/plugins/observability_ai_assistant/scripts/evaluation/kibana_client.ts @@ -186,17 +186,21 @@ export class KibanaClient { unregister: () => void; }> = []; - async function chat({ - messages, - functions, - functionCall, - }: { - messages: Message[]; - functions: FunctionDefinition[]; - functionCall?: string; - }) { + async function chat( + name: string, + { + messages, + functions, + functionCall, + }: { + messages: Message[]; + functions: FunctionDefinition[]; + functionCall?: string; + } + ) { const params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat'>['params']['body'] = { + name, messages, connectorId, functions: functions.map((fn) => pick(fn, 'name', 'description', 'parameters')), @@ -235,7 +239,7 @@ export class KibanaClient { '@timestamp': new Date().toISOString(), })), ]; - return chat({ messages, functions: functionDefinitions }); + return chat('chat', { messages, functions: functionDefinitions }); }, complete: async (...args) => { const messagesArg = args.length === 1 ? args[0] : args[1]; @@ -298,7 +302,7 @@ export class KibanaClient { }; }, evaluate: async ({ messages, conversationId }, criteria) => { - const message = await chat({ + const message = await chat('evaluate', { messages: [ { '@timestamp': new Date().toISOString(), diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/esql/index.ts b/x-pack/plugins/observability_ai_assistant/server/functions/esql/index.ts index a44bf0f5f235cb..a8f3ad22ebe5ba 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/esql/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/esql/index.ts @@ -124,7 +124,7 @@ export function registerEsqlFunction({ ]; const source$ = ( - await client.chat({ + await client.chat('classify_esql', { connectorId, messages: withEsqlSystemMessage( `Use the classify_esql function to classify the user's request @@ -198,10 +198,12 @@ export function registerEsqlFunction({ const messagesToInclude = mapValues(pick(esqlDocs, keywords), ({ data }) => data); - const esqlResponse$: Observable = await client.chat({ - messages: [ - ...withEsqlSystemMessage( - `Format every ES|QL query as Markdown: + const esqlResponse$: Observable = await client.chat( + 'answer_esql_question', + { + messages: [ + ...withEsqlSystemMessage( + `Format every ES|QL query as Markdown: \`\`\`esql \`\`\` @@ -224,33 +226,34 @@ export function registerEsqlFunction({ \`\`\` ` - ), - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.Assistant, - content: '', - function_call: { - name: 'get_esql_info', - arguments: JSON.stringify(args), - trigger: MessageRole.Assistant as const, + ), + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.Assistant, + content: '', + function_call: { + name: 'get_esql_info', + arguments: JSON.stringify(args), + trigger: MessageRole.Assistant as const, + }, }, }, - }, - { - '@timestamp': new Date().toISOString(), - message: { - role: MessageRole.User, - name: 'get_esql_info', - content: JSON.stringify({ - documentation: messagesToInclude, - }), + { + '@timestamp': new Date().toISOString(), + message: { + role: MessageRole.User, + name: 'get_esql_info', + content: JSON.stringify({ + documentation: messagesToInclude, + }), + }, }, - }, - ], - connectorId, - signal, - }); + ], + connectorId, + signal, + } + ); return esqlResponse$.pipe( emitWithConcatenatedMessage((msg) => { diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts b/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts index f1d763359aa401..4969ddfb7e4023 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/get_dataset_info.ts @@ -115,7 +115,7 @@ export function registerGetDatasetInfoFunction({ const relevantFields = await Promise.all( chunk(fieldNames, 500).map(async (fieldsInChunk) => { const chunkResponse$ = ( - await client.chat({ + await client.chat('get_relevent_dataset_names', { connectorId, signal, messages: [ diff --git a/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts index 03ff6edbb0d641..7e966fa0e55086 100644 --- a/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts +++ b/x-pack/plugins/observability_ai_assistant/server/functions/recall.ts @@ -248,7 +248,7 @@ async function scoreSuggestions({ const response = await lastValueFrom( ( - await client.chat({ + await client.chat('score_suggestions', { connectorId, messages: [extendedSystemMessage, newUserMessage], functions: [scoreFunction], diff --git a/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts index 7cc57b769ece86..517cc48f9f27c7 100644 --- a/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/plugins/observability_ai_assistant/server/routes/chat/route.ts @@ -21,6 +21,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ params: t.type({ body: t.intersection([ t.type({ + name: t.string, messages: t.array(messageRt), connectorId: t.string, functions: t.array( @@ -46,7 +47,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ } const { - body: { messages, connectorId, functions, functionCall }, + body: { name, messages, connectorId, functions, functionCall }, } = params; const controller = new AbortController(); @@ -55,7 +56,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({ controller.abort(); }); - const response$ = await client.chat({ + const response$ = await client.chat(name, { messages, connectorId, signal: controller.signal, diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index 66832c0fd394d0..ad208927b76361 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -10,6 +10,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { ElasticsearchClient } from '@kbn/core/server'; import type { Logger } from '@kbn/logging'; import type { PublicMethodsOf } from '@kbn/utility-types'; +import apm from 'elastic-apm-node'; import { decode, encode } from 'gpt-tokenizer'; import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash'; import type OpenAI from 'openai'; @@ -191,17 +192,22 @@ export class ObservabilityAIAssistantClient { return await next(nextMessages.concat(addedMessage)); } else if (isUserMessage) { const response$ = ( - await this.chat({ - messages: nextMessages, - connectorId, - signal, - functions: - numFunctionsCalled >= MAX_FUNCTION_CALLS - ? [] - : functionClient - .getFunctions() - .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')), - }) + await this.chat( + lastMessage.message.name && lastMessage.message.name !== 'recall' + ? 'function_response' + : 'user_message', + { + messages: nextMessages, + connectorId, + signal, + functions: + numFunctionsCalled >= MAX_FUNCTION_CALLS + ? [] + : functionClient + .getFunctions() + .map((fn) => pick(fn.definition, 'name', 'description', 'parameters')), + } + ) ).pipe(emitWithConcatenatedMessage(), shareReplay()); response$.subscribe({ @@ -226,6 +232,14 @@ export class ObservabilityAIAssistantClient { } if (isAssistantMessageWithFunctionRequest) { + const span = apm.startSpan( + `execute_function ${lastMessage.message.function_call!.name}` + ); + + span?.addLabels({ + ai_assistant_args: JSON.stringify(lastMessage.message.function_call!.arguments ?? {}), + }); + const functionResponse = numFunctionsCalled >= MAX_FUNCTION_CALLS ? { @@ -247,6 +261,8 @@ export class ObservabilityAIAssistantClient { return response; } + span?.setOutcome('success'); + const encoded = encode(JSON.stringify(response.content || {})); if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) { @@ -263,6 +279,7 @@ export class ObservabilityAIAssistantClient { }; }) .catch((error): FunctionResponse => { + span?.setOutcome('failure'); return { content: { message: error.toString(), @@ -322,8 +339,13 @@ export class ObservabilityAIAssistantClient { ) ); + span?.end(); + return await next(nextMessages.concat(messageEvents.map((event) => event.message))); } + + span?.end(); + return await next(nextMessages); } @@ -401,19 +423,24 @@ export class ObservabilityAIAssistantClient { ).pipe(shareReplay()); }; - chat = async ({ - messages, - connectorId, - functions, - functionCall, - signal, - }: { - messages: Message[]; - connectorId: string; - functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>; - functionCall?: string; - signal: AbortSignal; - }): Promise> => { + chat = async ( + name: string, + { + messages, + connectorId, + functions, + functionCall, + signal, + }: { + messages: Message[]; + connectorId: string; + functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>; + functionCall?: string; + signal: AbortSignal; + } + ): Promise> => { + const span = apm.startSpan(`chat ${name}`); + const messagesForOpenAI: Array< Omit & { role: MessageRole; @@ -481,7 +508,24 @@ export class ObservabilityAIAssistantClient { signal.addEventListener('abort', () => response.destroy()); - return streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay()); + const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay()); + + if (span) { + lastValueFrom(observable) + .then( + () => { + span.setOutcome('success'); + }, + () => { + span.setOutcome('failure'); + } + ) + .finally(() => { + span.end(); + }); + } + + return observable; }; find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => { @@ -541,7 +585,7 @@ export class ObservabilityAIAssistantClient { connectorId: string; signal: AbortSignal; }) => { - const response$ = await this.chat({ + const response$ = await this.chat('generate_title', { messages: [ { '@timestamp': new Date().toISOString(), diff --git a/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts b/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts index f258f72769fe11..ef2fd5b6f96075 100644 --- a/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts +++ b/x-pack/test/observability_ai_assistant_api_integration/tests/chat/chat.spec.ts @@ -69,6 +69,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { .post(CHAT_API_URL) .set('kbn-xsrf', 'foo') .send({ + name: 'my_api_call', messages, connectorId: 'does not exist', functions: [], @@ -96,6 +97,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { .set('kbn-xsrf', 'foo') .on('error', reject) .send({ + name: 'my_api_call', messages, connectorId, functions: [], @@ -136,6 +138,7 @@ export default function ApiTest({ getService }: FtrProviderContext) { .post(CHAT_API_URL) .set('kbn-xsrf', 'foo') .send({ + name: 'my_api_call', messages, connectorId, functions: [],