-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
9 changed files
with
329 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import java.util.Optional; | ||
|
||
public interface AgentState { | ||
|
||
java.util.Map<String,Object> data(); | ||
|
||
default <T> Optional<T> getValue(String key) { | ||
return Optional.ofNullable((T) data().get(key)); | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
|
||
import java.util.Map; | ||
import java.util.function.Function; | ||
|
||
public interface AgentStateFactory<State extends AgentState> extends Function<Map<String,Object>, State> { | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import java.util.List; | ||
|
||
public record AppendableValue<T>(List<T> values ) { | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import java.util.concurrent.CompletableFuture; | ||
|
||
@FunctionalInterface | ||
public interface EdgeAction<S extends AgentState> { | ||
|
||
String apply(S t) throws Exception; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import java.util.HashSet; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.Set; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import static java.lang.String.format; | ||
|
||
enum GraphStateError { | ||
invalidNodeIdentifier( "END is not a valid node id!"), | ||
invalidEdgeIdentifier( "END is not a valid edge sourceId!"), | ||
duplicateNodeError("node with id: %s already exist!"), | ||
duplicateEdgeError("edge with id: %s already exist!"), | ||
edgeMappingIsEmpty( "edge mapping is empty!" ), | ||
missingEntryPoint( "missing Entry Point" ), | ||
entryPointNotExist("entryPoint: %s doesn't exist!" ), | ||
finishPointNotExist( "finishPoint: %s doesn't exist!"), | ||
missingNodeReferencedByEdge( "edge sourceId: %s reference a not existent node!"), | ||
missingNodeInEdgeMapping( "edge mapping for sourceId: %s contains a not existent nodeId %s!"), | ||
invalidEdgeTarget( "edge sourceId: %s has an initialized target value!" ) | ||
; | ||
private final String errorMessage; | ||
|
||
GraphStateError(String errorMessage ) { | ||
this.errorMessage = errorMessage; | ||
} | ||
|
||
GraphStateException exception(String... args ) { | ||
return new GraphStateException( format(errorMessage, (Object[]) args) ); | ||
} | ||
} | ||
|
||
record EdgeCondition<S extends AgentState>(EdgeAction<S> action, Map<String,String> mappings) {} | ||
record EdgeValue<State extends AgentState>(String id, EdgeCondition<State> value) {} | ||
public class GraphState<State extends AgentState> { | ||
|
||
public class Runnable { | ||
|
||
} | ||
public static String END = "__END__"; | ||
|
||
record Node<State extends AgentState>(String id, NodeAsyncAction<State> action) { | ||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Node<?> node = (Node<?>) o; | ||
return Objects.equals(id, node.id); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(id); | ||
} | ||
} | ||
record Edge<State extends AgentState>(String sourceId, EdgeValue<State> target) { | ||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
Edge<?> node = (Edge<?>) o; | ||
return Objects.equals(sourceId, node.sourceId); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(sourceId); | ||
} | ||
|
||
} | ||
|
||
Set<Node<State>> nodes = new HashSet<>(); | ||
Set<Edge<State>> edges = new HashSet<>(); | ||
|
||
String entryPoint; | ||
String finishPoint; | ||
|
||
AgentStateFactory<State> stateFactory; | ||
|
||
public GraphState( AgentStateFactory<State> stateFactory ) { | ||
this.stateFactory = stateFactory; | ||
} | ||
|
||
public void setEntryPoint(String entryPoint) { | ||
this.entryPoint = entryPoint; | ||
} | ||
|
||
public void setFinishPoint(String finishPoint) { | ||
this.finishPoint = finishPoint; | ||
} | ||
|
||
public void addNode(String id, NodeAsyncAction<State> action) throws GraphStateException { | ||
if( Objects.equals( id, END)) { | ||
throw GraphStateError.invalidNodeIdentifier.exception(END); | ||
} | ||
var node = new Node<State>(id, action); | ||
|
||
if( nodes.contains(node ) ) { | ||
throw GraphStateError.duplicateNodeError.exception(id); | ||
} | ||
|
||
nodes.add( node ); | ||
} | ||
|
||
public void addEdge(String sourceId, String targetId) throws GraphStateException { | ||
if( Objects.equals( sourceId, END)) { | ||
throw GraphStateError.invalidEdgeIdentifier.exception(END); | ||
} | ||
var edge = new Edge<State>(sourceId, new EdgeValue<>(targetId, null) ); | ||
|
||
if( edges.contains(edge ) ) { | ||
throw GraphStateError.duplicateEdgeError.exception(sourceId); | ||
} | ||
|
||
edges.add( edge ); | ||
} | ||
|
||
public void addConditionalEdge( String sourceId, EdgeAction<State> condition, Map<String,String> mappings ) throws GraphStateException { | ||
if( Objects.equals( sourceId, END)) { | ||
throw GraphStateError.invalidEdgeIdentifier.exception(END); | ||
} | ||
if( mappings == null || mappings.isEmpty() ) { | ||
throw GraphStateError.edgeMappingIsEmpty.exception(sourceId); | ||
} | ||
var edge = new Edge<State>(sourceId, new EdgeValue<>(null, new EdgeCondition<>(condition, mappings)) ); | ||
|
||
if( edges.contains(edge ) ) { | ||
throw GraphStateError.duplicateEdgeError.exception(sourceId); | ||
} | ||
|
||
edges.add( edge ); | ||
} | ||
|
||
private Node<State> makeFakeNode(String id) { | ||
return new Node<>(id, null); | ||
} | ||
|
||
public Runnable compile() throws GraphStateException { | ||
if( entryPoint == null ) { | ||
throw GraphStateError.missingEntryPoint.exception(); | ||
} | ||
|
||
if( !nodes.contains( makeFakeNode(entryPoint) ) ) { | ||
throw GraphStateError.entryPointNotExist.exception(entryPoint); | ||
} | ||
|
||
if( finishPoint!= null ) { | ||
if( !nodes.contains( makeFakeNode(entryPoint) ) ) { | ||
throw GraphStateError.finishPointNotExist.exception(entryPoint); | ||
} | ||
} | ||
|
||
for ( Edge<State> edge: edges ) { | ||
|
||
if( !nodes.contains( makeFakeNode(edge.sourceId) ) ) { | ||
throw GraphStateError.missingNodeReferencedByEdge.exception(edge.sourceId); | ||
} | ||
|
||
if( edge.target.id() != null ) { | ||
if(!Objects.equals(edge.target.id(), END) && !nodes.contains( makeFakeNode(edge.target.id()) ) ) { | ||
throw GraphStateError.missingNodeReferencedByEdge.exception(edge.target.id()); | ||
} | ||
} | ||
else if( edge.target.value() != null ) { | ||
for ( String nodeId: edge.target.value().mappings().values() ) { | ||
if(!Objects.equals(nodeId, END) && !nodes.contains( makeFakeNode(nodeId) ) ) { | ||
throw GraphStateError.missingNodeInEdgeMapping.exception(edge.sourceId, nodeId); | ||
} | ||
} | ||
} | ||
else { | ||
throw GraphStateError.invalidEdgeTarget.exception(edge.sourceId); | ||
} | ||
|
||
|
||
} | ||
|
||
return new Runnable(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
public class GraphStateException extends Exception { | ||
|
||
public GraphStateException( String errorMessage ) { | ||
super(errorMessage); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
record NodeAsyncActionAdapter<State extends AgentState>(NodeAction<State> action) implements NodeAsyncAction<State> { | ||
@Override public CompletableFuture<Map<String, Object>> apply(State t) throws Exception { | ||
return CompletableFuture.completedFuture(action.apply(t)); | ||
} | ||
} | ||
|
||
public interface NodeAction <T extends AgentState> { | ||
Map<String, Object> apply(T t) throws Exception; | ||
|
||
static <T extends AgentState> NodeAsyncAction<T> async( NodeAction<T> action ) { | ||
return new NodeAsyncActionAdapter<>( action ); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
|
||
@FunctionalInterface | ||
public interface NodeAsyncAction<T extends AgentState> { | ||
CompletableFuture<Map<String, Object>> apply(T t) throws Exception; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package org.bsc.langgraph4j; | ||
|
||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.Map; | ||
|
||
import static java.util.concurrent.CompletableFuture.completedFuture; | ||
import static org.bsc.langgraph4j.GraphState.END; | ||
import static org.bsc.langgraph4j.NodeAction.async; | ||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
/** | ||
* Unit test for simple App. | ||
*/ | ||
public class LangGraphTest | ||
{ | ||
record BaseAgentState( Map<String,Object> data ) implements AgentState {} | ||
|
||
@Test | ||
void testValidation() throws Exception { | ||
|
||
var workflow = new GraphState<BaseAgentState>(BaseAgentState::new); | ||
var exception = assertThrows(GraphStateException.class, workflow::compile); | ||
System.out.println(exception.getMessage()); | ||
assertEquals( "missing Entry Point", exception.getMessage()); | ||
|
||
workflow.setEntryPoint("agent_1"); | ||
|
||
exception = assertThrows(GraphStateException.class, workflow::compile); | ||
System.out.println(exception.getMessage()); | ||
assertEquals( "entryPoint: agent_1 doesn't exist!", exception.getMessage()); | ||
|
||
workflow.addNode("agent_1", async(( state ) -> { | ||
System.out.print("agent_1 "); | ||
System.out.println(state); | ||
return Map.of("prop1", "test"); | ||
}) ) ; | ||
|
||
assertNotNull(workflow.compile()); | ||
|
||
workflow.addEdge( "agent_1", END); | ||
|
||
assertNotNull(workflow.compile()); | ||
|
||
exception = assertThrows(GraphStateException.class, () -> | ||
workflow.addEdge(END, "agent_1") ); | ||
System.out.println(exception.getMessage()); | ||
|
||
exception = assertThrows(GraphStateException.class, () -> | ||
workflow.addEdge("agent_1", "agent_2") ); | ||
System.out.println(exception.getMessage()); | ||
|
||
workflow.addNode("agent_2", ( state ) -> { | ||
|
||
System.out.print( "agent_2: "); | ||
System.out.println( state ); | ||
|
||
return completedFuture(Map.of("prop2", "test")); | ||
}); | ||
|
||
workflow.addEdge("agent_2", "agent_3"); | ||
|
||
exception = assertThrows(GraphStateException.class, workflow::compile); | ||
System.out.println(exception.getMessage()); | ||
|
||
exception = assertThrows(GraphStateException.class, () -> | ||
workflow.addConditionalEdge("agent_1", ( state ) -> "agent_3" , Map.of() ) | ||
); | ||
System.out.println(exception.getMessage()); | ||
|
||
} | ||
|
||
} |