Skip to content

Commit

Permalink
feat: initial implementation
Browse files Browse the repository at this point in the history
    graph creation
    graph compilation
  • Loading branch information
bsorrentino committed Mar 17, 2024
1 parent dbd8a87 commit dc46c9b
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/main/java/org/bsc/langgraph4j/AgentState.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.bsc.langgraph4j;

import java.util.Optional;

public interface AgentState {

java.util.Map<String,Object> data();

default <T> Optional<T> getValue(String key) {
return Optional.ofNullable((T) data().get(key));
};
}
9 changes: 9 additions & 0 deletions src/main/java/org/bsc/langgraph4j/AgentStateFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.bsc.langgraph4j;


import java.util.Map;
import java.util.function.Function;

public interface AgentStateFactory<State extends AgentState> extends Function<Map<String,Object>, State> {

}
6 changes: 6 additions & 0 deletions src/main/java/org/bsc/langgraph4j/AppendableValue.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.bsc.langgraph4j;

import java.util.List;

public record AppendableValue<T>(List<T> values ) {
}
9 changes: 9 additions & 0 deletions src/main/java/org/bsc/langgraph4j/EdgeAction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.bsc.langgraph4j;

import java.util.concurrent.CompletableFuture;

@FunctionalInterface
public interface EdgeAction<S extends AgentState> {

String apply(S t) throws Exception;
}
182 changes: 182 additions & 0 deletions src/main/java/org/bsc/langgraph4j/GraphState.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package org.bsc.langgraph4j;

import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static java.lang.String.format;

enum GraphStateError {
invalidNodeIdentifier( "END is not a valid node id!"),
invalidEdgeIdentifier( "END is not a valid edge sourceId!"),
duplicateNodeError("node with id: %s already exist!"),
duplicateEdgeError("edge with id: %s already exist!"),
edgeMappingIsEmpty( "edge mapping is empty!" ),
missingEntryPoint( "missing Entry Point" ),
entryPointNotExist("entryPoint: %s doesn't exist!" ),
finishPointNotExist( "finishPoint: %s doesn't exist!"),
missingNodeReferencedByEdge( "edge sourceId: %s reference a not existent node!"),
missingNodeInEdgeMapping( "edge mapping for sourceId: %s contains a not existent nodeId %s!"),
invalidEdgeTarget( "edge sourceId: %s has an initialized target value!" )
;
private final String errorMessage;

GraphStateError(String errorMessage ) {
this.errorMessage = errorMessage;
}

GraphStateException exception(String... args ) {
return new GraphStateException( format(errorMessage, (Object[]) args) );
}
}

record EdgeCondition<S extends AgentState>(EdgeAction<S> action, Map<String,String> mappings) {}
record EdgeValue<State extends AgentState>(String id, EdgeCondition<State> value) {}
public class GraphState<State extends AgentState> {

public class Runnable {

}
public static String END = "__END__";

record Node<State extends AgentState>(String id, NodeAsyncAction<State> action) {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Node<?> node = (Node<?>) o;
return Objects.equals(id, node.id);
}

@Override
public int hashCode() {
return Objects.hash(id);
}
}
record Edge<State extends AgentState>(String sourceId, EdgeValue<State> target) {
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Edge<?> node = (Edge<?>) o;
return Objects.equals(sourceId, node.sourceId);
}

@Override
public int hashCode() {
return Objects.hash(sourceId);
}

}

Set<Node<State>> nodes = new HashSet<>();
Set<Edge<State>> edges = new HashSet<>();

String entryPoint;
String finishPoint;

AgentStateFactory<State> stateFactory;

public GraphState( AgentStateFactory<State> stateFactory ) {
this.stateFactory = stateFactory;
}

public void setEntryPoint(String entryPoint) {
this.entryPoint = entryPoint;
}

public void setFinishPoint(String finishPoint) {
this.finishPoint = finishPoint;
}

public void addNode(String id, NodeAsyncAction<State> action) throws GraphStateException {
if( Objects.equals( id, END)) {
throw GraphStateError.invalidNodeIdentifier.exception(END);
}
var node = new Node<State>(id, action);

if( nodes.contains(node ) ) {
throw GraphStateError.duplicateNodeError.exception(id);
}

nodes.add( node );
}

public void addEdge(String sourceId, String targetId) throws GraphStateException {
if( Objects.equals( sourceId, END)) {
throw GraphStateError.invalidEdgeIdentifier.exception(END);
}
var edge = new Edge<State>(sourceId, new EdgeValue<>(targetId, null) );

if( edges.contains(edge ) ) {
throw GraphStateError.duplicateEdgeError.exception(sourceId);
}

edges.add( edge );
}

public void addConditionalEdge( String sourceId, EdgeAction<State> condition, Map<String,String> mappings ) throws GraphStateException {
if( Objects.equals( sourceId, END)) {
throw GraphStateError.invalidEdgeIdentifier.exception(END);
}
if( mappings == null || mappings.isEmpty() ) {
throw GraphStateError.edgeMappingIsEmpty.exception(sourceId);
}
var edge = new Edge<State>(sourceId, new EdgeValue<>(null, new EdgeCondition<>(condition, mappings)) );

if( edges.contains(edge ) ) {
throw GraphStateError.duplicateEdgeError.exception(sourceId);
}

edges.add( edge );
}

