Skip to content

Commit

Permalink
feat: refine Channels management
Browse files Browse the repository at this point in the history
introducing State Schema concept
  • Loading branch information
bsorrentino committed Aug 4, 2024
1 parent 2e96e05 commit 74a4768
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 62 deletions.
105 changes: 61 additions & 44 deletions Sources/LangGraph/LangGraph.swift
Original file line number Diff line number Diff line change
@@ -1,64 +1,77 @@
import OSLog

public typealias PartialAgentState = [String: Any]
public typealias NodeAction<Action: AgentState> = ( Action ) async throws -> PartialAgentState
public typealias EdgeCondition<Action: AgentState> = ( Action ) async throws -> String

public typealias Reducer<Value> = ( Value?, Value ) -> Value
//public typealias UnaryOperator<Value> = () -> Value
public typealias UnaryOperator<Value> = () -> Value

public protocol ChannelProtocol {
associatedtype ItemType
var value: ItemType? { get }
func update( _ newValue: Any ) throws -> Self
associatedtype T
var reducer: Reducer<T> { get }
var `default`: UnaryOperator<T> { get }

func update( oldValue: Any?, newValue: Any ) throws -> Any
}


public class Channel<T> : ChannelProtocol {
public var value: T?
public var reducer: Reducer<T>?
// var `default`: UnaryOperator<T>?
public var reducer: Reducer<T>
public var `default`: UnaryOperator<T>

public init(reducer: @escaping Reducer<T>, default defaultValueProvider: @escaping UnaryOperator<T> ) {
self.reducer = reducer
self.`default` = defaultValueProvider
}

public func update( _ newValue: Any ) throws -> Self {
public func update( oldValue: Any?, newValue: Any ) throws -> Any {
guard let new = newValue as? T else {
throw CompiledGraphError.executionError( "Channel update type mismatch!")
throw CompiledGraphError.executionError( "Channel update 'newValue' type mismatch!")
}
if let reducer {
value = reducer( value, new )

var old:T
if oldValue == nil {
old = self.`default`()
}
else {
value = new
guard let _old = oldValue as? T else {
throw CompiledGraphError.executionError( "Channel update 'oldValue' type mismatch!")
}
old = _old
}
return self
}
}

public typealias PartialAgentState = [String: Any]
return reducer( old, new )

}
}

