Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.12] [Obs AI Assistant] Add guardrails (#174060) #174117

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
import type { DeeplyMockedKeys } from '@kbn/utility-types-jest';
import { merge } from 'lodash';
import { waitFor } from '@testing-library/react';
import { last, merge, repeat } from 'lodash';
import { ChatCompletionResponseMessage } from 'openai';
import { Subject } from 'rxjs';
import { PassThrough, type Readable } from 'stream';
import { EventEmitter, PassThrough, type Readable } from 'stream';
import { finished } from 'stream/promises';
import { ObservabilityAIAssistantClient } from '.';
import { createResourceNamesMap } from '..';
Expand Down Expand Up @@ -70,7 +72,7 @@ function createLlmSimulator() {
};
}

describe('Observability AI Assistant service', () => {
describe('Observability AI Assistant client', () => {
let client: ObservabilityAIAssistantClient;

const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
Expand All @@ -84,14 +86,8 @@ describe('Observability AI Assistant service', () => {
} as any;

const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
search: jest.fn().mockResolvedValue({
hits: {
hits: [],
},
}),
fieldCaps: jest.fn().mockResolvedValue({
fields: [],
}),
search: jest.fn(),
fieldCaps: jest.fn(),
} as any;

const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
Expand All @@ -107,16 +103,29 @@ describe('Observability AI Assistant service', () => {

const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
executeFunction: jest.fn(),
getFunctions: jest.fn().mockReturnValue([]),
hasFunction: jest.fn().mockImplementation((name) => {
return name !== 'recall';
}),
getFunctions: jest.fn(),
hasFunction: jest.fn(),
} as any;

let llmSimulator: LlmSimulator;

function createClient() {
jest.clearAllMocks();
jest.resetAllMocks();

functionClientMock.getFunctions.mockReturnValue([]);
functionClientMock.hasFunction.mockImplementation((name) => {
return name !== 'recall';
});

currentUserEsClientMock.search.mockResolvedValue({
hits: {
hits: [],
},
} as any);

currentUserEsClientMock.fieldCaps.mockResolvedValue({
fields: [],
} as any);

return new ObservabilityAIAssistantClient({
actionsClient: actionsClientMock,
Expand Down Expand Up @@ -158,6 +167,10 @@ describe('Observability AI Assistant service', () => {
);
}

beforeEach(() => {
jest.clearAllMocks();
});

describe('when completing a conversation without an initial conversation id', () => {
let stream: Readable;

Expand Down Expand Up @@ -1148,4 +1161,197 @@ describe('Observability AI Assistant service', () => {
});
});
});

describe('when the LLM keeps on calling a function and the limit has been exceeded', () => {
let stream: Readable;

let dataHandler: jest.Mock;

beforeEach(async () => {
client = createClient();

const onLlmCall = new EventEmitter();

function waitForNextLlmCall() {
return new Promise<void>((resolve) => onLlmCall.addListener('next', resolve));
}

actionsClientMock.execute.mockImplementation(async () => {
llmSimulator = createLlmSimulator();
onLlmCall.emit('next');
return {
actionId: '',
status: 'ok',
data: llmSimulator.stream,
};
});

functionClientMock.getFunctions.mockImplementation(() => [
{
definition: {
name: 'get_top_alerts',
contexts: ['core'],
description: '',
parameters: {},
},
respond: async () => {
return { content: 'Call this function again' };
},
},
]);

functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
functionClientMock.executeFunction.mockImplementation(async () => ({
content: 'Call this function again',
}));

stream = await client.complete({
connectorId: 'foo',
messages: [system('This is a system message'), user('How many alerts do I have?')],
functionClient: functionClientMock,
signal: new AbortController().signal,
title: 'My predefined title',
persist: true,
});

dataHandler = jest.fn();

stream.on('data', dataHandler);

async function requestAlertsFunctionCall() {
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);

let nextLlmCallPromise: Promise<void>;

if (body.functions?.length) {
nextLlmCallPromise = waitForNextLlmCall();
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
} else {
nextLlmCallPromise = Promise.resolve();
await llmSimulator.next({ content: 'Looks like we are done here' });
}

await llmSimulator.complete();

await nextLlmCallPromise;
}

await requestAlertsFunctionCall();

await requestAlertsFunctionCall();

await requestAlertsFunctionCall();

await requestAlertsFunctionCall();

await finished(stream);
});

it('executed the function no more than three times', () => {
expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(3);
});

it('does not give the LLM the choice to call a function anymore', () => {
const firstBody = JSON.parse(
(actionsClientMock.execute.mock.calls[0][0].params as any).subActionParams.body
);
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);

expect(firstBody.functions.length).toBe(1);

expect(body.functions).toBeUndefined();
});
});

describe('when the function response exceeds the max no of tokens for one', () => {
let stream: Readable;

let dataHandler: jest.Mock;

beforeEach(async () => {
client = createClient();

let functionResponsePromiseResolve: Function | undefined;

actionsClientMock.execute.mockImplementation(async () => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
status: 'ok',
data: llmSimulator.stream,
};
});

functionClientMock.getFunctions.mockImplementation(() => [
{
definition: {
name: 'get_top_alerts',
contexts: ['core'],
description: '',
parameters: {},
},
respond: async () => {
return { content: '' };
},
},
]);

functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');

functionClientMock.executeFunction.mockImplementation(() => {
return new Promise((resolve) => {
functionResponsePromiseResolve = resolve;
});
});

stream = await client.complete({
connectorId: 'foo',
messages: [system('This is a system message'), user('How many alerts do I have?')],
functionClient: functionClientMock,
signal: new AbortController().signal,
title: 'My predefined title',
persist: true,
});

dataHandler = jest.fn();

stream.on('data', dataHandler);

await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });

await llmSimulator.complete();

await waitFor(() => functionResponsePromiseResolve !== undefined);

functionResponsePromiseResolve!({
content: repeat('word ', 10000),
});

await waitFor(() => actionsClientMock.execute.mock.calls.length > 1);

await llmSimulator.next({ content: 'Looks like this was truncated' });

await llmSimulator.complete();

await finished(stream);
});
it('truncates the message', () => {
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);

const parsed = JSON.parse(last(body.messages as ChatCompletionResponseMessage[])!.content!);

expect(parsed).toEqual({
message: 'Function response exceeded the maximum length allowed and was truncated',
truncated: expect.any(String),
});

expect(parsed.truncated.includes('word ')).toBe(true);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import { compact, isEmpty, last, merge, omit, pick } from 'lodash';
import type {
ChatCompletionRequestMessage,
CreateChatCompletionRequest,
CreateChatCompletionResponse,
} from 'openai';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, omit, pick, take } from 'lodash';
import { isObservable, lastValueFrom } from 'rxjs';
import { PassThrough, Readable } from 'stream';
import { v4 } from 'uuid';
Expand Down Expand Up @@ -176,6 +177,11 @@ export class ObservabilityAIAssistantClient {
});
}

