Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Assistant] Fix langgraph issues #189287

Merged
merged 15 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@
"resolutions": {
"**/@bazel/typescript/protobufjs": "6.11.4",
"**/@hello-pangea/dnd": "16.6.0",
"**/@langchain/core": "^0.2.17",
"**/@langchain/core": "^0.2.18",
"**/@types/node": "20.10.5",
"**/@typescript-eslint/utils": "5.62.0",
"**/chokidar": "^3.5.3",
"**/d3-scale/**/d3-color": "npm:@elastic/kibana-d3-color@2.0.1",
"**/globule/minimatch": "^3.1.2",
"**/hoist-non-react-statics": "^3.3.2",
"**/isomorphic-fetch/node-fetch": "^2.6.7",
"**/langchain": "^0.2.10",
"**/langchain": "^0.2.11",
"**/react-intl/**/@types/react": "^17.0.45",
"**/remark-parse/trim": "1.0.1",
"**/sharp": "0.32.6",
Expand Down Expand Up @@ -949,9 +949,9 @@
"@kbn/zod": "link:packages/kbn-zod",
"@kbn/zod-helpers": "link:packages/kbn-zod-helpers",
"@langchain/community": "0.2.18",
"@langchain/core": "^0.2.17",
"@langchain/core": "^0.2.18",
"@langchain/google-genai": "^0.0.23",
"@langchain/langgraph": "^0.0.29",
"@langchain/langgraph": "^0.0.31",
"@langchain/openai": "^0.1.3",
"@langtrase/trace-attributes": "^3.0.8",
"@launchdarkly/node-server-sdk": "^9.4.7",
Expand Down Expand Up @@ -1093,8 +1093,8 @@
"jsonwebtoken": "^9.0.2",
"jsts": "^1.6.2",
"kea": "^2.6.0",
"langchain": "^0.2.10",
"langsmith": "^0.1.37",
"langchain": "^0.2.11",
"langsmith": "^0.1.39",
"launchdarkly-js-client-sdk": "^3.4.0",
"launchdarkly-node-server-sdk": "^7.0.3",
"load-json-file": "^6.2.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ export type AssistantFeatureKey = keyof AssistantFeatures;
export const defaultAssistantFeatures = Object.freeze({
assistantKnowledgeBaseByDefault: false,
assistantModelEvaluation: false,
assistantBedrockChat: false,
assistantBedrockChat: true,
});
spong marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
params: {
subAction: 'invokeAIRaw',
subActionParams: {
messages: inputBody.messages,
messages: prepareMessages(inputBody.messages),
temperature: params.temperature ?? inputBody.temperature,
stopSequences: inputBody.stop_sequences,
system: inputBody.system,
Expand Down Expand Up @@ -99,3 +99,20 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
});
}
}

const prepareMessages = (messages: Array<{ role: string; content: string[] }>) =>
messages.reduce((acc, { role, content }) => {
const lastMessage = acc[acc.length - 1];

if (!lastMessage || lastMessage.role !== role) {
acc.push({ role, content });
return acc;
}

if (lastMessage.role === role) {
acc[acc.length - 1].content = lastMessage.content.concat(content);
return acc;
}

return acc;
}, [] as Array<{ role: string; content: string[] }>);
spong marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
}

