diff --git a/Sources/LangGraph/LangGraph.swift b/Sources/LangGraph/LangGraph.swift index 0912022..9d0e32f 100644 --- a/Sources/LangGraph/LangGraph.swift +++ b/Sources/LangGraph/LangGraph.swift @@ -238,12 +238,15 @@ public class GraphState { if( verbose ) { log.debug("start processing node \(currentNodeId)") } + + try Task.checkCancellation() let partialState = try await action( currentState ) - + currentState = mergeState( currentState: currentState, partialState: partialState) let output = NodeOutput(node: currentNodeId,state: currentState) + try Task.checkCancellation() continuation.yield( output ) if( currentNodeId == finishPoint ) { @@ -252,7 +255,7 @@ public class GraphState { currentNodeId = try await nextNodeId(nodeId: currentNodeId, agentState: currentState) - } while( currentNodeId != END ) + } while( currentNodeId != END && !Task.isCancelled ) continuation.finish() } @@ -264,13 +267,21 @@ public class GraphState { return stream } + + /// run the graph an return the final State + /// + /// - Parameters: + /// - inputs: partial state + /// - verbose: enable verbose output (log) + /// - Returns: final State public func invoke( inputs: PartialAgentState, verbose:Bool = false ) async throws -> State { - var result:State? - for try await output in stream(inputs: inputs) { - result = output.state - } - return result! + let initResult:[NodeOutput] = [] + let result = try await stream(inputs: inputs).reduce( initResult, { partialResult, output in + [output] + }) + + return result[0].state } } @@ -388,19 +399,19 @@ public class GraphState { for edge in edges { guard nodes.contains( makeFakeNode(edge.sourceId) ) else { - throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference a not existent node!") + throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference to non existent node!") } switch( edge.target ) { case .id( let targetId ): guard targetId==END || nodes.contains(makeFakeNode(targetId) ) else { - throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference a not existent targetId: \(targetId) node!") + throw GraphStateError.missingNodeReferencedByEdge( "edge sourceId: \(edge.sourceId) reference to non existent node targetId: \(targetId) node!") } break case .condition((_, let edgeMappings)): for (_,nodeId) in edgeMappings { guard nodeId==END || nodes.contains(makeFakeNode(nodeId) ) else { - throw GraphStateError.missingNodeInEdgeMapping( "edge mapping for sourceId: \(edge.sourceId) contains a not existen nodeId \(nodeId)!") + throw GraphStateError.missingNodeInEdgeMapping( "edge mapping for sourceId: \(edge.sourceId) contains a not existent nodeId \(nodeId)!") } } } diff --git a/Tests/LangGraphTests/LangGraphTests.swift b/Tests/LangGraphTests/LangGraphTests.swift index 9535fc0..f6b7bec 100644 --- a/Tests/LangGraphTests/LangGraphTests.swift +++ b/Tests/LangGraphTests/LangGraphTests.swift @@ -361,4 +361,54 @@ final class LangGraphTests: XCTestCase { XCTAssertEqual( ["agent_1", "agent_2", "agent_3"], nodesInvolved) } + func testWithStreamAnCancellation() async throws { + + let workflow = GraphState { AgentStateWithAppender() } + + try workflow.addNode("agent_1") { state in + try await Task.sleep(nanoseconds: 500_000_000) + return ["messages": "message1"] + } + try workflow.addNode("agent_2") { state in + try await Task.sleep(nanoseconds: 500_000_000) + return ["messages": ["message2", "message3"] ] + } + try workflow.addNode("agent_3") { state in + try await Task.sleep(nanoseconds: 500_000_000) + return ["result": state.messages?.count ?? 0] + } + + try workflow.addEdge(sourceId: "agent_1", targetId: "agent_2") + try workflow.addEdge(sourceId: "agent_2", targetId: "agent_3") + + try workflow.setEntryPoint("agent_1") + workflow.setFinishPoint("agent_3") + + let app = try workflow.compile() + + let task = Task { + + return try await app.stream(inputs: [:] ).reduce([] as [String]) { partialResult, output in + + print( "-------------") + print( "Agent Output of \(output.node)" ) + print( output.state ) + print( "-------------") + + return partialResult + [output.node ] + } + + } + + Task { + try await Task.sleep(nanoseconds: 1_150_000_000) // Sleep for 1/2 second + task.cancel() + print("Cancellation requested") + } + + let nodesInvolved = try await task.value + + XCTAssertEqual( ["agent_1", "agent_2" ], nodesInvolved) + } + }