From 32b460cbd5db3ecf9b399d58c60375391efa8bf0 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Sat, 16 Mar 2024 16:31:42 +0100 Subject: [PATCH] feat: complete refactory on LangChain AgentExecutor in LangGraph --- .../LangChainDemo/AgentExecutor.swift | 240 ++++++++++-------- 1 file changed, 129 insertions(+), 111 deletions(-) diff --git a/LangChainDemo/LangChainDemo/AgentExecutor.swift b/LangChainDemo/LangChainDemo/AgentExecutor.swift index b913080..9104152 100644 --- a/LangChainDemo/LangChainDemo/AgentExecutor.swift +++ b/LangChainDemo/LangChainDemo/AgentExecutor.swift @@ -9,31 +9,50 @@ import Foundation import LangChain import LangGraph +enum AgentOutcome /* Union */ { + case action(AgentAction) + case finish(AgentFinish) +} + + struct AgentExecutorState : AgentState { var data: [String : Any] init() { - data = ["intermediate_steps": AppendableValue()] + self.init([ + "intermediate_steps": AppendableValue(), + "chat_history": AppendableValue() + ]) } init(_ initState: [String : Any]) { data = initState } - - var start:Double? { - value("start") - } + + // from langchain var input:String? { value("input") } + + var chatHistory:[BaseMessage]? { + appendableValue("chat_history" ) + } - var output:(LLMResult, Parsed)? { - value("output") + var agentOutcome:AgentOutcome? { + return value("agent_outcome") } - + var intermediate_steps: [(AgentAction, String)]? { appendableValue("intermediate_steps" ) } + + // Tracing + var start:Double? { + value("start") + } + var cost:Double? { + value("cost") + } } struct ToolOutputParser: BaseOutputParser { @@ -63,161 +82,160 @@ struct ToolOutputParser: BaseOutputParser { public func runAgent( input: String, llm: LLM, tools: [BaseTool], callbacks: [BaseCallbackHandler] = []) async throws -> Void { - let output_parser = ToolOutputParser() - let llm_chain = LLMChain(llm: llm, - prompt: ZeroShotAgent.create_prompt(tools: tools), - parser: output_parser, - stop: ["\nObservation: ", "\n\tObservation: "]) - let agent = ZeroShotAgent(llm_chain: llm_chain) let AGENT_REQ_ID = "agent_req_id" - let chain_reqId = UUID().uuidString + let agent_reqId = UUID().uuidString + + let agent = { + let output_parser = ToolOutputParser() + let llm_chain = LLMChain(llm: llm, + prompt: ZeroShotAgent.create_prompt(tools: tools), + parser: output_parser, + stop: ["\nObservation: ", "\n\tObservation: "]) + return ZeroShotAgent(llm_chain: llm_chain) - func take_next_step( input: String, intermediate_steps: [(AgentAction, String)]) async -> (Parsed, String) { - let step = await agent.plan(input: input, intermediate_steps: intermediate_steps) - switch step { - case .finish(let finish): - return (step, finish.final) - case .action(let action): - let tool = tools.filter{$0.name() == action.action}.first! - do { - print("try call \(tool.name()) tool.") - var observation = try await tool.run(args: action.input) - if observation.count > 1000 { - observation = String(observation.prefix(1000)) - } - return (step, observation) - } catch { - print("\(error.localizedDescription) at run \(tool.name()) tool.") - let observation = try! await InvalidTool(tool_name: tool.name()).run(args: action.input) - return (step, observation) + }() + + let toolExecutor = { (action: AgentAction) in + guard let tool = tools.filter({$0.name() == action.action}).first else { + throw GraphRunnerError.executionError("tool \(action.action) not found!") + } + + do { + + print("try call \(tool.name()) tool.") + var observation = try await tool.run(args: action.input) + if observation.count > 1000 { + observation = String(observation.prefix(1000)) } - default: - return (step, "fail") + return observation + } catch { + print("\(error.localizedDescription) at run \(tool.name()) tool.") + let observation = try! await InvalidTool(tool_name: tool.name()).run(args: action.input) + return observation } + } - - - func callEnd(output: String, reqId: String, cost: Double) { - for callback in callbacks { - do { - try callback.on_chain_end(output: output, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"]) - } catch { - print("call chain end callback errer: \(error)") + + + let onAgentStart = { (input: String ) in + do { + for callback in callbacks { + try callback.on_agent_start(prompt: input, metadata: [AGENT_REQ_ID: agent_reqId]) } + } catch { + print( "call on_agent_start callback error: \(error)") } } - - func callStart(prompt: String, reqId: String) { - for callback in callbacks { - do { - try callback.on_chain_start(prompts: prompt, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId]) - } catch { - print("call chain end callback errer: \(error)") + + let onAgentAction = { (action: AgentAction ) in + do { + for callback in callbacks { + try callback.on_agent_action(action: action, metadata: [AGENT_REQ_ID: agent_reqId]) } + } catch { + print( "call on_agent_action callback error: \(error)") } + } - func callCatch(error: Error, reqId: String, cost: Double) { - for callback in callbacks { - do { - try callback.on_chain_error(error: error, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"]) - } catch { - print("call LLM start callback errer: \(error)") + let onAgentFinish = { (action: AgentFinish ) in + do { + for callback in callbacks { + try callback.on_agent_finish(action: action, metadata: [AGENT_REQ_ID: agent_reqId]) } + } catch { + print( "call on_agent_finish callback error: \(error)") } } - let workflow = GraphState( stateType: AgentExecutorState.self ) + + let workflow = GraphState { + AgentExecutorState() + } try workflow.addNode( "call_start" ) { state in - guard let prompt = state.input else { - throw GraphRunnerError.executionError("'inputs' argument not found!") - } - - callStart(prompt: prompt, reqId: chain_reqId) - - return ["start": Date.now.timeIntervalSince1970] + ["start": Date.now.timeIntervalSince1970] } try workflow.addNode( "call_end" ) { state in - guard let output = state.output else { - throw GraphRunnerError.executionError("'output' argument not found!") - } - guard let start = state.start else { - throw GraphRunnerError.executionError("'start' argument not found!") - } - - let cost = Date.now.timeIntervalSince1970 - start + var cost:Double = 0 + if let start = state.start { + cost = Date.now.timeIntervalSince1970 - start + + } - callEnd(output: output.0.llm_output ?? "", reqId: chain_reqId, cost: cost) - - return [:] + return [ "cost": cost ] } try workflow.addNode("call_agent" ) { state in guard let input = state.input else { - throw GraphRunnerError.executionError("'inputs' argument not found in state!") + throw GraphRunnerError.executionError("'input' argument not found in state!") } guard let intermediate_steps = state.intermediate_steps else { throw GraphRunnerError.executionError("'intermediate_steps' property not found in state!") } - let agent_reqId = UUID().uuidString - do { - for callback in callbacks { - try callback.on_agent_start(prompt: input, metadata: [AGENT_REQ_ID: agent_reqId]) - } - } catch { - print( "call agent start callback error: \(error)") + onAgentStart( input ) + let step = await agent.plan(input: input, intermediate_steps: intermediate_steps) + switch( step ) { + case .finish( let finish ): + onAgentFinish( finish ) + return [ "agent_outcome": AgentOutcome.finish(finish) ] + case .action( let action ): + onAgentAction( action ) + return [ "agent_outcome": AgentOutcome.action(action) ] + default: + throw GraphRunnerError.executionError( "Parsed.error" ) } + } - let result = await take_next_step(input: input, intermediate_steps: intermediate_steps) + try workflow.addNode("call_action" ) { state in - switch result.0 { - case .finish(let finish): - print("Found final answer.") - do { - for callback in callbacks { - try callback.on_agent_finish(action: finish, metadata: [AGENT_REQ_ID: agent_reqId]) - } - } catch { - print( "call chain end callback error: \(error)") - } - return [ "output": (LLMResult(llm_output: result.1), Parsed.str(result.1)) ] - case .action(let action): - do { - for callback in callbacks { - try callback.on_agent_action(action: action, metadata: [AGENT_REQ_ID: agent_reqId]) - } - } catch { - print( "call chain end callback error: \(error)") - } - return [ "intermediate_steps" : (action, result.1) ] - default: - throw GraphRunnerError.executionError( "Parsed.error" ) + guard let agentOutcome = state.agentOutcome else { + throw GraphRunnerError.executionError("'agent_outcome' property not found in state!") + } + + guard case .action(let action) = agentOutcome else { + throw GraphRunnerError.executionError("'agent_outcome' is not an action!") } + + let result = try await toolExecutor( action ) + + return [ "intermediate_steps" : (action, result) ] } - + try workflow.setEntryPoint("call_start") workflow.setFinishPoint("call_end") try workflow.addEdge(sourceId: "call_start", targetId: "call_agent") + try workflow.addEdge(sourceId: "call_action", targetId: "call_agent") try workflow.addConditionalEdge( sourceId: "call_agent", condition: { state in - return "terminate" + + guard let agentOutcome = state.agentOutcome else { + throw GraphRunnerError.executionError("'agent_outcome' property not found in state!") + } + + switch agentOutcome { + case .finish: + return "finish" + case .action: + return "continue" + } + }, edgeMapping: [ - "continue" : "call_agent", - "terminate": "call_end"]) + "continue" : "call_action", + "finish": "call_end"]) let runner = try workflow.compile() - let result = try await runner.invoke(inputs: [ "input": input ]) + let result = try await runner.invoke(inputs: [ "input": input, "chat_history": [] ]) print( result ) }