async completionWithRetry(
request: string | GenerateContentRequest | Array<string | Part>,
request: GenerateContentRequest,
options?: this['ParsedCallOptions']
): Promise<GenerateContentResult> {
return this.caller.callWithOptions({ signal: options?.signal }, async () => {
Expand All @@ -80,7 +80,8 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
subAction: 'invokeAIRaw',
subActionParams: {
model: this.#model,
messages: request,
messages: request.contents,
tools: request.tools,
temperature: this.#temperature,
},
},
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ export const getDefaultAssistantGraph = ({
default: () => undefined,
},
messages: {
value: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y),
value: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
default: () => [],
},
chatTitle: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export const streamGraph = async ({
runName: DEFAULT_ASSISTANT_GRAPH_ID,
tags: traceOptions?.tags ?? [],
version: 'v2',
streamMode: 'values',
},
llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled
? createToolCallingAgent({
? await createToolCallingAgent({
spong marked this conversation as resolved.
Show resolved Hide resolved
llm,
tools,
prompt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import { RunnableConfig } from '@langchain/core/runnables';
import { StructuredTool } from '@langchain/core/tools';
import { ToolExecutor } from '@langchain/langgraph/prebuilt';
import { isArray } from 'lodash';
import { castArray } from 'lodash';
import { AgentAction } from 'langchain/agents';
import { AgentState, NodeParamsBase } from '../types';

export interface ExecuteToolsParams extends NodeParamsBase {
Expand All @@ -34,23 +35,30 @@ export const executeTools = async ({ config, logger, state, tools }: ExecuteTool
logger.debug(() => `Node state:\n${JSON.stringify(state, null, 2)}`);

const toolExecutor = new ToolExecutor({ tools });
const agentAction = isArray(state.agentOutcome) ? state.agentOutcome[0] : state.agentOutcome;
const agentAction = state.agentOutcome;

if (!agentAction || 'returnValues' in agentAction) {
throw new Error('Agent has not been run yet');
}

let out;
try {
out = await toolExecutor.invoke(agentAction, config);
} catch (err) {
return {
steps: [{ action: agentAction, observation: JSON.stringify(`Error: ${err}`, null, 2) }],
};
}
const steps = await Promise.all(
castArray(state.agentOutcome as AgentAction)?.map(async (action) => {
Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: This fixes the multiple tool call code path which was previously not working. So now if the LLM chooses multiple tools at once (only Bedrock right now, OpenAI supports it as well, but not via langchain chatmodel IIRC), both will execute.

let out;
try {
out = await toolExecutor.invoke(action, config);
} catch (err) {
return {
action,
observation: JSON.stringify(`Error: ${err}`, null, 2),
};
}

return {
action,
observation: JSON.stringify(out, null, 2),
};
})
);

return {
...state,
steps: [{ action: agentAction, observation: JSON.stringify(out, null, 2) }],
};
return { steps };
};
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ export const generateChatTitle = async ({
logger.debug(`chatTitle: ${chatTitle}`);

return {
...state,
chatTitle,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ export const getPersistedConversation = async ({
if (!conversationId) {
logger.debug('Cannot get conversation, because conversationId is undefined');
return {
...state,
conversation: undefined,
messages: [],
chatTitle: '',
Expand All @@ -39,7 +38,6 @@ export const getPersistedConversation = async ({
if (!conversation) {
logger.debug('Requested conversation, because conversation is undefined');
return {
...state,
conversation: undefined,
messages: [],
chatTitle: '',
Expand All @@ -50,11 +48,21 @@ export const getPersistedConversation = async ({
logger.debug(`conversationId: ${conversationId}`);

const messages = getLangChainMessages(conversation.messages ?? []);

if (!state.input) {
const lastMessage = messages?.splice(-1)[0];
return {
conversation,
messages,
chatTitle: conversation.title,
input: lastMessage?.content as string,
};
}

return {
...state,
conversation,
messages,
chatTitle: conversation.title,
input: !state.input ? conversation.messages?.slice(-1)[0].content : state.input,
input: state.input,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ export const persistConversationChanges = async ({
const langChainMessages = getLangChainMessages(state.conversation.messages ?? []);
const messages = langChainMessages.slice(0, -1); // all but the last message
return {
...state,
conversation: state.conversation,
messages,
};
Expand All @@ -78,15 +77,14 @@ export const persistConversationChanges = async ({
});
if (!updatedConversation) {
logger.debug('Not updated conversation');
return { ...state, conversation: undefined, messages: [] };
return { conversation: undefined, messages: [] };
}

logger.debug(`conversationId: ${conversationId}`);
const langChainMessages = getLangChainMessages(updatedConversation.messages ?? []);
const messages = langChainMessages.slice(0, -1); // all but the last message

return {
...state,
conversation: updatedConversation,
messages,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,12 @@ export const runAgent = async ({
const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke(
{
...state,
messages: state.messages.splice(-1),
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
},
config
);

return {
...state,
agentOutcome,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export const shouldContinue = ({ logger, state }: ShouldContinueParams) => {
export const shouldContinueGenerateTitle = ({ logger, state }: ShouldContinueParams) => {
logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`);

if (state.conversation?.title !== NEW_CHAT) {
if (state.conversation?.title?.length && state.conversation?.title !== NEW_CHAT) {
return 'end';
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export const openAIFunctionAgentPrompt = ChatPromptTemplate.fromMessages([
export const bedrockToolCallingAgentPrompt = ChatPromptTemplate.fromMessages([
[
'system',
'You are a helpful assistant. ALWAYS use the provided tools. Use tools as often as possible, as they have access to the latest data and syntax.',
'You are a helpful assistant. ALWAYS use the provided tools. Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.',
],
['placeholder', '{chat_history}'],
['human', '{input}'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export const allowedExperimentalValues = Object.freeze({
/**
* Enables the Assistant BedrockChat Langchain model, introduced in `8.15.0`.
*/
assistantBedrockChat: false,
assistantBedrockChat: true,

/**
* Enables the Managed User section inside the new user details flyout.
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/stack_connectors/common/gemini/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export const InvokeAIRawActionParamsSchema = schema.object({
stopSequences: schema.maybe(schema.arrayOf(schema.string())),
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),
tools: schema.maybe(schema.arrayOf(schema.any())),
});

export const InvokeAIActionResponseSchema = schema.object({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ export class GeminiConnector extends SubActionConnector<Config, Secrets> {
temperature = 0,
signal,
timeout,
tools,
}: InvokeAIRawActionParams): Promise<InvokeAIRawActionResponse> {
const res = await this.runApi({
body: JSON.stringify(formatGeminiPayload(messages, temperature)),
body: JSON.stringify({ ...formatGeminiPayload(messages, temperature), tools }),
model,
signal,
timeout,
Expand Down
Loading