From 74a4768633e8cafab52940a2e75515e1afac5a2b Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Sun, 4 Aug 2024 13:16:59 +0200 Subject: [PATCH] feat: refine Channels management introducing State Schema concept --- Sources/LangGraph/LangGraph.swift | 105 +++++++++++++--------- Tests/LangGraphTests/LangGraphTests.swift | 27 ++---- channels.playground/Contents.swift | 27 ++++++ channels.playground/contents.xcplayground | 4 + 4 files changed, 101 insertions(+), 62 deletions(-) create mode 100644 channels.playground/Contents.swift create mode 100644 channels.playground/contents.xcplayground diff --git a/Sources/LangGraph/LangGraph.swift b/Sources/LangGraph/LangGraph.swift index baca080..0d85358 100644 --- a/Sources/LangGraph/LangGraph.swift +++ b/Sources/LangGraph/LangGraph.swift @@ -1,64 +1,77 @@ import OSLog +public typealias PartialAgentState = [String: Any] public typealias NodeAction = ( Action ) async throws -> PartialAgentState public typealias EdgeCondition = ( Action ) async throws -> String public typealias Reducer = ( Value?, Value ) -> Value -//public typealias UnaryOperator = () -> Value +public typealias UnaryOperator = () -> Value public protocol ChannelProtocol { - associatedtype ItemType - var value: ItemType? { get } - func update( _ newValue: Any ) throws -> Self + associatedtype T + var reducer: Reducer { get } + var `default`: UnaryOperator { get } + + func update( oldValue: Any?, newValue: Any ) throws -> Any } + public class Channel : ChannelProtocol { - public var value: T? - public var reducer: Reducer? -// var `default`: UnaryOperator? + public var reducer: Reducer + public var `default`: UnaryOperator + + public init(reducer: @escaping Reducer, default defaultValueProvider: @escaping UnaryOperator ) { + self.reducer = reducer + self.`default` = defaultValueProvider + } - public func update( _ newValue: Any ) throws -> Self { + public func update( oldValue: Any?, newValue: Any ) throws -> Any { guard let new = newValue as? T else { - throw CompiledGraphError.executionError( "Channel update type mismatch!") + throw CompiledGraphError.executionError( "Channel update 'newValue' type mismatch!") } - if let reducer { - value = reducer( value, new ) + + var old:T + if oldValue == nil { + old = self.`default`() } else { - value = new + guard let _old = oldValue as? T else { + throw CompiledGraphError.executionError( "Channel update 'oldValue' type mismatch!") + } + old = _old } - return self - } -} -public typealias PartialAgentState = [String: Any] + return reducer( old, new ) + } +} -public class AppendChannel : Channel<[T]> { - - init(_ value: [T]? = nil ) { - super.init() - self.value = value - self.reducer = { (left, right) in - +public class AppenderChannel : Channel<[T]> { + + public init( default defaultValueProvider: @escaping UnaryOperator<[T]> = { [] } ) { + super.init( reducer: { left, right in + guard var left else { return right } - + left.append(contentsOf: right) return left - } + }, + default : defaultValueProvider) } - - public override func update( _ newValue: Any ) throws -> Self { - if let new = newValue as? T { - return try super.update( [new] ) + + public override func update( oldValue: Any?, newValue: Any ) throws -> Any { + if let new = newValue as? T { + return try super.update(oldValue: oldValue, newValue: [new] ) } - return try super.update( newValue ) + return try super.update(oldValue: oldValue, newValue: newValue ) } } +public typealias Channels = [String: any ChannelProtocol ] + public protocol AgentState { var data: [String: Any] { get } @@ -71,10 +84,6 @@ public protocol AgentState { extension AgentState { public func value( _ key: String ) -> T? { - if let channel = data[ key ] as? Channel { - return channel.value - } - return data[ key ] as? T } @@ -94,7 +103,7 @@ public struct NodeOutput { public struct BaseAgentState : AgentState { - subscript(key: String) -> Any? { + public subscript(key: String) -> Any? { value( key ) } @@ -192,9 +201,10 @@ public class StateGraph { var edges:Dictionary var entryPoint:EdgeValue var finishPoint:String? - + let schema: Channels + init( owner: StateGraph ) { - + self.schema = owner.schema self.stateFactory = owner.stateFactory self.nodes = Dictionary() self.edges = Dictionary() @@ -214,12 +224,17 @@ public class StateGraph { if partialState.isEmpty { return currentState } - let newState = try currentState.data.merging(partialState, uniquingKeysWith: { - (current, new) in - - if let value = current as? (any ChannelProtocol) { - return try value.update( new ) + + let _partialState = try partialState.map { key, value in + if let channel = schema[key] { + return ( key , try channel.update( oldValue: currentState.data[key], newValue: value )) } + return (key, value) + + } + let newState = currentState.data.merging(_partialState, uniquingKeysWith: { + (current, new) in + return new }) return State.init(newState) @@ -358,10 +373,12 @@ public class StateGraph { private var finishPoint: String? private var stateFactory: () -> State + private var schema: Channels - public init( stateFactory: @escaping () -> State ) { + public init( schema: Channels = [:], stateFactory: @escaping () -> State ) { + self.schema = schema self.stateFactory = stateFactory - + } public func addNode( _ id: String, action: @escaping NodeAction ) throws { diff --git a/Tests/LangGraphTests/LangGraphTests.swift b/Tests/LangGraphTests/LangGraphTests.swift index dc7d6dd..7975f9b 100644 --- a/Tests/LangGraphTests/LangGraphTests.swift +++ b/Tests/LangGraphTests/LangGraphTests.swift @@ -25,16 +25,6 @@ final class LangGraphTests: XCTestCase { } } } - if let value1 = value as? (any ChannelProtocol) { - if let values = value1.value as? [Any] { - if values.count == values2.count { - for ( v1, v2) in zip(values, values2) { - return compareAsEquatable( v1, v2 ) - } - } - - } - } } return false } @@ -279,11 +269,12 @@ final class LangGraphTests: XCTestCase { } struct AgentStateWithAppender : AgentState { - var data: [String : Any] - init() { - self.init( ["messages": AppendChannel()] ) - } + static var schema: Channels = { + ["messages": AppenderChannel( default: { [] })] + }() + + var data: [String : Any] init(_ initState: [String : Any]) { data = initState @@ -294,8 +285,8 @@ final class LangGraphTests: XCTestCase { } func testAppender() async throws { - - let workflow = StateGraph { AgentStateWithAppender() } + + let workflow = StateGraph( schema: AgentStateWithAppender.schema ) { AgentStateWithAppender( [:] ) } try workflow.addNode("agent_1") { state in @@ -329,7 +320,7 @@ final class LangGraphTests: XCTestCase { func testWithStream() async throws { - let workflow = StateGraph { AgentStateWithAppender() } + let workflow = StateGraph( schema: AgentStateWithAppender.schema ) { AgentStateWithAppender( [:] ) } try workflow.addNode("agent_1") { state in ["messages": "message1"] @@ -365,7 +356,7 @@ final class LangGraphTests: XCTestCase { func testWithStreamAnCancellation() async throws { - let workflow = StateGraph { AgentStateWithAppender() } + let workflow = StateGraph( schema: AgentStateWithAppender.schema ) { AgentStateWithAppender([:]) } try workflow.addNode("agent_1") { state in try await Task.sleep(nanoseconds: 500_000_000) diff --git a/channels.playground/Contents.swift b/channels.playground/Contents.swift new file mode 100644 index 0000000..d2b859e --- /dev/null +++ b/channels.playground/Contents.swift @@ -0,0 +1,27 @@ +import UIKit + +struct Channel { + var value: T + +} + +func update( channel: Channel, newValue: Any ) throws -> Any { + print( "update single \(newValue)" ) + + return newValue +} + +func update( channel: Channel<[T]>, newValue: Any ) throws -> Any { + print( "update array \(newValue)" ) + + return newValue +} + + +let channel = Channel(value: [ "1", "2", "3" ]) + +try update(channel: channel, newValue: []) + +let channel1 = Channel(value: "BBBBB") + +try update(channel: channel1, newValue: "CCCCC") diff --git a/channels.playground/contents.xcplayground b/channels.playground/contents.xcplayground new file mode 100644 index 0000000..cf026f2 --- /dev/null +++ b/channels.playground/contents.xcplayground @@ -0,0 +1,4 @@ + + + + \ No newline at end of file