let numFunctionsCalled: number = 0;

const MAX_FUNCTION_CALLS = 3;
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;

const next = async (nextMessages: Message[]): Promise<void> => {
const lastMessage = last(nextMessages);

Expand Down Expand Up @@ -222,32 +228,65 @@ export class ObservabilityAIAssistantClient {
connectorId,
stream: true,
signal,
functions: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
functions:
numFunctionsCalled >= MAX_FUNCTION_CALLS
? []
: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
})
).pipe(processOpenAiStream()),
});
return await next(nextMessages.concat({ message, '@timestamp': new Date().toISOString() }));
}

if (isAssistantMessageWithFunctionRequest) {
const functionResponse = await functionClient
.executeFunction({
connectorId,
name: lastMessage.message.function_call!.name,
messages: nextMessages,
args: lastMessage.message.function_call!.arguments,
signal,
})
.catch((error): FunctionResponse => {
return {
content: {
message: error.toString(),
error,
},
};
});
const functionResponse =
numFunctionsCalled >= MAX_FUNCTION_CALLS
? {
content: {
error: {},
message: 'Function limit exceeded, ask the user what to do next',
},
}
: await functionClient
.executeFunction({
connectorId,
name: lastMessage.message.function_call!.name,
messages: nextMessages,
args: lastMessage.message.function_call!.arguments,
signal,
})
.then((response) => {
if (isObservable(response)) {
return response;
}

const encoded = encode(JSON.stringify(response.content || {}));

if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
return response;
}

return {
data: response.data,
content: {
message:
'Function response exceeded the maximum length allowed and was truncated',
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
},
};
})
.catch((error): FunctionResponse => {
return {
content: {
message: error.toString(),
error,
},
};
});

numFunctionsCalled++;

if (signal.aborted) {
return;
Expand Down
Loading