Skip to content

Commit

Permalink
refactor: add channel management in AgentState
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 3, 2024
1 parent 99d04dc commit 2e96e05
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 115 deletions.
151 changes: 39 additions & 112 deletions Sources/LangGraph/LangGraph.swift
Original file line number Diff line number Diff line change
@@ -1,123 +1,60 @@
import OSLog

public typealias PartialAgentState = [String: Any]

public typealias NodeAction<Action: AgentState> = ( Action ) async throws -> PartialAgentState
public typealias EdgeCondition<Action: AgentState> = ( Action ) async throws -> String

public typealias BinaryOperator<Value> = ( Value?, Value ) throws -> Value
public typealias UnaryOperator<Value> = () -> Value

public struct Evaluator<T> {
var value: BinaryOperator<T>
var `default`: UnaryOperator<T>?
}
public typealias Reducer<Value> = ( Value?, Value ) -> Value
//public typealias UnaryOperator<Value> = () -> Value

enum Channel<T> {
case Value
case Eval( Evaluator<T> )
}

protocol EvaluableValueProtocol {
public protocol ChannelProtocol {
associatedtype ItemType

var value: ItemType? { get }

mutating func setValue( _ newValue: Any ) throws -> Void
func update( _ newValue: Any ) throws -> Self
}

public struct EvaluableValue<T> : EvaluableValueProtocol {
typealias ItemType = T

private var channel: Channel<Any>
private var _value: T?
public class Channel<T> : ChannelProtocol {
public var value: T?
public var reducer: Reducer<T>?
// var `default`: UnaryOperator<T>?

public var value: T? {
get {
_value
public func update( _ newValue: Any ) throws -> Self {
guard let new = newValue as? T else {
throw CompiledGraphError.executionError( "Channel update type mismatch!")
}
}

public init( ){
self.channel = Channel.Value
}

public init( evaluator: Evaluator<Any> ){
self.channel = Channel.Eval(evaluator)
if let defaultValue = evaluator.default {
self._value = defaultValue() as? T
if let reducer {
value = reducer( value, new )
}
}

mutating func setValue( _ newValue: Any ) throws {

switch( channel ) {
case .Value:
_value = newValue as? T

case .Eval( let evaluator ):
_value = try evaluator.value( _value, newValue ) as? T

else {
value = new
}
return self
}
}

protocol AppendableValueProtocol : EvaluableValueProtocol {

}
public typealias PartialAgentState = [String: Any]

public struct AppendableValue<T> : AppendableValueProtocol {
typealias ItemType = [T]

var _value:EvaluableValue<[T]>?

public var value: [T]? {
get {
_value?.value
}
}

public class AppendChannel<T> : Channel<[T]> {

func _append( left: Any, right: Any ) throws -> Any {
if let typedValues = right as? [T] {
guard var left = left as? [T] else {
return typedValues
init(_ value: [T]? = nil ) {
super.init()
self.value = value
self.reducer = { (left, right) in

guard var left else {
return right
}

left.append(contentsOf: typedValues)
left.append(contentsOf: right)
return left
}


throw CompiledGraphError.executionError( "AppenderValue type mismatch!")
}

public init() {

let evaluator = Evaluator( value: _append )
_value = EvaluableValue( evaluator: evaluator )
}

public init( values: [T] ) {
let evaluator = Evaluator( value: _append) {
values
}
_value = EvaluableValue( evaluator: evaluator )
}

mutating func setValue( _ newValue: Any ) throws {
if let newValue = newValue as? T {

try self._value?.setValue([newValue])
return
}

if let newValue = newValue as? [T] {

try self._value?.setValue( newValue )
return
public override func update( _ newValue: Any ) throws -> Self {
if let new = newValue as? T {
return try super.update( [new] )
}

throw CompiledGraphError.executionError( "AppenderValue type mismatch!")
return try super.update( newValue )
}

}
Expand All @@ -134,22 +71,14 @@ public protocol AgentState {
extension AgentState {

public func value<T>( _ key: String ) -> T? {
if let channel = data[ key ] as? Channel<T> {
return channel.value
}

return data[ key ] as? T

}

public func evaluableValue<T>(_ key: String ) -> T? {
guard let eval = data[ key ] as? EvaluableValue<T> else {
return nil
}
return eval.value
}

public func appendableValue<T>( _ key: String ) -> [T]? {
guard let eval = data[ key ] as? AppendableValue<T> else {
return nil
}
return eval.value
}
}

public struct NodeOutput<State: AgentState> {
Expand All @@ -166,7 +95,7 @@ public struct NodeOutput<State: AgentState> {
public struct BaseAgentState : AgentState {

subscript(key: String) -> Any? {
data[key]
value( key )
}

public var data: [String : Any]
Expand Down Expand Up @@ -288,10 +217,8 @@ public class StateGraph<State: AgentState> {
let newState = try currentState.data.merging(partialState, uniquingKeysWith: {
(current, new) in

if var eval = current as? (any EvaluableValueProtocol) {
try eval.setValue(new)

return eval
if let value = current as? (any ChannelProtocol) {
return try value.update( new )
}
return new
})
Expand Down
6 changes: 3 additions & 3 deletions Tests/LangGraphTests/LangGraphTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ final class LangGraphTests: XCTestCase {
}
}
}
if let value1 = value as? (any AppendableValueProtocol) {
if let value1 = value as? (any ChannelProtocol) {
if let values = value1.value as? [Any] {
if values.count == values2.count {
for ( v1, v2) in zip(values, values2) {
Expand Down Expand Up @@ -282,14 +282,14 @@ final class LangGraphTests: XCTestCase {
var data: [String : Any]

init() {
self.init( ["messages": AppendableValue<String>()] )
self.init( ["messages": AppendChannel<String>()] )
}

init(_ initState: [String : Any]) {
data = initState
}
var messages:[String]? {
appendableValue("messages")
value("messages")
}
}

Expand Down

0 comments on commit 2e96e05

Please sign in to comment.