Skip to content

Commit

Permalink
Merge branch 'release/1.2.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Mar 19, 2024
2 parents 185f938 + 6167de2 commit c3d6e5a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
31 changes: 21 additions & 10 deletions Sources/LangGraph/LangGraph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,15 @@ public class GraphState<State: AgentState> {
if( verbose ) {
log.debug("start processing node \(currentNodeId)")
}

try Task.checkCancellation()
let partialState = try await action( currentState )

currentState = mergeState( currentState: currentState, partialState: partialState)

let output = NodeOutput(node: currentNodeId,state: currentState)

try Task.checkCancellation()
continuation.yield( output )

if( currentNodeId == finishPoint ) {
Expand All @@ -252,7 +255,7 @@ public class GraphState<State: AgentState> {

currentNodeId = try await nextNodeId(nodeId: currentNodeId, agentState: currentState)

} while( currentNodeId != END )
} while( currentNodeId != END && !Task.isCancelled )

continuation.finish()
}
Expand All @@ -264,13 +267,21 @@ public class GraphState<State: AgentState> {
return stream
}


/// run the graph an return the final State
///
/// - Parameters:
/// - inputs: partial state
/// - verbose: enable verbose output (log)
/// - Returns: final State
public func invoke( inputs: PartialAgentState, verbose:Bool = false ) async throws -> State {

var result:State?
for try await output in stream(inputs: inputs) {
result = output.state
}
return result!
let initResult:[NodeOutput<State>] = []
let result = try await stream(inputs: inputs).reduce( initResult, { partialResult, output in
[output]
})

return result[0].state
}
}

Expand Down Expand Up @@ -388,19 +399,19 @@ public class GraphState<State: AgentState> {
for edge in edges {

guard nodes.contains( makeFakeNode(edge.sourceId) ) else {
throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference a not existent node!")
throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference to non existent node!")
}

switch( edge.target ) {
case .id( let targetId ):
guard targetId==END || nodes.contains(makeFakeNode(targetId) ) else {
throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference a not existent targetId: \(targetId) node!")
throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference to non existent node targetId: \(targetId) node!")
}
break
case .condition((_, let edgeMappings)):
for (_,nodeId) in edgeMappings {
guard nodeId==END || nodes.contains(makeFakeNode(nodeId) ) else {
throw GraphStateError.missingNodeInEdgeMapping( "edge mapping for sourceId: \(edge.sourceId) contains a not existen nodeId \(nodeId)!")
throw GraphStateError.missingNodeInEdgeMapping( "edge mapping for sourceId: \(edge.sourceId) contains a not existent nodeId \(nodeId)!")
}
}
}
Expand Down
50 changes: 50 additions & 0 deletions Tests/LangGraphTests/LangGraphTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,54 @@ final class LangGraphTests: XCTestCase {
XCTAssertEqual( ["agent_1", "agent_2", "agent_3"], nodesInvolved)
}

func testWithStreamAnCancellation() async throws {

let workflow = GraphState { AgentStateWithAppender() }

try workflow.addNode("agent_1") { state in
try await Task.sleep(nanoseconds: 500_000_000)
return ["messages": "message1"]
}
try workflow.addNode("agent_2") { state in
try await Task.sleep(nanoseconds: 500_000_000)
return ["messages": ["message2", "message3"] ]
}
try workflow.addNode("agent_3") { state in
try await Task.sleep(nanoseconds: 500_000_000)
return ["result": state.messages?.count ?? 0]
}

try workflow.addEdge(sourceId: "agent_1", targetId: "agent_2")
try workflow.addEdge(sourceId: "agent_2", targetId: "agent_3")

try workflow.setEntryPoint("agent_1")
workflow.setFinishPoint("agent_3")

let app = try workflow.compile()

let task = Task {

return try await app.stream(inputs: [:] ).reduce([] as [String]) { partialResult, output in

print( "-------------")
print( "Agent Output of \(output.node)" )
print( output.state )
print( "-------------")

return partialResult + [output.node ]
}

}

Task {
try await Task.sleep(nanoseconds: 1_150_000_000) // Sleep for 1/2 second
task.cancel()
print("Cancellation requested")
}

let nodesInvolved = try await task.value

XCTAssertEqual( ["agent_1", "agent_2" ], nodesInvolved)
}

}

0 comments on commit c3d6e5a

Please sign in to comment.