Skip to content

Commit

Permalink
[Obs AI Assistant] Add APM instrumentation for chat/complete (elastic…
Browse files Browse the repository at this point in the history
…#175799)

Adds APM instrumentation for function calls and interactions with the
LLM.

Closes elastic/obs-ai-assistant-team#120
  • Loading branch information
dgieselaar authored and fkanout committed Mar 4, 2024
1 parent 091d59a commit 21af56b
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ describe('createChatService', () => {
}

function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) {
return service.chat({ signal, messages: [], connectorId: '' });
return service.chat('my_test', { signal, messages: [], connectorId: '' });
}

beforeEach(async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,20 @@ export async function createChatService({
});
});
},
chat({
connectorId,
messages,
function: callFunctions = 'auto',
signal,
}: {
connectorId: string;
messages: Message[];
function?: 'none' | 'auto';
signal: AbortSignal;
}) {
chat(
name: string,
{
connectorId,
messages,
function: callFunctions = 'auto',
signal,
}: {
connectorId: string;
messages: Message[];
function?: 'none' | 'auto';
signal: AbortSignal;
}
) {
return new Observable<StreamingChatResponseEventWithoutError>((subscriber) => {
const contexts = ['core', 'apm'];

Expand All @@ -193,6 +196,7 @@ export async function createChatService({
client('POST /internal/observability_ai_assistant/chat', {
params: {
body: {
name,
messages,
connectorId,
functions:
Expand Down
15 changes: 9 additions & 6 deletions x-pack/plugins/observability_ai_assistant/public/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@ export type { PendingMessage };

export interface ObservabilityAIAssistantChatService {
analytics: AnalyticsServiceStart;
chat: (options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
signal: AbortSignal;
}) => Observable<StreamingChatResponseEventWithoutError>;
chat: (
name: string,
options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
signal: AbortSignal;
}
) => Observable<StreamingChatResponseEventWithoutError>;
complete: (options: {
messages: Message[];
connectorId: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,21 @@ export class KibanaClient {
unregister: () => void;
}> = [];

async function chat({
messages,
functions,
functionCall,
}: {
messages: Message[];
functions: FunctionDefinition[];
functionCall?: string;
}) {
async function chat(
name: string,
{
messages,
functions,
functionCall,
}: {
messages: Message[];
functions: FunctionDefinition[];
functionCall?: string;
}
) {
const params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat'>['params']['body'] =
{
name,
messages,
connectorId,
functions: functions.map((fn) => pick(fn, 'name', 'description', 'parameters')),
Expand Down Expand Up @@ -235,7 +239,7 @@ export class KibanaClient {
'@timestamp': new Date().toISOString(),
})),
];
return chat({ messages, functions: functionDefinitions });
return chat('chat', { messages, functions: functionDefinitions });
},
complete: async (...args) => {
const messagesArg = args.length === 1 ? args[0] : args[1];
Expand Down Expand Up @@ -298,7 +302,7 @@ export class KibanaClient {
};
},
evaluate: async ({ messages, conversationId }, criteria) => {
const message = await chat({
const message = await chat('evaluate', {
messages: [
{
'@timestamp': new Date().toISOString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export function registerEsqlFunction({
];

const source$ = (
await client.chat({
await client.chat('classify_esql', {
connectorId,
messages: withEsqlSystemMessage(
`Use the classify_esql function to classify the user's request
Expand Down Expand Up @@ -198,10 +198,12 @@ export function registerEsqlFunction({

const messagesToInclude = mapValues(pick(esqlDocs, keywords), ({ data }) => data);

const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat({
messages: [
...withEsqlSystemMessage(
`Format every ES|QL query as Markdown:
const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat(
'answer_esql_question',
{
messages: [
...withEsqlSystemMessage(
`Format every ES|QL query as Markdown:
\`\`\`esql
<query>
\`\`\`
Expand All @@ -224,33 +226,34 @@ export function registerEsqlFunction({
\`\`\`
`
),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
content: '',
function_call: {
name: 'get_esql_info',
arguments: JSON.stringify(args),
trigger: MessageRole.Assistant as const,
),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
content: '',
function_call: {
name: 'get_esql_info',
arguments: JSON.stringify(args),
trigger: MessageRole.Assistant as const,
},
},
},
},
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
name: 'get_esql_info',
content: JSON.stringify({
documentation: messagesToInclude,
}),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
name: 'get_esql_info',
content: JSON.stringify({
documentation: messagesToInclude,
}),
},
},
},
],
connectorId,
signal,
});
],
connectorId,
signal,
}
);

return esqlResponse$.pipe(
emitWithConcatenatedMessage((msg) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ export function registerGetDatasetInfoFunction({
const relevantFields = await Promise.all(
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await client.chat({
await client.chat('get_relevent_dataset_names', {
connectorId,
signal,
messages: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async function scoreSuggestions({

const response = await lastValueFrom(
(
await client.chat({
await client.chat('score_suggestions', {
connectorId,
messages: [extendedSystemMessage, newUserMessage],
functions: [scoreFunction],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
params: t.type({
body: t.intersection([
t.type({
name: t.string,
messages: t.array(messageRt),
connectorId: t.string,
functions: t.array(
Expand All @@ -46,7 +47,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
}

const {
body: { messages, connectorId, functions, functionCall },
body: { name, messages, connectorId, functions, functionCall },
} = params;

const controller = new AbortController();
Expand All @@ -55,7 +56,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
controller.abort();
});

const response$ = await client.chat({
const response$ = await client.chat(name, {
messages,
connectorId,
signal: controller.signal,
Expand Down
Loading

0 comments on commit 21af56b

Please sign in to comment.