Skip to content

Commit

Permalink
feat: complete refactory on LangChain AgentExecutor in LangGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Mar 16, 2024
1 parent 17b16a5 commit 32b460c
Showing 1 changed file with 129 additions and 111 deletions.
240 changes: 129 additions & 111 deletions LangChainDemo/LangChainDemo/AgentExecutor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,50 @@ import Foundation
import LangChain
import LangGraph

enum AgentOutcome /* Union */ {
case action(AgentAction)
case finish(AgentFinish)
}


struct AgentExecutorState : AgentState {
var data: [String : Any]

init() {
data = ["intermediate_steps": AppendableValue()]
self.init([
"intermediate_steps": AppendableValue(),
"chat_history": AppendableValue()
])
}

init(_ initState: [String : Any]) {
data = initState
}

var start:Double? {
value("start")
}

// from langchain
var input:String? {
value("input")
}

var chatHistory:[BaseMessage]? {
appendableValue("chat_history" )
}

var output:(LLMResult, Parsed)? {
value("output")
var agentOutcome:AgentOutcome? {
return value("agent_outcome")
}

var intermediate_steps: [(AgentAction, String)]? {
appendableValue("intermediate_steps" )
}

// Tracing
var start:Double? {
value("start")
}
var cost:Double? {
value("cost")
}
}

struct ToolOutputParser: BaseOutputParser {
Expand Down Expand Up @@ -63,161 +82,160 @@ struct ToolOutputParser: BaseOutputParser {

public func runAgent( input: String, llm: LLM, tools: [BaseTool], callbacks: [BaseCallbackHandler] = []) async throws -> Void {

let output_parser = ToolOutputParser()
let llm_chain = LLMChain(llm: llm,
prompt: ZeroShotAgent.create_prompt(tools: tools),
parser: output_parser,
stop: ["\nObservation: ", "\n\tObservation: "])
let agent = ZeroShotAgent(llm_chain: llm_chain)

let AGENT_REQ_ID = "agent_req_id"

let chain_reqId = UUID().uuidString
let agent_reqId = UUID().uuidString

let agent = {
let output_parser = ToolOutputParser()
let llm_chain = LLMChain(llm: llm,
prompt: ZeroShotAgent.create_prompt(tools: tools),
parser: output_parser,
stop: ["\nObservation: ", "\n\tObservation: "])
return ZeroShotAgent(llm_chain: llm_chain)

func take_next_step( input: String, intermediate_steps: [(AgentAction, String)]) async -> (Parsed, String) {
let step = await agent.plan(input: input, intermediate_steps: intermediate_steps)
switch step {
case .finish(let finish):
return (step, finish.final)
case .action(let action):
let tool = tools.filter{$0.name() == action.action}.first!
do {
print("try call \(tool.name()) tool.")
var observation = try await tool.run(args: action.input)
if observation.count > 1000 {
observation = String(observation.prefix(1000))
}
return (step, observation)
} catch {
print("\(error.localizedDescription) at run \(tool.name()) tool.")
let observation = try! await InvalidTool(tool_name: tool.name()).run(args: action.input)
return (step, observation)
}()

let toolExecutor = { (action: AgentAction) in
guard let tool = tools.filter({$0.name() == action.action}).first else {
throw GraphRunnerError.executionError("tool \(action.action) not found!")
}

do {

print("try call \(tool.name()) tool.")
var observation = try await tool.run(args: action.input)
if observation.count > 1000 {
observation = String(observation.prefix(1000))
}
default:
return (step, "fail")
return observation
} catch {
print("\(error.localizedDescription) at run \(tool.name()) tool.")
let observation = try! await InvalidTool(tool_name: tool.name()).run(args: action.input)
return observation
}

}


func callEnd(output: String, reqId: String, cost: Double) {
for callback in callbacks {
do {
try callback.on_chain_end(output: output, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"])
} catch {
print("call chain end callback errer: \(error)")


let onAgentStart = { (input: String ) in
do {
for callback in callbacks {
try callback.on_agent_start(prompt: input, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call on_agent_start callback error: \(error)")
}
}

func callStart(prompt: String, reqId: String) {
for callback in callbacks {
do {
try callback.on_chain_start(prompts: prompt, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId])
} catch {
print("call chain end callback errer: \(error)")

let onAgentAction = { (action: AgentAction ) in
do {
for callback in callbacks {
try callback.on_agent_action(action: action, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call on_agent_action callback error: \(error)")
}

}

func callCatch(error: Error, reqId: String, cost: Double) {
for callback in callbacks {
do {
try callback.on_chain_error(error: error, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"])
} catch {
print("call LLM start callback errer: \(error)")
let onAgentFinish = { (action: AgentFinish ) in
do {
for callback in callbacks {
try callback.on_agent_finish(action: action, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call on_agent_finish callback error: \(error)")
}
}

let workflow = GraphState( stateType: AgentExecutorState.self )

let workflow = GraphState {
AgentExecutorState()
}

try workflow.addNode( "call_start" ) { state in

guard let prompt = state.input else {
throw GraphRunnerError.executionError("'inputs' argument not found!")
}

callStart(prompt: prompt, reqId: chain_reqId)

return ["start": Date.now.timeIntervalSince1970]
["start": Date.now.timeIntervalSince1970]
}

try workflow.addNode( "call_end" ) { state in

guard let output = state.output else {
throw GraphRunnerError.executionError("'output' argument not found!")
}
guard let start = state.start else {
throw GraphRunnerError.executionError("'start' argument not found!")
}

let cost = Date.now.timeIntervalSince1970 - start
var cost:Double = 0

if let start = state.start {
cost = Date.now.timeIntervalSince1970 - start

}

callEnd(output: output.0.llm_output ?? "", reqId: chain_reqId, cost: cost)

return [:]
return [ "cost": cost ]
}

try workflow.addNode("call_agent" ) { state in

guard let input = state.input else {
throw GraphRunnerError.executionError("'inputs' argument not found in state!")
throw GraphRunnerError.executionError("'input' argument not found in state!")
}
guard let intermediate_steps = state.intermediate_steps else {
throw GraphRunnerError.executionError("'intermediate_steps' property not found in state!")
}

let agent_reqId = UUID().uuidString
do {
for callback in callbacks {
try callback.on_agent_start(prompt: input, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call agent start callback error: \(error)")
onAgentStart( input )
let step = await agent.plan(input: input, intermediate_steps: intermediate_steps)
switch( step ) {
case .finish( let finish ):
onAgentFinish( finish )
return [ "agent_outcome": AgentOutcome.finish(finish) ]
case .action( let action ):
onAgentAction( action )
return [ "agent_outcome": AgentOutcome.action(action) ]
default:
throw GraphRunnerError.executionError( "Parsed.error" )
}
}

let result = await take_next_step(input: input, intermediate_steps: intermediate_steps)
try workflow.addNode("call_action" ) { state in

switch result.0 {
case .finish(let finish):
print("Found final answer.")
do {
for callback in callbacks {
try callback.on_agent_finish(action: finish, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call chain end callback error: \(error)")
}
return [ "output": (LLMResult(llm_output: result.1), Parsed.str(result.1)) ]
case .action(let action):
do {
for callback in callbacks {
try callback.on_agent_action(action: action, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call chain end callback error: \(error)")
}
return [ "intermediate_steps" : (action, result.1) ]
default:
throw GraphRunnerError.executionError( "Parsed.error" )
guard let agentOutcome = state.agentOutcome else {
throw GraphRunnerError.executionError("'agent_outcome' property not found in state!")
}

guard case .action(let action) = agentOutcome else {
throw GraphRunnerError.executionError("'agent_outcome' is not an action!")
}

let result = try await toolExecutor( action )

return [ "intermediate_steps" : (action, result) ]
}

try workflow.setEntryPoint("call_start")
workflow.setFinishPoint("call_end")

try workflow.addEdge(sourceId: "call_start", targetId: "call_agent")
try workflow.addEdge(sourceId: "call_action", targetId: "call_agent")
try workflow.addConditionalEdge( sourceId: "call_agent", condition: { state in
return "terminate"

guard let agentOutcome = state.agentOutcome else {
throw GraphRunnerError.executionError("'agent_outcome' property not found in state!")
}

switch agentOutcome {
case .finish:
return "finish"
case .action:
return "continue"
}

}, edgeMapping: [
"continue" : "call_agent",
"terminate": "call_end"])
"continue" : "call_action",
"finish": "call_end"])


let runner = try workflow.compile()

let result = try await runner.invoke(inputs: [ "input": input ])
let result = try await runner.invoke(inputs: [ "input": input, "chat_history": [] ])

print( result )
}

0 comments on commit 32b460c

Please sign in to comment.