Skip to content

Commit

Permalink
[Obs AI Assistant] Add guardrails (elastic#174060)
Browse files Browse the repository at this point in the history
Add guardrails against function looping (max 3 calls in a completion
request) and long function responses (max 4000 tokens).

---------

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
  • Loading branch information
dgieselaar and kibanamachine authored Jan 2, 2024
1 parent 34935a2 commit 725a217
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
import type { DeeplyMockedKeys } from '@kbn/utility-types-jest';
import { merge } from 'lodash';
import { waitFor } from '@testing-library/react';
import { last, merge, repeat } from 'lodash';
import { ChatCompletionResponseMessage } from 'openai';
import { Subject } from 'rxjs';
import { PassThrough, type Readable } from 'stream';
import { EventEmitter, PassThrough, type Readable } from 'stream';
import { finished } from 'stream/promises';
import { ObservabilityAIAssistantClient } from '.';
import { createResourceNamesMap } from '..';
Expand Down Expand Up @@ -70,7 +72,7 @@ function createLlmSimulator() {
};
}

describe('Observability AI Assistant service', () => {
describe('Observability AI Assistant client', () => {
let client: ObservabilityAIAssistantClient;

const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
Expand All @@ -84,14 +86,8 @@ describe('Observability AI Assistant service', () => {
} as any;

const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
search: jest.fn().mockResolvedValue({
hits: {
hits: [],
},
}),
fieldCaps: jest.fn().mockResolvedValue({
fields: [],
}),
search: jest.fn(),
fieldCaps: jest.fn(),
} as any;

const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
Expand All @@ -107,16 +103,29 @@ describe('Observability AI Assistant service', () => {

const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
executeFunction: jest.fn(),
getFunctions: jest.fn().mockReturnValue([]),
hasFunction: jest.fn().mockImplementation((name) => {
return name !== 'recall';
}),
getFunctions: jest.fn(),
hasFunction: jest.fn(),
} as any;

let llmSimulator: LlmSimulator;

function createClient() {
jest.clearAllMocks();
jest.resetAllMocks();

functionClientMock.getFunctions.mockReturnValue([]);
functionClientMock.hasFunction.mockImplementation((name) => {
return name !== 'recall';
});

currentUserEsClientMock.search.mockResolvedValue({
hits: {
hits: [],
},
} as any);

currentUserEsClientMock.fieldCaps.mockResolvedValue({
fields: [],
} as any);

return new ObservabilityAIAssistantClient({
actionsClient: actionsClientMock,
Expand Down Expand Up @@ -158,6 +167,10 @@ describe('Observability AI Assistant service', () => {
);
}

beforeEach(() => {
jest.clearAllMocks();
});

describe('when completing a conversation without an initial conversation id', () => {
let stream: Readable;

Expand Down Expand Up @@ -1148,4 +1161,197 @@ describe('Observability AI Assistant service', () => {
});
});
});

describe('when the LLM keeps on calling a function and the limit has been exceeded', () => {
let stream: Readable;

let dataHandler: jest.Mock;

beforeEach(async () => {
client = createClient();

const onLlmCall = new EventEmitter();

function waitForNextLlmCall() {
return new Promise<void>((resolve) => onLlmCall.addListener('next', resolve));
}

actionsClientMock.execute.mockImplementation(async () => {
llmSimulator = createLlmSimulator();
onLlmCall.emit('next');
return {
actionId: '',
status: 'ok',
data: llmSimulator.stream,
};
});

functionClientMock.getFunctions.mockImplementation(() => [
{
definition: {
name: 'get_top_alerts',
contexts: ['core'],
description: '',
parameters: {},
},
respond: async () => {
return { content: 'Call this function again' };
},
},
]);

functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
functionClientMock.executeFunction.mockImplementation(async () => ({
content: 'Call this function again',
}));

stream = await client.complete({
connectorId: 'foo',
messages: [system('This is a system message'), user('How many alerts do I have?')],
functionClient: functionClientMock,
signal: new AbortController().signal,
title: 'My predefined title',
persist: true,
});

dataHandler = jest.fn();

stream.on('data', dataHandler);

async function requestAlertsFunctionCall() {
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);

let nextLlmCallPromise: Promise<void>;

if (body.functions?.length) {
nextLlmCallPromise = waitForNextLlmCall();
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
} else {
nextLlmCallPromise = Promise.resolve();
await llmSimulator.next({ content: 'Looks like we are done here' });
}

await llmSimulator.complete();

await nextLlmCallPromise;
}

await requestAlertsFunctionCall();

await requestAlertsFunctionCall();

await requestAlertsFunctionCall();

await requestAlertsFunctionCall();

await finished(stream);
});

it('executed the function no more than three times', () => {
expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(3);
});

it('does not give the LLM the choice to call a function anymore', () => {
const firstBody = JSON.parse(
(actionsClientMock.execute.mock.calls[0][0].params as any).subActionParams.body
);
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);

expect(firstBody.functions.length).toBe(1);

expect(body.functions).toBeUndefined();
});
});

describe('when the function response exceeds the max no of tokens for one', () => {
let stream: Readable;

let dataHandler: jest.Mock;

beforeEach(async () => {
client = createClient();

let functionResponsePromiseResolve: Function | undefined;

actionsClientMock.execute.mockImplementation(async () => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
status: 'ok',
data: llmSimulator.stream,
};
});

functionClientMock.getFunctions.mockImplementation(() => [
{
definition: {
name: 'get_top_alerts',
contexts: ['core'],
description: '',
parameters: {},
},
respond: async () => {
return { content: '' };
},
},
]);

functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');

functionClientMock.executeFunction.mockImplementation(() => {
return new Promise((resolve) => {
functionResponsePromiseResolve = resolve;
});
});

stream = await client.complete({
connectorId: 'foo',
messages: [system('This is a system message'), user('How many alerts do I have?')],
functionClient: functionClientMock,
signal: new AbortController().signal,
title: 'My predefined title',
persist: true,
});

dataHandler = jest.fn();

stream.on('data', dataHandler);

await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });

await llmSimulator.complete();

await waitFor(() => functionResponsePromiseResolve !== undefined);

functionResponsePromiseResolve!({
content: repeat('word ', 10000),
});

await waitFor(() => actionsClientMock.execute.mock.calls.length > 1);

await llmSimulator.next({ content: 'Looks like this was truncated' });

await llmSimulator.complete();

await finished(stream);
});
it('truncates the message', () => {
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);

const parsed = JSON.parse(last(body.messages as ChatCompletionResponseMessage[])!.content!);

expect(parsed).toEqual({
message: 'Function response exceeded the maximum length allowed and was truncated',
truncated: expect.any(String),
});

expect(parsed.truncated.includes('word ')).toBe(true);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import { compact, isEmpty, last, merge, omit, pick } from 'lodash';
import type {
ChatCompletionRequestMessage,
CreateChatCompletionRequest,
CreateChatCompletionResponse,
} from 'openai';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, omit, pick, take } from 'lodash';
import { isObservable, lastValueFrom } from 'rxjs';
import { PassThrough, Readable } from 'stream';
import { v4 } from 'uuid';
Expand Down Expand Up @@ -176,6 +177,11 @@ export class ObservabilityAIAssistantClient {
});
}