private Node<State> makeFakeNode(String id) {
return new Node<>(id, null);
}

public Runnable compile() throws GraphStateException {
if( entryPoint == null ) {
throw GraphStateError.missingEntryPoint.exception();
}

if( !nodes.contains( makeFakeNode(entryPoint) ) ) {
throw GraphStateError.entryPointNotExist.exception(entryPoint);
}

if( finishPoint!= null ) {
if( !nodes.contains( makeFakeNode(entryPoint) ) ) {
throw GraphStateError.finishPointNotExist.exception(entryPoint);
}
}

for ( Edge<State> edge: edges ) {

if( !nodes.contains( makeFakeNode(edge.sourceId) ) ) {
throw GraphStateError.missingNodeReferencedByEdge.exception(edge.sourceId);
}

if( edge.target.id() != null ) {
if(!Objects.equals(edge.target.id(), END) && !nodes.contains( makeFakeNode(edge.target.id()) ) ) {
throw GraphStateError.missingNodeReferencedByEdge.exception(edge.target.id());
}
}
else if( edge.target.value() != null ) {
for ( String nodeId: edge.target.value().mappings().values() ) {
if(!Objects.equals(nodeId, END) && !nodes.contains( makeFakeNode(nodeId) ) ) {
throw GraphStateError.missingNodeInEdgeMapping.exception(edge.sourceId, nodeId);
}
}
}
else {
throw GraphStateError.invalidEdgeTarget.exception(edge.sourceId);
}


}

return new Runnable();
}
}
8 changes: 8 additions & 0 deletions src/main/java/org/bsc/langgraph4j/GraphStateException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.bsc.langgraph4j;

public class GraphStateException extends Exception {

public GraphStateException( String errorMessage ) {
super(errorMessage);
}
}
19 changes: 19 additions & 0 deletions src/main/java/org/bsc/langgraph4j/NodeAction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.bsc.langgraph4j;

import java.util.Map;
import java.util.concurrent.CompletableFuture;

record NodeAsyncActionAdapter<State extends AgentState>(NodeAction<State> action) implements NodeAsyncAction<State> {
@Override public CompletableFuture<Map<String, Object>> apply(State t) throws Exception {
return CompletableFuture.completedFuture(action.apply(t));
}
}

public interface NodeAction <T extends AgentState> {
Map<String, Object> apply(T t) throws Exception;

static <T extends AgentState> NodeAsyncAction<T> async( NodeAction<T> action ) {
return new NodeAsyncActionAdapter<>( action );
}

}
11 changes: 11 additions & 0 deletions src/main/java/org/bsc/langgraph4j/NodeAsyncAction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.bsc.langgraph4j;

import java.util.Map;
import java.util.concurrent.CompletableFuture;


@FunctionalInterface
public interface NodeAsyncAction<T extends AgentState> {
CompletableFuture<Map<String, Object>> apply(T t) throws Exception;

}
73 changes: 73 additions & 0 deletions src/test/java/org/bsc/langgraph4j/LangGraphTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package org.bsc.langgraph4j;

import org.junit.jupiter.api.Test;

import java.util.Map;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static org.bsc.langgraph4j.GraphState.END;
import static org.bsc.langgraph4j.NodeAction.async;
import static org.junit.jupiter.api.Assertions.*;

/**
* Unit test for simple App.
*/
public class LangGraphTest
{
record BaseAgentState( Map<String,Object> data ) implements AgentState {}

@Test
void testValidation() throws Exception {

var workflow = new GraphState<BaseAgentState>(BaseAgentState::new);
var exception = assertThrows(GraphStateException.class, workflow::compile);
System.out.println(exception.getMessage());
assertEquals( "missing Entry Point", exception.getMessage());

workflow.setEntryPoint("agent_1");

exception = assertThrows(GraphStateException.class, workflow::compile);
System.out.println(exception.getMessage());
assertEquals( "entryPoint: agent_1 doesn't exist!", exception.getMessage());

workflow.addNode("agent_1", async(( state ) -> {
System.out.print("agent_1 ");
System.out.println(state);
return Map.of("prop1", "test");
}) ) ;

assertNotNull(workflow.compile());

workflow.addEdge( "agent_1", END);

assertNotNull(workflow.compile());

exception = assertThrows(GraphStateException.class, () ->
workflow.addEdge(END, "agent_1") );
System.out.println(exception.getMessage());

exception = assertThrows(GraphStateException.class, () ->
workflow.addEdge("agent_1", "agent_2") );
System.out.println(exception.getMessage());

workflow.addNode("agent_2", ( state ) -> {

System.out.print( "agent_2: ");
System.out.println( state );

return completedFuture(Map.of("prop2", "test"));
});

workflow.addEdge("agent_2", "agent_3");

exception = assertThrows(GraphStateException.class, workflow::compile);
System.out.println(exception.getMessage());

exception = assertThrows(GraphStateException.class, () ->
workflow.addConditionalEdge("agent_1", ( state ) -> "agent_3" , Map.of() )
);
System.out.println(exception.getMessage());

}

}

0 comments on commit dc46c9b

Please sign in to comment.