diff --git a/Sources/LangGraph/LangGraph.swift b/Sources/LangGraph/LangGraph.swift index 32c90b7..baca080 100644 --- a/Sources/LangGraph/LangGraph.swift +++ b/Sources/LangGraph/LangGraph.swift @@ -1,123 +1,60 @@ import OSLog -public typealias PartialAgentState = [String: Any] - public typealias NodeAction = ( Action ) async throws -> PartialAgentState public typealias EdgeCondition = ( Action ) async throws -> String -public typealias BinaryOperator = ( Value?, Value ) throws -> Value -public typealias UnaryOperator = () -> Value - -public struct Evaluator { - var value: BinaryOperator - var `default`: UnaryOperator? -} +public typealias Reducer = ( Value?, Value ) -> Value +//public typealias UnaryOperator = () -> Value -enum Channel { - case Value - case Eval( Evaluator ) -} - -protocol EvaluableValueProtocol { +public protocol ChannelProtocol { associatedtype ItemType - var value: ItemType? { get } - - mutating func setValue( _ newValue: Any ) throws -> Void + func update( _ newValue: Any ) throws -> Self } -public struct EvaluableValue : EvaluableValueProtocol { - typealias ItemType = T - - private var channel: Channel - private var _value: T? +public class Channel : ChannelProtocol { + public var value: T? + public var reducer: Reducer? +// var `default`: UnaryOperator? - public var value: T? { - get { - _value + public func update( _ newValue: Any ) throws -> Self { + guard let new = newValue as? T else { + throw CompiledGraphError.executionError( "Channel update type mismatch!") } - } - - public init( ){ - self.channel = Channel.Value - } - - public init( evaluator: Evaluator ){ - self.channel = Channel.Eval(evaluator) - if let defaultValue = evaluator.default { - self._value = defaultValue() as? T + if let reducer { + value = reducer( value, new ) } - } - - mutating func setValue( _ newValue: Any ) throws { - - switch( channel ) { - case .Value: - _value = newValue as? T - - case .Eval( let evaluator ): - _value = try evaluator.value( _value, newValue ) as? T - + else { + value = new } + return self } } -protocol AppendableValueProtocol : EvaluableValueProtocol { - -} +public typealias PartialAgentState = [String: Any] -public struct AppendableValue : AppendableValueProtocol { - typealias ItemType = [T] - - var _value:EvaluableValue<[T]>? - - public var value: [T]? { - get { - _value?.value - } - } + +public class AppendChannel : Channel<[T]> { - func _append( left: Any, right: Any ) throws -> Any { - if let typedValues = right as? [T] { - guard var left = left as? [T] else { - return typedValues + init(_ value: [T]? = nil ) { + super.init() + self.value = value + self.reducer = { (left, right) in + + guard var left else { + return right } - left.append(contentsOf: typedValues) + left.append(contentsOf: right) return left } - - - throw CompiledGraphError.executionError( "AppenderValue type mismatch!") - } - - public init() { - - let evaluator = Evaluator( value: _append ) - _value = EvaluableValue( evaluator: evaluator ) - } - - public init( values: [T] ) { - let evaluator = Evaluator( value: _append) { - values - } - _value = EvaluableValue( evaluator: evaluator ) } - - mutating func setValue( _ newValue: Any ) throws { - if let newValue = newValue as? T { - try self._value?.setValue([newValue]) - return - } - - if let newValue = newValue as? [T] { - - try self._value?.setValue( newValue ) - return + public override func update( _ newValue: Any ) throws -> Self { + if let new = newValue as? T { + return try super.update( [new] ) } - - throw CompiledGraphError.executionError( "AppenderValue type mismatch!") + return try super.update( newValue ) } } @@ -134,22 +71,14 @@ public protocol AgentState { extension AgentState { public func value( _ key: String ) -> T? { + if let channel = data[ key ] as? Channel { + return channel.value + } + return data[ key ] as? T + } - public func evaluableValue(_ key: String ) -> T? { - guard let eval = data[ key ] as? EvaluableValue else { - return nil - } - return eval.value - } - - public func appendableValue( _ key: String ) -> [T]? { - guard let eval = data[ key ] as? AppendableValue else { - return nil - } - return eval.value - } } public struct NodeOutput { @@ -166,7 +95,7 @@ public struct NodeOutput { public struct BaseAgentState : AgentState { subscript(key: String) -> Any? { - data[key] + value( key ) } public var data: [String : Any] @@ -288,10 +217,8 @@ public class StateGraph { let newState = try currentState.data.merging(partialState, uniquingKeysWith: { (current, new) in - if var eval = current as? (any EvaluableValueProtocol) { - try eval.setValue(new) - - return eval + if let value = current as? (any ChannelProtocol) { + return try value.update( new ) } return new }) diff --git a/Tests/LangGraphTests/LangGraphTests.swift b/Tests/LangGraphTests/LangGraphTests.swift index d601041..dc7d6dd 100644 --- a/Tests/LangGraphTests/LangGraphTests.swift +++ b/Tests/LangGraphTests/LangGraphTests.swift @@ -25,7 +25,7 @@ final class LangGraphTests: XCTestCase { } } } - if let value1 = value as? (any AppendableValueProtocol) { + if let value1 = value as? (any ChannelProtocol) { if let values = value1.value as? [Any] { if values.count == values2.count { for ( v1, v2) in zip(values, values2) { @@ -282,14 +282,14 @@ final class LangGraphTests: XCTestCase { var data: [String : Any] init() { - self.init( ["messages": AppendableValue()] ) + self.init( ["messages": AppendChannel()] ) } init(_ initState: [String : Any]) { data = initState } var messages:[String]? { - appendableValue("messages") + value("messages") } }