Skip to content

Commit

Permalink
refactor: start porting LangChain agent to LangGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Mar 15, 2024
1 parent 4cb667e commit d829826
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 34 deletions.
37 changes: 18 additions & 19 deletions LangChainDemo/LangChainDemo.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
archiveVersion = 1;
classes = {
};
objectVersion = 56;
objectVersion = 60;
objects = {

/* Begin PBXBuildFile section */
A04EAC2B2BA441B500C49DC0 /* AgentExecutor.swift in Sources */ = {isa = PBXBuildFile; fileRef = A04EAC2A2BA441B500C49DC0 /* AgentExecutor.swift */; };
A04EAC2E2BA4572E00C49DC0 /* LangChain in Frameworks */ = {isa = PBXBuildFile; productRef = A04EAC2D2BA4572E00C49DC0 /* LangChain */; };
A08CC7552BA373E9007A8248 /* LangChainDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = A08CC7542BA373E9007A8248 /* LangChainDemoApp.swift */; };
A08CC7572BA373E9007A8248 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = A08CC7562BA373E9007A8248 /* ContentView.swift */; };
A08CC7592BA373EA007A8248 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = A08CC7582BA373EA007A8248 /* Assets.xcassets */; };
A08CC75C2BA373EA007A8248 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = A08CC75B2BA373EA007A8248 /* Preview Assets.xcassets */; };
A08CC7652BA37494007A8248 /* LangGraph in Frameworks */ = {isa = PBXBuildFile; productRef = A08CC7642BA37494007A8248 /* LangGraph */; };
A08CC7682BA3772E007A8248 /* LangChain in Frameworks */ = {isa = PBXBuildFile; productRef = A08CC7672BA3772E007A8248 /* LangChain */; };
/* End PBXBuildFile section */

