Skip to content

Commit

Permalink
[Obs AI Assistant] Add APM instrumentation for chat/complete
Browse files Browse the repository at this point in the history
  • Loading branch information
dgieselaar committed Jan 29, 2024
1 parent d791cdf commit 28126e5
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ describe('createChatService', () => {
}

function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) {
return service.chat({ signal, messages: [], connectorId: '' });
return service.chat('my_test', { signal, messages: [], connectorId: '' });
}

beforeEach(async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,20 @@ export async function createChatService({
});
});
},
chat({
connectorId,
messages,
function: callFunctions = 'auto',
signal,
}: {
connectorId: string;
messages: Message[];
function?: 'none' | 'auto';
signal: AbortSignal;
}) {
chat(
name: string,
{
connectorId,
messages,
function: callFunctions = 'auto',
signal,
}: {
connectorId: string;
messages: Message[];
function?: 'none' | 'auto';
signal: AbortSignal;
}
) {
return new Observable<StreamingChatResponseEventWithoutError>((subscriber) => {
const contexts = ['core', 'apm'];

Expand All @@ -193,6 +196,7 @@ export async function createChatService({
client('POST /internal/observability_ai_assistant/chat', {
params: {
body: {
name,
messages,
connectorId,
functions:
Expand Down
15 changes: 9 additions & 6 deletions x-pack/plugins/observability_ai_assistant/public/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@ export type { PendingMessage };

export interface ObservabilityAIAssistantChatService {
analytics: AnalyticsServiceStart;
chat: (options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
signal: AbortSignal;
}) => Observable<StreamingChatResponseEventWithoutError>;
chat: (
name: string,
options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
signal: AbortSignal;
}
) => Observable<StreamingChatResponseEventWithoutError>;
complete: (options: {
messages: Message[];
connectorId: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,21 @@ export class KibanaClient {
unregister: () => void;
}> = [];

async function chat({
messages,
functions,
functionCall,
}: {
messages: Message[];
functions: FunctionDefinition[];
functionCall?: string;
}) {
async function chat(
name: string,
{
messages,
functions,
functionCall,
}: {
messages: Message[];
functions: FunctionDefinition[];
functionCall?: string;
}
) {
const params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat'>['params']['body'] =
{
name,
messages,
connectorId,
functions: functions.map((fn) => pick(fn, 'name', 'description', 'parameters')),
Expand Down Expand Up @@ -157,7 +161,7 @@ export class KibanaClient {
'@timestamp': new Date().toISOString(),
})),
];
return chat({ messages, functions: functionDefinitions });
return chat('chat', { messages, functions: functionDefinitions });
},
complete: async (...args) => {
const messagesArg = args.length === 1 ? args[0] : args[1];
Expand Down Expand Up @@ -219,7 +223,7 @@ export class KibanaClient {
};
},
evaluate: async ({ messages, conversationId }, criteria) => {
const message = await chat({
const message = await chat('evaluate', {
messages: [
{
'@timestamp': new Date().toISOString(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export function registerEsqlFunction({
];

const source$ = (
await client.chat({
await client.chat('classify_esql', {
connectorId,
messages: withEsqlSystemMessage(
`Use the classify_esql function to classify the user's request
Expand Down Expand Up @@ -198,10 +198,12 @@ export function registerEsqlFunction({

const messagesToInclude = mapValues(pick(esqlDocs, keywords), ({ data }) => data);

const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat({
messages: [
...withEsqlSystemMessage(
`Format every ES|QL query as Markdown:
const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat(
'answer_esql_question',
{
messages: [
...withEsqlSystemMessage(
`Format every ES|QL query as Markdown:
\`\`\`esql
<query>
\`\`\`
Expand All @@ -224,33 +226,34 @@ export function registerEsqlFunction({
\`\`\`
`
),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
content: '',
function_call: {
name: 'get_esql_info',
arguments: JSON.stringify(args),
trigger: MessageRole.Assistant as const,
),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
content: '',
function_call: {
name: 'get_esql_info',
arguments: JSON.stringify(args),
trigger: MessageRole.Assistant as const,
},
},
},
},
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
name: 'get_esql_info',
content: JSON.stringify({
documentation: messagesToInclude,
}),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
name: 'get_esql_info',
content: JSON.stringify({
documentation: messagesToInclude,
}),
},
},
},
],
connectorId,
signal,
});
],
connectorId,
signal,
}
);

return esqlResponse$.pipe(
emitWithConcatenatedMessage((msg) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export function registerGetDatasetInfoFunction({
const relevantFields = await Promise.all(
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await client.chat({
await client.chat('get_relevent_dataset_names', {
connectorId,
signal,
messages: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async function scoreSuggestions({

const response = await lastValueFrom(
(
await client.chat({
await client.chat('score_suggestions', {
connectorId,
messages: [extendedSystemMessage, newUserMessage],
functions: [scoreFunction],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
params: t.type({
body: t.intersection([
t.type({
name: t.string,
messages: t.array(messageRt),
connectorId: t.string,
functions: t.array(
Expand All @@ -46,7 +47,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
}

const {
body: { messages, connectorId, functions, functionCall },
body: { name, messages, connectorId, functions, functionCall },
} = params;

const controller = new AbortController();
Expand All @@ -55,7 +56,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
controller.abort();
});

const response$ = await client.chat({
const response$ = await client.chat(name, {
messages,
connectorId,
signal: controller.signal,
Expand Down
Loading

0 comments on commit 28126e5

Please sign in to comment.