public class AppendChannel<T> : Channel<[T]> {

init(_ value: [T]? = nil ) {
super.init()
self.value = value
self.reducer = { (left, right) in

public class AppenderChannel<T> : Channel<[T]> {

public init( default defaultValueProvider: @escaping UnaryOperator<[T]> = { [] } ) {
super.init( reducer: { left, right in

guard var left else {
return right
}

left.append(contentsOf: right)
return left
}
},
default : defaultValueProvider)
}

public override func update( _ newValue: Any ) throws -> Self {
if let new = newValue as? T {
return try super.update( [new] )
public override func update( oldValue: Any?, newValue: Any ) throws -> Any {
if let new = newValue as? T {
return try super.update(oldValue: oldValue, newValue: [new] )
}
return try super.update( newValue )
return try super.update(oldValue: oldValue, newValue: newValue )
}

}

public typealias Channels = [String: any ChannelProtocol ]

public protocol AgentState {

var data: [String: Any] { get }
Expand All @@ -71,10 +84,6 @@ public protocol AgentState {
extension AgentState {

public func value<T>( _ key: String ) -> T? {
if let channel = data[ key ] as? Channel<T> {
return channel.value
}

return data[ key ] as? T

}
Expand All @@ -94,7 +103,7 @@ public struct NodeOutput<State: AgentState> {

public struct BaseAgentState : AgentState {

subscript(key: String) -> Any? {
public subscript(key: String) -> Any? {
value( key )
}

Expand Down Expand Up @@ -192,9 +201,10 @@ public class StateGraph<State: AgentState> {
var edges:Dictionary<String, EdgeValue>
var entryPoint:EdgeValue
var finishPoint:String?

let schema: Channels

init( owner: StateGraph ) {

self.schema = owner.schema
self.stateFactory = owner.stateFactory
self.nodes = Dictionary()
self.edges = Dictionary()
Expand All @@ -214,12 +224,17 @@ public class StateGraph<State: AgentState> {
if partialState.isEmpty {
return currentState
}
let newState = try currentState.data.merging(partialState, uniquingKeysWith: {
(current, new) in

if let value = current as? (any ChannelProtocol) {
return try value.update( new )

let _partialState = try partialState.map { key, value in
if let channel = schema[key] {
return ( key , try channel.update( oldValue: currentState.data[key], newValue: value ))
}
return (key, value)

}
let newState = currentState.data.merging(_partialState, uniquingKeysWith: {
(current, new) in

return new
})
return State.init(newState)
Expand Down Expand Up @@ -358,10 +373,12 @@ public class StateGraph<State: AgentState> {
private var finishPoint: String?

private var stateFactory: () -> State
private var schema: Channels

public init( stateFactory: @escaping () -> State ) {
public init( schema: Channels = [:], stateFactory: @escaping () -> State ) {
self.schema = schema
self.stateFactory = stateFactory

}

public func addNode( _ id: String, action: @escaping NodeAction<State> ) throws {
Expand Down
27 changes: 9 additions & 18 deletions Tests/LangGraphTests/LangGraphTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@ final class LangGraphTests: XCTestCase {
}
}
}
if let value1 = value as? (any ChannelProtocol) {
if let values = value1.value as? [Any] {
if values.count == values2.count {
for ( v1, v2) in zip(values, values2) {
return compareAsEquatable( v1, v2 )
}
}

}
}
}
return false
}
Expand Down Expand Up @@ -279,11 +269,12 @@ final class LangGraphTests: XCTestCase {
}

struct AgentStateWithAppender : AgentState {
var data: [String : Any]

init() {
self.init( ["messages": AppendChannel<String>()] )
}
static var schema: Channels = {
["messages": AppenderChannel<String>( default: { [] })]
}()

var data: [String : Any]

init(_ initState: [String : Any]) {
data = initState
Expand All @@ -294,8 +285,8 @@ final class LangGraphTests: XCTestCase {
}

func testAppender() async throws {
let workflow = StateGraph { AgentStateWithAppender() }

let workflow = StateGraph( schema: AgentStateWithAppender.schema ) { AgentStateWithAppender( [:] ) }

try workflow.addNode("agent_1") { state in

Expand Down Expand Up @@ -329,7 +320,7 @@ final class LangGraphTests: XCTestCase {

func testWithStream() async throws {

let workflow = StateGraph { AgentStateWithAppender() }
let workflow = StateGraph( schema: AgentStateWithAppender.schema ) { AgentStateWithAppender( [:] ) }

try workflow.addNode("agent_1") { state in
["messages": "message1"]
Expand Down Expand Up @@ -365,7 +356,7 @@ final class LangGraphTests: XCTestCase {

func testWithStreamAnCancellation() async throws {

let workflow = StateGraph { AgentStateWithAppender() }
let workflow = StateGraph( schema: AgentStateWithAppender.schema ) { AgentStateWithAppender([:]) }

try workflow.addNode("agent_1") { state in
try await Task.sleep(nanoseconds: 500_000_000)
Expand Down
27 changes: 27 additions & 0 deletions channels.playground/Contents.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import UIKit

struct Channel<T> {
var value: T

}

func update<T>( channel: Channel<T>, newValue: Any ) throws -> Any {
print( "update single \(newValue)" )

return newValue
}

func update<T>( channel: Channel<[T]>, newValue: Any ) throws -> Any {
print( "update array \(newValue)" )

return newValue
}


let channel = Channel(value: [ "1", "2", "3" ])

try update(channel: channel, newValue: [])

let channel1 = Channel(value: "BBBBB")

try update(channel: channel1, newValue: "CCCCC")
4 changes: 4 additions & 0 deletions channels.playground/contents.xcplayground
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<playground version='5.0' target-platform='ios' buildActiveScheme='true' importAppTypes='true'>
<timeline fileName='timeline.xctimeline'/>
</playground>

0 comments on commit 74a4768

Please sign in to comment.