/* Begin PBXFileReference section */
A04EAC2A2BA441B500C49DC0 /* AgentExecutor.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AgentExecutor.swift; sourceTree = "<group>"; };
A08CC7512BA373E9007A8248 /* LangChainDemo.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = LangChainDemo.app; sourceTree = BUILT_PRODUCTS_DIR; };
A08CC7542BA373E9007A8248 /* LangChainDemoApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LangChainDemoApp.swift; sourceTree = "<group>"; };
A08CC7562BA373E9007A8248 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
Expand All @@ -29,7 +31,7 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
A08CC7682BA3772E007A8248 /* LangChain in Frameworks */,
A04EAC2E2BA4572E00C49DC0 /* LangChain in Frameworks */,
A08CC7652BA37494007A8248 /* LangGraph in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand Down Expand Up @@ -57,6 +59,7 @@
A08CC7532BA373E9007A8248 /* LangChainDemo */ = {
isa = PBXGroup;
children = (
A04EAC2A2BA441B500C49DC0 /* AgentExecutor.swift */,
A08CC7542BA373E9007A8248 /* LangChainDemoApp.swift */,
A08CC7562BA373E9007A8248 /* ContentView.swift */,
A08CC7582BA373EA007A8248 /* Assets.xcassets */,
Expand Down Expand Up @@ -99,7 +102,7 @@
name = LangChainDemo;
packageProductDependencies = (
A08CC7642BA37494007A8248 /* LangGraph */,
A08CC7672BA3772E007A8248 /* LangChain */,
A04EAC2D2BA4572E00C49DC0 /* LangChain */,
);
productName = LangChainDemo;
productReference = A08CC7512BA373E9007A8248 /* LangChainDemo.app */;
Expand Down Expand Up @@ -130,7 +133,7 @@
);
mainGroup = A08CC7482BA373E9007A8248;
packageReferences = (
A08CC7662BA3772E007A8248 /* XCRemoteSwiftPackageReference "langchain-swift" */,
A04EAC2C2BA4572E00C49DC0 /* XCLocalSwiftPackageReference "../../langchain-swift" */,
);
productRefGroup = A08CC7522BA373E9007A8248 /* Products */;
projectDirPath = "";
Expand Down Expand Up @@ -160,6 +163,7 @@
files = (
A08CC7572BA373E9007A8248 /* ContentView.swift in Sources */,
A08CC7552BA373E9007A8248 /* LangChainDemoApp.swift in Sources */,
A04EAC2B2BA441B500C49DC0 /* AgentExecutor.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -366,26 +370,21 @@
};
/* End XCConfigurationList section */

/* Begin XCRemoteSwiftPackageReference section */
A08CC7662BA3772E007A8248 /* XCRemoteSwiftPackageReference "langchain-swift" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/buhe/langchain-swift.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 0.48.0;
};
/* Begin XCLocalSwiftPackageReference section */
A04EAC2C2BA4572E00C49DC0 /* XCLocalSwiftPackageReference "../../langchain-swift" */ = {
isa = XCLocalSwiftPackageReference;
relativePath = "../../langchain-swift";
};
/* End XCRemoteSwiftPackageReference section */
/* End XCLocalSwiftPackageReference section */

/* Begin XCSwiftPackageProductDependency section */
A08CC7642BA37494007A8248 /* LangGraph */ = {
A04EAC2D2BA4572E00C49DC0 /* LangChain */ = {
isa = XCSwiftPackageProductDependency;
productName = LangGraph;
productName = LangChain;
};
A08CC7672BA3772E007A8248 /* LangChain */ = {
A08CC7642BA37494007A8248 /* LangGraph */ = {
isa = XCSwiftPackageProductDependency;
package = A08CC7662BA3772E007A8248 /* XCRemoteSwiftPackageReference "langchain-swift" */;
productName = LangChain;
productName = LangGraph;
};
/* End XCSwiftPackageProductDependency section */
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"originHash" : "9f967c268a37c7ace07b0924d8a430e8dca22c0e207c303bb411a059e4ad9863",
"originHash" : "b47df776245fa51fa189bff68b7fac0f33c7b9ada1ab0403cc1479ad43538ee3",
"pins" : [
{
"identity" : "async-http-client",
Expand Down Expand Up @@ -73,15 +73,6 @@
"version" : "4.2.2"
}
},
{
"identity" : "langchain-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/buhe/langchain-swift.git",
"state" : {
"revision" : "d2c945cfbb0b1b9bc571dc5121fbaa264bb42bd6",
"version" : "0.48.0"
}
},
{
"identity" : "openai-kit",
"kind" : "remoteSourceControl",
Expand Down
223 changes: 223 additions & 0 deletions LangChainDemo/LangChainDemo/AgentExecutor.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
//
// AgentExecutor.swift
// LangChainDemo
//
// Created by bsorrentino on 14/03/24.
//

import Foundation
import LangChain
import LangGraph

struct AgentExecutorState : AgentState {
var data: [String : Any]

init() {
data = ["intermediate_steps": AppendableValue()]
}

init(_ initState: [String : Any]) {
data = initState
}

var start:Double? {
value("start")
}
var input:String? {
value("input")
}

var output:(LLMResult, Parsed)? {
value("output")
}

var intermediate_steps: [(AgentAction, String)]? {
appendableValue("intermediate_steps" )
}
}

struct ToolOutputParser: BaseOutputParser {
public init() {}
public func parse(text: String) -> Parsed {
print(text.uppercased())
let pattern = "Action\\s*:[\\s]*(.*)[\\s]*Action\\s*Input\\s*:[\\s]*(.*)"
let regex = try! NSRegularExpression(pattern: pattern)

if let match = regex.firstMatch(in: text, options: [], range: NSRange(location: 0, length: text.utf16.count)) {

let firstCaptureGroup = Range(match.range(at: 1), in: text).map { String(text[$0]) }
// print(firstCaptureGroup!)


let secondCaptureGroup = Range(match.range(at: 2), in: text).map { String(text[$0]) }
// print(secondCaptureGroup!)
return Parsed.action(AgentAction(action: firstCaptureGroup!, input: secondCaptureGroup!, log: text))
} else {
if text.uppercased().contains(FINAL_ANSWER_ACTION) {
return Parsed.finish(AgentFinish(final: text))
}
return Parsed.error
}
}
}

public func runAgent( input: String, llm: LLM, tools: [BaseTool], callbacks: [BaseCallbackHandler] = []) async throws -> Void {

let output_parser = ToolOutputParser()
let llm_chain = LLMChain(llm: llm,
prompt: ZeroShotAgent.create_prompt(tools: tools),
parser: output_parser,
stop: ["\nObservation: ", "\n\tObservation: "])
let agent = ZeroShotAgent(llm_chain: llm_chain)

let AGENT_REQ_ID = "agent_req_id"

let chain_reqId = UUID().uuidString

func take_next_step( input: String, intermediate_steps: [(AgentAction, String)]) async -> (Parsed, String) {
let step = await agent.plan(input: input, intermediate_steps: intermediate_steps)
switch step {
case .finish(let finish):
return (step, finish.final)
case .action(let action):
let tool = tools.filter{$0.name() == action.action}.first!
do {
print("try call \(tool.name()) tool.")
var observation = try await tool.run(args: action.input)
if observation.count > 1000 {
observation = String(observation.prefix(1000))
}
return (step, observation)
} catch {
print("\(error.localizedDescription) at run \(tool.name()) tool.")
let observation = try! await InvalidTool(tool_name: tool.name()).run(args: action.input)
return (step, observation)
}
default:
return (step, "fail")
}
}


func callEnd(output: String, reqId: String, cost: Double) {
for callback in callbacks {
do {
try callback.on_chain_end(output: output, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"])
} catch {
print("call chain end callback errer: \(error)")
}
}
}

func callStart(prompt: String, reqId: String) {
for callback in callbacks {
do {
try callback.on_chain_start(prompts: prompt, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId])
} catch {
print("call chain end callback errer: \(error)")
}
}
}

func callCatch(error: Error, reqId: String, cost: Double) {
for callback in callbacks {
do {
try callback.on_chain_error(error: error, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"])
} catch {
print("call LLM start callback errer: \(error)")
}
}
}

let workflow = GraphState( stateType: AgentExecutorState.self )

try workflow.addNode( "call_start" ) { state in

guard let prompt = state.input else {
throw GraphRunnerError.executionError("'inputs' argument not found!")
}

callStart(prompt: prompt, reqId: chain_reqId)

return ["start": Date.now.timeIntervalSince1970]
}

try workflow.addNode( "call_end" ) { state in

guard let output = state.output else {
throw GraphRunnerError.executionError("'output' argument not found!")
}
guard let start = state.start else {
throw GraphRunnerError.executionError("'start' argument not found!")
}

let cost = Date.now.timeIntervalSince1970 - start


callEnd(output: output.0.llm_output ?? "", reqId: chain_reqId, cost: cost)

return [:]
}

try workflow.addNode("call_agent" ) { state in

guard let input = state.input else {
throw GraphRunnerError.executionError("'inputs' argument not found in state!")
}
guard let intermediate_steps = state.intermediate_steps else {
throw GraphRunnerError.executionError("'intermediate_steps' property not found in state!")
}

let agent_reqId = UUID().uuidString
do {
for callback in callbacks {
try callback.on_agent_start(prompt: input, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call agent start callback error: \(error)")
}

let result = await take_next_step(input: input, intermediate_steps: intermediate_steps)

switch result.0 {
case .finish(let finish):
print("Found final answer.")
do {
for callback in callbacks {
try callback.on_agent_finish(action: finish, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call chain end callback error: \(error)")
}
return [ "output": (LLMResult(llm_output: result.1), Parsed.str(result.1)) ]
case .action(let action):
do {
for callback in callbacks {
try callback.on_agent_action(action: action, metadata: [AGENT_REQ_ID: agent_reqId])
}
} catch {
print( "call chain end callback error: \(error)")
}
return [ "intermediate_steps" : (action, result.1) ]
default:
throw GraphRunnerError.executionError( "Parsed.error" )
}
}

try workflow.setEntryPoint("call_start")
workflow.setFinishPoint("call_end")

try workflow.addEdge(sourceId: "call_start", targetId: "call_agent")
try workflow.addConditionalEdge( sourceId: "call_agent", condition: { state in
return "terminate"
}, edgeMapping: [
"continue" : "call_agent",
"terminate": "call_end"])


let runner = try workflow.compile()

let result = try await runner.invoke(inputs: [ "input": input ])

print( result )
}
Loading

0 comments on commit d829826

Please sign in to comment.