Skip to content

Commit

Permalink
Merge branch 'release/3.0.2'
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 10, 2024
2 parents 14856ff + 286315d commit 1566eef
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 24 deletions.
4 changes: 2 additions & 2 deletions LangChainDemo/LangChainDemo/AgentExecutor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ public func runAgent( input: String, llm: LLM, tools: [BaseTool], callbacks: [Ba
return [ "intermediate_steps" : (action, result) ]
}

try workflow.setEntryPoint("call_start")
workflow.setFinishPoint("call_end")
try workflow.addEdge(sourceId: START, targetId: "call_start")
try workflow.addEdge(sourceId: "call_end", targetId: END)

try workflow.addEdge(sourceId: "call_start", targetId: "call_agent")
try workflow.addEdge(sourceId: "call_action", targetId: "call_agent")
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ In the [LangChainDemo](LangChainDemo) project, you can find the porting of [Agen
return [ "intermediate_steps" : (action, result) ]
}

try workflow.setEntryPoint("call_agent")
try workflow.addEdge(sourceId: START, targetId: "call_agent")

try workflow.addConditionalEdge( sourceId: "call_agent", condition: { state in

guard let agentOutcome = state.agentOutcome else {
Expand Down
31 changes: 22 additions & 9 deletions Sources/LangGraph/LangGraph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ public enum CompiledGraphError : Error, LocalizedError {
}
}

public let START = "__START__" // id of the edge staring workflow
public let END = "__END__" // id of the edge ending workflow

//enum Either<Left, Right> {
Expand Down Expand Up @@ -419,7 +420,7 @@ public class StateGraph<State: AgentState> {

}

public func addNode( _ id: String, action: @escaping NodeAction<State> ) throws {
public func addNode( _ id: String, action: @escaping NodeAction<State> ) throws {
guard id != END else {
throw StateGraphError.invalidNodeIdentifier( "END is not a valid node id!")
}
Expand All @@ -434,6 +435,13 @@ public class StateGraph<State: AgentState> {
guard sourceId != END else {
throw StateGraphError.invalidEdgeIdentifier( "END is not a valid edge sourceId!")
}
guard sourceId != START else {
if targetId == END {
throw StateGraphError.invalidNodeIdentifier( "END is not a valid node entry point!")
}
entryPoint = EdgeValue.id(targetId)
return
}

let edge = Edge(sourceId: sourceId, target: .id(targetId) )
if edges.contains(edge) {
Expand All @@ -448,25 +456,30 @@ public class StateGraph<State: AgentState> {
if edgeMapping.isEmpty {
throw StateGraphError.edgeMappingIsEmpty
}
guard sourceId != START else {
entryPoint = EdgeValue.condition((condition, edgeMapping))
return
}

let edge = Edge(sourceId: sourceId, target: .condition(( condition, edgeMapping)) )
if edges.contains(edge) {
throw StateGraphError.duplicateEdgeError("edge with id:\(sourceId) already exist!")
}
edges.insert( edge)
return
}

@available(*, deprecated, message: "This method is deprecated. Use `addEdge( START, nodeId )` instead.")
public func setEntryPoint( _ nodeId: String ) throws {
guard nodeId != END else {
throw StateGraphError.invalidNodeIdentifier( "END is not a valid node entry point!")
}
entryPoint = EdgeValue.id(nodeId)
let _ = try addEdge( sourceId: START, targetId: nodeId )
}

@available(*, deprecated, message: "This method is deprecated. Use `addConditionalEdge( START, condition, edgeMappings )` instead.")
public func setConditionalEntryPoint( condition: @escaping EdgeCondition<State>, edgeMapping: [String:String] ) throws {
if edgeMapping.isEmpty {
throw StateGraphError.edgeMappingIsEmpty
}
entryPoint = EdgeValue.condition((condition, edgeMapping))
let _ = try self.addConditionalEdge(sourceId: START, condition: condition, edgeMapping: edgeMapping )
}

@available(*, deprecated, message: "This method is deprecated. Use `addEdge( nodeId, END )` instead.")
public func setFinishPoint( _ nodeId: String ) {
finishPoint = nodeId
}
Expand Down
22 changes: 11 additions & 11 deletions Tests/LangGraphTests/LangGraphTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ final class LangGraphTests: XCTestCase {
XCTAssertTrue(error is StateGraphError, "\(error) is not a GraphStateError")
}

try workflow.setEntryPoint("agent_1")
try workflow.addEdge(sourceId: START, targetId: "agent_1")

XCTAssertThrowsError( try workflow.compile() ) {error in
print( error )
Expand Down Expand Up @@ -121,7 +121,7 @@ final class LangGraphTests: XCTestCase {
func testRunningOneNode() async throws {

let workflow = StateGraph { BaseAgentState($0) }
try workflow.setEntryPoint("agent_1")
try workflow.addEdge( sourceId: START, targetId: "agent_1")
try workflow.addNode("agent_1") { state in

print( "agent_1", state )
Expand Down Expand Up @@ -188,8 +188,8 @@ final class LangGraphTests: XCTestCase {
try workflow.addEdge(sourceId: "agent_1", targetId: "agent_2")
try workflow.addEdge(sourceId: "agent_2", targetId: "sum")

try workflow.setEntryPoint("agent_1")
workflow.setFinishPoint("sum")
try workflow.addEdge( sourceId: START, targetId: "agent_1")
try workflow.addEdge(sourceId: "sum", targetId: END )

let app = try workflow.compile()

Expand Down Expand Up @@ -255,7 +255,7 @@ final class LangGraphTests: XCTestCase {
try workflow.addEdge(sourceId: "sum", targetId: END)
try workflow.addEdge(sourceId: "mul", targetId: END)

try workflow.setEntryPoint("agent_1")
try workflow.addEdge(sourceId: START, targetId: "agent_1")

let app = try workflow.compile()

Expand Down Expand Up @@ -306,8 +306,8 @@ final class LangGraphTests: XCTestCase {
try workflow.addEdge(sourceId: "agent_1", targetId: "agent_2")
try workflow.addEdge(sourceId: "agent_2", targetId: "agent_3")

try workflow.setEntryPoint("agent_1")
workflow.setFinishPoint("agent_3")
try workflow.addEdge(sourceId: START, targetId: "agent_1")
try workflow.addEdge(sourceId: "agent_3", targetId: END)

let app = try workflow.compile()

Expand Down Expand Up @@ -335,8 +335,8 @@ final class LangGraphTests: XCTestCase {
try workflow.addEdge(sourceId: "agent_1", targetId: "agent_2")
try workflow.addEdge(sourceId: "agent_2", targetId: "agent_3")

try workflow.setEntryPoint("agent_1")
workflow.setFinishPoint("agent_3")
try workflow.addEdge(sourceId: START, targetId: "agent_1")
try workflow.addEdge(sourceId: "agent_3", targetId: END)

let app = try workflow.compile()

Expand Down Expand Up @@ -374,8 +374,8 @@ final class LangGraphTests: XCTestCase {
try workflow.addEdge(sourceId: "agent_1", targetId: "agent_2")
try workflow.addEdge(sourceId: "agent_2", targetId: "agent_3")

try workflow.setEntryPoint("agent_1")
workflow.setFinishPoint("agent_3")
try workflow.addEdge(sourceId: START, targetId: "agent_1")
try workflow.addEdge(sourceId: "agent_3", targetId: END)

let app = try workflow.compile()

Expand Down

0 comments on commit 1566eef

Please sign in to comment.