From d829826d94b9c0e0c9b414ed419e99d550cb5979 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Fri, 15 Mar 2024 20:18:30 +0100 Subject: [PATCH] refactor: start porting LangChain agent to LangGraph --- .../LangChainDemo.xcodeproj/project.pbxproj | 37 ++- .../xcshareddata/swiftpm/Package.resolved | 11 +- .../LangChainDemo/AgentExecutor.swift | 223 ++++++++++++++++++ LangChainDemo/LangChainDemo/ContentView.swift | 70 +++++- 4 files changed, 307 insertions(+), 34 deletions(-) create mode 100644 LangChainDemo/LangChainDemo/AgentExecutor.swift diff --git a/LangChainDemo/LangChainDemo.xcodeproj/project.pbxproj b/LangChainDemo/LangChainDemo.xcodeproj/project.pbxproj index 8267502..34c61d1 100644 --- a/LangChainDemo/LangChainDemo.xcodeproj/project.pbxproj +++ b/LangChainDemo/LangChainDemo.xcodeproj/project.pbxproj @@ -3,19 +3,21 @@ archiveVersion = 1; classes = { }; - objectVersion = 56; + objectVersion = 60; objects = { /* Begin PBXBuildFile section */ + A04EAC2B2BA441B500C49DC0 /* AgentExecutor.swift in Sources */ = {isa = PBXBuildFile; fileRef = A04EAC2A2BA441B500C49DC0 /* AgentExecutor.swift */; }; + A04EAC2E2BA4572E00C49DC0 /* LangChain in Frameworks */ = {isa = PBXBuildFile; productRef = A04EAC2D2BA4572E00C49DC0 /* LangChain */; }; A08CC7552BA373E9007A8248 /* LangChainDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = A08CC7542BA373E9007A8248 /* LangChainDemoApp.swift */; }; A08CC7572BA373E9007A8248 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = A08CC7562BA373E9007A8248 /* ContentView.swift */; }; A08CC7592BA373EA007A8248 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = A08CC7582BA373EA007A8248 /* Assets.xcassets */; }; A08CC75C2BA373EA007A8248 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = A08CC75B2BA373EA007A8248 /* Preview Assets.xcassets */; }; A08CC7652BA37494007A8248 /* LangGraph in Frameworks */ = {isa = PBXBuildFile; productRef = A08CC7642BA37494007A8248 /* LangGraph */; }; - A08CC7682BA3772E007A8248 /* LangChain in Frameworks */ = {isa = PBXBuildFile; productRef = A08CC7672BA3772E007A8248 /* LangChain */; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + A04EAC2A2BA441B500C49DC0 /* AgentExecutor.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AgentExecutor.swift; sourceTree = ""; }; A08CC7512BA373E9007A8248 /* LangChainDemo.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = LangChainDemo.app; sourceTree = BUILT_PRODUCTS_DIR; }; A08CC7542BA373E9007A8248 /* LangChainDemoApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LangChainDemoApp.swift; sourceTree = ""; }; A08CC7562BA373E9007A8248 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; @@ -29,7 +31,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - A08CC7682BA3772E007A8248 /* LangChain in Frameworks */, + A04EAC2E2BA4572E00C49DC0 /* LangChain in Frameworks */, A08CC7652BA37494007A8248 /* LangGraph in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; @@ -57,6 +59,7 @@ A08CC7532BA373E9007A8248 /* LangChainDemo */ = { isa = PBXGroup; children = ( + A04EAC2A2BA441B500C49DC0 /* AgentExecutor.swift */, A08CC7542BA373E9007A8248 /* LangChainDemoApp.swift */, A08CC7562BA373E9007A8248 /* ContentView.swift */, A08CC7582BA373EA007A8248 /* Assets.xcassets */, @@ -99,7 +102,7 @@ name = LangChainDemo; packageProductDependencies = ( A08CC7642BA37494007A8248 /* LangGraph */, - A08CC7672BA3772E007A8248 /* LangChain */, + A04EAC2D2BA4572E00C49DC0 /* LangChain */, ); productName = LangChainDemo; productReference = A08CC7512BA373E9007A8248 /* LangChainDemo.app */; @@ -130,7 +133,7 @@ ); mainGroup = A08CC7482BA373E9007A8248; packageReferences = ( - A08CC7662BA3772E007A8248 /* XCRemoteSwiftPackageReference "langchain-swift" */, + A04EAC2C2BA4572E00C49DC0 /* XCLocalSwiftPackageReference "../../langchain-swift" */, ); productRefGroup = A08CC7522BA373E9007A8248 /* Products */; projectDirPath = ""; @@ -160,6 +163,7 @@ files = ( A08CC7572BA373E9007A8248 /* ContentView.swift in Sources */, A08CC7552BA373E9007A8248 /* LangChainDemoApp.swift in Sources */, + A04EAC2B2BA441B500C49DC0 /* AgentExecutor.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -366,26 +370,21 @@ }; /* End XCConfigurationList section */ -/* Begin XCRemoteSwiftPackageReference section */ - A08CC7662BA3772E007A8248 /* XCRemoteSwiftPackageReference "langchain-swift" */ = { - isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/buhe/langchain-swift.git"; - requirement = { - kind = upToNextMajorVersion; - minimumVersion = 0.48.0; - }; +/* Begin XCLocalSwiftPackageReference section */ + A04EAC2C2BA4572E00C49DC0 /* XCLocalSwiftPackageReference "../../langchain-swift" */ = { + isa = XCLocalSwiftPackageReference; + relativePath = "../../langchain-swift"; }; -/* End XCRemoteSwiftPackageReference section */ +/* End XCLocalSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ - A08CC7642BA37494007A8248 /* LangGraph */ = { + A04EAC2D2BA4572E00C49DC0 /* LangChain */ = { isa = XCSwiftPackageProductDependency; - productName = LangGraph; + productName = LangChain; }; - A08CC7672BA3772E007A8248 /* LangChain */ = { + A08CC7642BA37494007A8248 /* LangGraph */ = { isa = XCSwiftPackageProductDependency; - package = A08CC7662BA3772E007A8248 /* XCRemoteSwiftPackageReference "langchain-swift" */; - productName = LangChain; + productName = LangGraph; }; /* End XCSwiftPackageProductDependency section */ }; diff --git a/LangChainDemo/LangChainDemo.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/LangChainDemo/LangChainDemo.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 2aa2e6e..0328c92 100644 --- a/LangChainDemo/LangChainDemo.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/LangChainDemo/LangChainDemo.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "9f967c268a37c7ace07b0924d8a430e8dca22c0e207c303bb411a059e4ad9863", + "originHash" : "b47df776245fa51fa189bff68b7fac0f33c7b9ada1ab0403cc1479ad43538ee3", "pins" : [ { "identity" : "async-http-client", @@ -73,15 +73,6 @@ "version" : "4.2.2" } }, - { - "identity" : "langchain-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/buhe/langchain-swift.git", - "state" : { - "revision" : "d2c945cfbb0b1b9bc571dc5121fbaa264bb42bd6", - "version" : "0.48.0" - } - }, { "identity" : "openai-kit", "kind" : "remoteSourceControl", diff --git a/LangChainDemo/LangChainDemo/AgentExecutor.swift b/LangChainDemo/LangChainDemo/AgentExecutor.swift new file mode 100644 index 0000000..b913080 --- /dev/null +++ b/LangChainDemo/LangChainDemo/AgentExecutor.swift @@ -0,0 +1,223 @@ +// +// AgentExecutor.swift +// LangChainDemo +// +// Created by bsorrentino on 14/03/24. +// + +import Foundation +import LangChain +import LangGraph + +struct AgentExecutorState : AgentState { + var data: [String : Any] + + init() { + data = ["intermediate_steps": AppendableValue()] + } + + init(_ initState: [String : Any]) { + data = initState + } + + var start:Double? { + value("start") + } + var input:String? { + value("input") + } + + var output:(LLMResult, Parsed)? { + value("output") + } + + var intermediate_steps: [(AgentAction, String)]? { + appendableValue("intermediate_steps" ) + } +} + +struct ToolOutputParser: BaseOutputParser { + public init() {} + public func parse(text: String) -> Parsed { + print(text.uppercased()) + let pattern = "Action\\s*:[\\s]*(.*)[\\s]*Action\\s*Input\\s*:[\\s]*(.*)" + let regex = try! NSRegularExpression(pattern: pattern) + + if let match = regex.firstMatch(in: text, options: [], range: NSRange(location: 0, length: text.utf16.count)) { + + let firstCaptureGroup = Range(match.range(at: 1), in: text).map { String(text[$0]) } +// print(firstCaptureGroup!) + + + let secondCaptureGroup = Range(match.range(at: 2), in: text).map { String(text[$0]) } +// print(secondCaptureGroup!) + return Parsed.action(AgentAction(action: firstCaptureGroup!, input: secondCaptureGroup!, log: text)) + } else { + if text.uppercased().contains(FINAL_ANSWER_ACTION) { + return Parsed.finish(AgentFinish(final: text)) + } + return Parsed.error + } + } +} + +public func runAgent( input: String, llm: LLM, tools: [BaseTool], callbacks: [BaseCallbackHandler] = []) async throws -> Void { + + let output_parser = ToolOutputParser() + let llm_chain = LLMChain(llm: llm, + prompt: ZeroShotAgent.create_prompt(tools: tools), + parser: output_parser, + stop: ["\nObservation: ", "\n\tObservation: "]) + let agent = ZeroShotAgent(llm_chain: llm_chain) + + let AGENT_REQ_ID = "agent_req_id" + + let chain_reqId = UUID().uuidString + + func take_next_step( input: String, intermediate_steps: [(AgentAction, String)]) async -> (Parsed, String) { + let step = await agent.plan(input: input, intermediate_steps: intermediate_steps) + switch step { + case .finish(let finish): + return (step, finish.final) + case .action(let action): + let tool = tools.filter{$0.name() == action.action}.first! + do { + print("try call \(tool.name()) tool.") + var observation = try await tool.run(args: action.input) + if observation.count > 1000 { + observation = String(observation.prefix(1000)) + } + return (step, observation) + } catch { + print("\(error.localizedDescription) at run \(tool.name()) tool.") + let observation = try! await InvalidTool(tool_name: tool.name()).run(args: action.input) + return (step, observation) + } + default: + return (step, "fail") + } + } + + + func callEnd(output: String, reqId: String, cost: Double) { + for callback in callbacks { + do { + try callback.on_chain_end(output: output, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"]) + } catch { + print("call chain end callback errer: \(error)") + } + } + } + + func callStart(prompt: String, reqId: String) { + for callback in callbacks { + do { + try callback.on_chain_start(prompts: prompt, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId]) + } catch { + print("call chain end callback errer: \(error)") + } + } + } + + func callCatch(error: Error, reqId: String, cost: Double) { + for callback in callbacks { + do { + try callback.on_chain_error(error: error, metadata: [DefaultChain.CHAIN_REQ_ID_KEY: reqId, DefaultChain.CHAIN_COST_KEY: "\(cost)"]) + } catch { + print("call LLM start callback errer: \(error)") + } + } + } + + let workflow = GraphState( stateType: AgentExecutorState.self ) + + try workflow.addNode( "call_start" ) { state in + + guard let prompt = state.input else { + throw GraphRunnerError.executionError("'inputs' argument not found!") + } + + callStart(prompt: prompt, reqId: chain_reqId) + + return ["start": Date.now.timeIntervalSince1970] + } + + try workflow.addNode( "call_end" ) { state in + + guard let output = state.output else { + throw GraphRunnerError.executionError("'output' argument not found!") + } + guard let start = state.start else { + throw GraphRunnerError.executionError("'start' argument not found!") + } + + let cost = Date.now.timeIntervalSince1970 - start + + + callEnd(output: output.0.llm_output ?? "", reqId: chain_reqId, cost: cost) + + return [:] + } + + try workflow.addNode("call_agent" ) { state in + + guard let input = state.input else { + throw GraphRunnerError.executionError("'inputs' argument not found in state!") + } + guard let intermediate_steps = state.intermediate_steps else { + throw GraphRunnerError.executionError("'intermediate_steps' property not found in state!") + } + + let agent_reqId = UUID().uuidString + do { + for callback in callbacks { + try callback.on_agent_start(prompt: input, metadata: [AGENT_REQ_ID: agent_reqId]) + } + } catch { + print( "call agent start callback error: \(error)") + } + + let result = await take_next_step(input: input, intermediate_steps: intermediate_steps) + + switch result.0 { + case .finish(let finish): + print("Found final answer.") + do { + for callback in callbacks { + try callback.on_agent_finish(action: finish, metadata: [AGENT_REQ_ID: agent_reqId]) + } + } catch { + print( "call chain end callback error: \(error)") + } + return [ "output": (LLMResult(llm_output: result.1), Parsed.str(result.1)) ] + case .action(let action): + do { + for callback in callbacks { + try callback.on_agent_action(action: action, metadata: [AGENT_REQ_ID: agent_reqId]) + } + } catch { + print( "call chain end callback error: \(error)") + } + return [ "intermediate_steps" : (action, result.1) ] + default: + throw GraphRunnerError.executionError( "Parsed.error" ) + } + } + + try workflow.setEntryPoint("call_start") + workflow.setFinishPoint("call_end") + + try workflow.addEdge(sourceId: "call_start", targetId: "call_agent") + try workflow.addConditionalEdge( sourceId: "call_agent", condition: { state in + return "terminate" + }, edgeMapping: [ + "continue" : "call_agent", + "terminate": "call_end"]) + + + let runner = try workflow.compile() + + let result = try await runner.invoke(inputs: [ "input": input ]) + + print( result ) +} diff --git a/LangChainDemo/LangChainDemo/ContentView.swift b/LangChainDemo/LangChainDemo/ContentView.swift index 1b2393e..bf357d1 100644 --- a/LangChainDemo/LangChainDemo/ContentView.swift +++ b/LangChainDemo/LangChainDemo/ContentView.swift @@ -6,17 +6,77 @@ // import SwiftUI +import LangChain +import OpenAIKit +import AsyncHTTPClient + +class Callback : BaseCallbackHandler { + + override func on_tool_start(tool: BaseTool, input: String, metadata: [String: String]) throws { + + print( "on_tool_start", tool.name()) + } +} struct ContentView: View { + + @State var openai_api_key: String = "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" + @State var input:String = "perform a test call" + @State var progress: String = "" + var body: some View { - VStack { - Image(systemName: "globe") - .imageScale(.large) - .foregroundStyle(.tint) - Text("Hello, world!") + VStack(alignment: .center) { + + TextField(text: $openai_api_key, + label: { Label("OPENAI API KEY", systemImage: "bolt.fill") }) + Divider() + + TextField(text: $input, + label: { Label("PROMPT", systemImage: "bolt.fill") }) + + Button( action: executeAgent, label: { + Label("EXECUTE", systemImage: "bolt.fill") + }) + Divider() + Text( progress ) } .padding() } + + @MainActor + func setProgress( _ msg: String ) { + progress = msg + } + + func executeAgent() { + Env.initSet(["OPENAI_API_KEY": openai_api_key]) + + Task { + do { + let httpClient = HTTPClient() + defer { + // it's important to shutdown the httpClient after all requests are done, even if one failed. See: https://github.com/swift-server/async-http-client + try? httpClient.syncShutdown() + } + + let llm = ChatOpenAI( httpClient: httpClient, model: Model.GPT3.gpt3_5Turbo_0125, callbacks: [ Callback() ]) + +// let agent = initialize_agent(llm: llm, tools: [Dummy(), JavascriptREPLTool(), TTSTool()], callbacks: [ Callback() ]) +// +// print( await agent.run(args: input) ) + + try await runAgent(input: input, + llm: llm, + tools: [Dummy(), JavascriptREPLTool(), TTSTool()], + callbacks: [ Callback() ]) + } + catch { + await setProgress("ERROR: \(error)") + } + } + + } + } #Preview {