From dc46c9b49847c52b6ff48d4414c0b23cfcf70352 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Sun, 17 Mar 2024 20:48:03 +0100 Subject: [PATCH] feat: initial implementation graph creation graph compilation --- .../java/org/bsc/langgraph4j/AgentState.java | 12 ++ .../bsc/langgraph4j/AgentStateFactory.java | 9 + .../org/bsc/langgraph4j/AppendableValue.java | 6 + .../java/org/bsc/langgraph4j/EdgeAction.java | 9 + .../java/org/bsc/langgraph4j/GraphState.java | 182 ++++++++++++++++++ .../bsc/langgraph4j/GraphStateException.java | 8 + .../java/org/bsc/langgraph4j/NodeAction.java | 19 ++ .../org/bsc/langgraph4j/NodeAsyncAction.java | 11 ++ .../org/bsc/langgraph4j/LangGraphTest.java | 73 +++++++ 9 files changed, 329 insertions(+) create mode 100644 src/main/java/org/bsc/langgraph4j/AgentState.java create mode 100644 src/main/java/org/bsc/langgraph4j/AgentStateFactory.java create mode 100644 src/main/java/org/bsc/langgraph4j/AppendableValue.java create mode 100644 src/main/java/org/bsc/langgraph4j/EdgeAction.java create mode 100644 src/main/java/org/bsc/langgraph4j/GraphState.java create mode 100644 src/main/java/org/bsc/langgraph4j/GraphStateException.java create mode 100644 src/main/java/org/bsc/langgraph4j/NodeAction.java create mode 100644 src/main/java/org/bsc/langgraph4j/NodeAsyncAction.java create mode 100644 src/test/java/org/bsc/langgraph4j/LangGraphTest.java diff --git a/src/main/java/org/bsc/langgraph4j/AgentState.java b/src/main/java/org/bsc/langgraph4j/AgentState.java new file mode 100644 index 0000000..444ae08 --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/AgentState.java @@ -0,0 +1,12 @@ +package org.bsc.langgraph4j; + +import java.util.Optional; + +public interface AgentState { + + java.util.Map data(); + + default Optional getValue(String key) { + return Optional.ofNullable((T) data().get(key)); + }; +} diff --git a/src/main/java/org/bsc/langgraph4j/AgentStateFactory.java b/src/main/java/org/bsc/langgraph4j/AgentStateFactory.java new file mode 100644 index 0000000..d131be4 --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/AgentStateFactory.java @@ -0,0 +1,9 @@ +package org.bsc.langgraph4j; + + +import java.util.Map; +import java.util.function.Function; + +public interface AgentStateFactory extends Function, State> { + +} diff --git a/src/main/java/org/bsc/langgraph4j/AppendableValue.java b/src/main/java/org/bsc/langgraph4j/AppendableValue.java new file mode 100644 index 0000000..c7145c9 --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/AppendableValue.java @@ -0,0 +1,6 @@ +package org.bsc.langgraph4j; + +import java.util.List; + +public record AppendableValue(List values ) { +} diff --git a/src/main/java/org/bsc/langgraph4j/EdgeAction.java b/src/main/java/org/bsc/langgraph4j/EdgeAction.java new file mode 100644 index 0000000..ebd575c --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/EdgeAction.java @@ -0,0 +1,9 @@ +package org.bsc.langgraph4j; + +import java.util.concurrent.CompletableFuture; + +@FunctionalInterface +public interface EdgeAction { + + String apply(S t) throws Exception; +} diff --git a/src/main/java/org/bsc/langgraph4j/GraphState.java b/src/main/java/org/bsc/langgraph4j/GraphState.java new file mode 100644 index 0000000..4199f4e --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/GraphState.java @@ -0,0 +1,182 @@ +package org.bsc.langgraph4j; + +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +import static java.lang.String.format; + +enum GraphStateError { + invalidNodeIdentifier( "END is not a valid node id!"), + invalidEdgeIdentifier( "END is not a valid edge sourceId!"), + duplicateNodeError("node with id: %s already exist!"), + duplicateEdgeError("edge with id: %s already exist!"), + edgeMappingIsEmpty( "edge mapping is empty!" ), + missingEntryPoint( "missing Entry Point" ), + entryPointNotExist("entryPoint: %s doesn't exist!" ), + finishPointNotExist( "finishPoint: %s doesn't exist!"), + missingNodeReferencedByEdge( "edge sourceId: %s reference a not existent node!"), + missingNodeInEdgeMapping( "edge mapping for sourceId: %s contains a not existent nodeId %s!"), + invalidEdgeTarget( "edge sourceId: %s has an initialized target value!" ) + ; + private final String errorMessage; + + GraphStateError(String errorMessage ) { + this.errorMessage = errorMessage; + } + + GraphStateException exception(String... args ) { + return new GraphStateException( format(errorMessage, (Object[]) args) ); + } +} + +record EdgeCondition(EdgeAction action, Map mappings) {} +record EdgeValue(String id, EdgeCondition value) {} +public class GraphState { + + public class Runnable { + + } + public static String END = "__END__"; + + record Node(String id, NodeAsyncAction action) { + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Node node = (Node) o; + return Objects.equals(id, node.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + record Edge(String sourceId, EdgeValue target) { + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Edge node = (Edge) o; + return Objects.equals(sourceId, node.sourceId); + } + + @Override + public int hashCode() { + return Objects.hash(sourceId); + } + + } + + Set> nodes = new HashSet<>(); + Set> edges = new HashSet<>(); + + String entryPoint; + String finishPoint; + + AgentStateFactory stateFactory; + + public GraphState( AgentStateFactory stateFactory ) { + this.stateFactory = stateFactory; + } + + public void setEntryPoint(String entryPoint) { + this.entryPoint = entryPoint; + } + + public void setFinishPoint(String finishPoint) { + this.finishPoint = finishPoint; + } + + public void addNode(String id, NodeAsyncAction action) throws GraphStateException { + if( Objects.equals( id, END)) { + throw GraphStateError.invalidNodeIdentifier.exception(END); + } + var node = new Node(id, action); + + if( nodes.contains(node ) ) { + throw GraphStateError.duplicateNodeError.exception(id); + } + + nodes.add( node ); + } + + public void addEdge(String sourceId, String targetId) throws GraphStateException { + if( Objects.equals( sourceId, END)) { + throw GraphStateError.invalidEdgeIdentifier.exception(END); + } + var edge = new Edge(sourceId, new EdgeValue<>(targetId, null) ); + + if( edges.contains(edge ) ) { + throw GraphStateError.duplicateEdgeError.exception(sourceId); + } + + edges.add( edge ); + } + + public void addConditionalEdge( String sourceId, EdgeAction condition, Map mappings ) throws GraphStateException { + if( Objects.equals( sourceId, END)) { + throw GraphStateError.invalidEdgeIdentifier.exception(END); + } + if( mappings == null || mappings.isEmpty() ) { + throw GraphStateError.edgeMappingIsEmpty.exception(sourceId); + } + var edge = new Edge(sourceId, new EdgeValue<>(null, new EdgeCondition<>(condition, mappings)) ); + + if( edges.contains(edge ) ) { + throw GraphStateError.duplicateEdgeError.exception(sourceId); + } + + edges.add( edge ); + } + + private Node makeFakeNode(String id) { + return new Node<>(id, null); + } + + public Runnable compile() throws GraphStateException { + if( entryPoint == null ) { + throw GraphStateError.missingEntryPoint.exception(); + } + + if( !nodes.contains( makeFakeNode(entryPoint) ) ) { + throw GraphStateError.entryPointNotExist.exception(entryPoint); + } + + if( finishPoint!= null ) { + if( !nodes.contains( makeFakeNode(entryPoint) ) ) { + throw GraphStateError.finishPointNotExist.exception(entryPoint); + } + } + + for ( Edge edge: edges ) { + + if( !nodes.contains( makeFakeNode(edge.sourceId) ) ) { + throw GraphStateError.missingNodeReferencedByEdge.exception(edge.sourceId); + } + + if( edge.target.id() != null ) { + if(!Objects.equals(edge.target.id(), END) && !nodes.contains( makeFakeNode(edge.target.id()) ) ) { + throw GraphStateError.missingNodeReferencedByEdge.exception(edge.target.id()); + } + } + else if( edge.target.value() != null ) { + for ( String nodeId: edge.target.value().mappings().values() ) { + if(!Objects.equals(nodeId, END) && !nodes.contains( makeFakeNode(nodeId) ) ) { + throw GraphStateError.missingNodeInEdgeMapping.exception(edge.sourceId, nodeId); + } + } + } + else { + throw GraphStateError.invalidEdgeTarget.exception(edge.sourceId); + } + + + } + + return new Runnable(); + } +} diff --git a/src/main/java/org/bsc/langgraph4j/GraphStateException.java b/src/main/java/org/bsc/langgraph4j/GraphStateException.java new file mode 100644 index 0000000..02c1077 --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/GraphStateException.java @@ -0,0 +1,8 @@ +package org.bsc.langgraph4j; + +public class GraphStateException extends Exception { + + public GraphStateException( String errorMessage ) { + super(errorMessage); +} +} diff --git a/src/main/java/org/bsc/langgraph4j/NodeAction.java b/src/main/java/org/bsc/langgraph4j/NodeAction.java new file mode 100644 index 0000000..bef5f51 --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/NodeAction.java @@ -0,0 +1,19 @@ +package org.bsc.langgraph4j; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +record NodeAsyncActionAdapter(NodeAction action) implements NodeAsyncAction { + @Override public CompletableFuture> apply(State t) throws Exception { + return CompletableFuture.completedFuture(action.apply(t)); + } +} + +public interface NodeAction { + Map apply(T t) throws Exception; + + static NodeAsyncAction async( NodeAction action ) { + return new NodeAsyncActionAdapter<>( action ); + } + +} diff --git a/src/main/java/org/bsc/langgraph4j/NodeAsyncAction.java b/src/main/java/org/bsc/langgraph4j/NodeAsyncAction.java new file mode 100644 index 0000000..ba2687e --- /dev/null +++ b/src/main/java/org/bsc/langgraph4j/NodeAsyncAction.java @@ -0,0 +1,11 @@ +package org.bsc.langgraph4j; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + + +@FunctionalInterface +public interface NodeAsyncAction { + CompletableFuture> apply(T t) throws Exception; + +} diff --git a/src/test/java/org/bsc/langgraph4j/LangGraphTest.java b/src/test/java/org/bsc/langgraph4j/LangGraphTest.java new file mode 100644 index 0000000..e4caaf0 --- /dev/null +++ b/src/test/java/org/bsc/langgraph4j/LangGraphTest.java @@ -0,0 +1,73 @@ +package org.bsc.langgraph4j; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.bsc.langgraph4j.GraphState.END; +import static org.bsc.langgraph4j.NodeAction.async; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit test for simple App. + */ +public class LangGraphTest +{ + record BaseAgentState( Map data ) implements AgentState {} + + @Test + void testValidation() throws Exception { + + var workflow = new GraphState(BaseAgentState::new); + var exception = assertThrows(GraphStateException.class, workflow::compile); + System.out.println(exception.getMessage()); + assertEquals( "missing Entry Point", exception.getMessage()); + + workflow.setEntryPoint("agent_1"); + + exception = assertThrows(GraphStateException.class, workflow::compile); + System.out.println(exception.getMessage()); + assertEquals( "entryPoint: agent_1 doesn't exist!", exception.getMessage()); + + workflow.addNode("agent_1", async(( state ) -> { + System.out.print("agent_1 "); + System.out.println(state); + return Map.of("prop1", "test"); + }) ) ; + + assertNotNull(workflow.compile()); + + workflow.addEdge( "agent_1", END); + + assertNotNull(workflow.compile()); + + exception = assertThrows(GraphStateException.class, () -> + workflow.addEdge(END, "agent_1") ); + System.out.println(exception.getMessage()); + + exception = assertThrows(GraphStateException.class, () -> + workflow.addEdge("agent_1", "agent_2") ); + System.out.println(exception.getMessage()); + + workflow.addNode("agent_2", ( state ) -> { + + System.out.print( "agent_2: "); + System.out.println( state ); + + return completedFuture(Map.of("prop2", "test")); + }); + + workflow.addEdge("agent_2", "agent_3"); + + exception = assertThrows(GraphStateException.class, workflow::compile); + System.out.println(exception.getMessage()); + + exception = assertThrows(GraphStateException.class, () -> + workflow.addConditionalEdge("agent_1", ( state ) -> "agent_3" , Map.of() ) + ); + System.out.println(exception.getMessage()); + + } + +}