diff --git a/Sources/LangGraph/LangGraph.swift b/Sources/LangGraph/LangGraph.swift index 417e3aa..ab9f7e2 100644 --- a/Sources/LangGraph/LangGraph.swift +++ b/Sources/LangGraph/LangGraph.swift @@ -169,14 +169,14 @@ public let END = "__END__" // id of the edge ending workflow let log = Logger( subsystem: Bundle.module.bundleIdentifier ?? "langgraph", category: "main") -public class GraphState { +public class StateGraph { enum EdgeValue /* Union */ { case id(String) case condition( ( EdgeCondition, [String:String] ) ) } - public class Runner { + public class CompiledGraph { var stateFactory: () -> State var nodes:Dictionary> @@ -184,7 +184,7 @@ public class GraphState { var entryPoint:String var finishPoint:String? - init( owner: GraphState ) { + init( owner: StateGraph ) { self.stateFactory = owner.stateFactory self.nodes = Dictionary() @@ -315,7 +315,7 @@ public class GraphState { var id: String { sourceId } - static func == (lhs: GraphState.Edge, rhs: GraphState.Edge) -> Bool { + static func == (lhs: StateGraph.Edge, rhs: StateGraph.Edge) -> Bool { lhs.id == rhs.id } @@ -331,7 +331,7 @@ public class GraphState { private var edges: Set = [] struct Node : Hashable, Identifiable { - static func == (lhs: GraphState.Node, rhs: GraphState.Node) -> Bool { + static func == (lhs: StateGraph.Node, rhs: StateGraph.Node) -> Bool { lhs.id == rhs.id } @@ -407,7 +407,7 @@ public class GraphState { Node(id: id, action: fakeAction) } - public func compile() throws -> Runner { + public func compile() throws -> CompiledGraph { guard let entryPoint else { throw GraphStateError.missingEntryPoint } @@ -443,6 +443,6 @@ public class GraphState { } } - return Runner( owner: self ) + return CompiledGraph( owner: self ) } } diff --git a/Tests/LangGraphTests/LangGraphTests.swift b/Tests/LangGraphTests/LangGraphTests.swift index 4005c91..ec48681 100644 --- a/Tests/LangGraphTests/LangGraphTests.swift +++ b/Tests/LangGraphTests/LangGraphTests.swift @@ -48,7 +48,7 @@ final class LangGraphTests: XCTestCase { } func testValidation() async throws { - let workflow = GraphState { BaseAgentState() } + let workflow = StateGraph { BaseAgentState() } XCTAssertThrowsError( try workflow.compile() ) {error in print( error ) @@ -128,7 +128,7 @@ final class LangGraphTests: XCTestCase { func testRunningOneNode() async throws { - let workflow = GraphState { BaseAgentState() } + let workflow = StateGraph { BaseAgentState() } try workflow.setEntryPoint("agent_1") try workflow.addNode("agent_1") { state in @@ -171,7 +171,7 @@ final class LangGraphTests: XCTestCase { func testRunningTreeNodes() async throws { - let workflow = GraphState { BinaryOpState() } + let workflow = StateGraph { BinaryOpState() } try workflow.addNode("agent_1") { state in @@ -209,7 +209,7 @@ final class LangGraphTests: XCTestCase { func testRunningFourNodesWithCondition() async throws { - let workflow = GraphState { BinaryOpState() } + let workflow = StateGraph { BinaryOpState() } try workflow.addNode("agent_1") { state in @@ -293,7 +293,7 @@ final class LangGraphTests: XCTestCase { func testAppender() async throws { - let workflow = GraphState { AgentStateWithAppender() } + let workflow = StateGraph { AgentStateWithAppender() } try workflow.addNode("agent_1") { state in @@ -327,7 +327,7 @@ final class LangGraphTests: XCTestCase { func testWithStream() async throws { - let workflow = GraphState { AgentStateWithAppender() } + let workflow = StateGraph { AgentStateWithAppender() } try workflow.addNode("agent_1") { state in ["messages": "message1"] @@ -363,7 +363,7 @@ final class LangGraphTests: XCTestCase { func testWithStreamAnCancellation() async throws { - let workflow = GraphState { AgentStateWithAppender() } + let workflow = StateGraph { AgentStateWithAppender() } try workflow.addNode("agent_1") { state in try await Task.sleep(nanoseconds: 500_000_000)