let numFunctionsCalled: number = 0;

const MAX_FUNCTION_CALLS = 3;
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;

const next = async (nextMessages: Message[]): Promise<void> => {
const lastMessage = last(nextMessages);

Expand Down Expand Up @@ -222,32 +228,65 @@ export class ObservabilityAIAssistantClient {
connectorId,
stream: true,
signal,
functions: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
functions:
numFunctionsCalled >= MAX_FUNCTION_CALLS
? []
: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
})
).pipe(processOpenAiStream()),
});
return await next(nextMessages.concat({ message, '@timestamp': new Date().toISOString() }));
}

if (isAssistantMessageWithFunctionRequest) {
const functionResponse = await functionClient
.executeFunction({
connectorId,
name: lastMessage.message.function_call!.name,
messages: nextMessages,
args: lastMessage.message.function_call!.arguments,
signal,
})
.catch((error): FunctionResponse => {
return {
content: {
message: error.toString(),
error,
},
};
});
const functionResponse =
numFunctionsCalled >= MAX_FUNCTION_CALLS
? {
content: {
error: {},
message: 'Function limit exceeded, ask the user what to do next',
},
}
: await functionClient
.executeFunction({
connectorId,
name: lastMessage.message.function_call!.name,
messages: nextMessages,
args: lastMessage.message.function_call!.arguments,
signal,
})
.then((response) => {
if (isObservable(response)) {
return response;
}

const encoded = encode(JSON.stringify(response.content || {}));

if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
return response;
}

return {
data: response.data,
content: {
message:
'Function response exceeded the maximum length allowed and was truncated',
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
},
};
})
.catch((error): FunctionResponse => {
return {
content: {
message: error.toString(),
error,
},
};
});

numFunctionsCalled++;

if (signal.aborted) {
return;
Expand Down

0 comments on commit 725a217

Please sign in to comment.