Skip to content

Commit

Permalink
feat: enable fluent interface on graph definition
Browse files Browse the repository at this point in the history
deprecate: setEntryPoint, setFinishPoint, setConditionalEntryPoint
  • Loading branch information
bsorrentino committed Aug 7, 2024
1 parent cd50013 commit 787d41c
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.state.AgentState;

Expand All @@ -20,6 +19,7 @@

import static java.util.Collections.emptyList;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.listOf;
Expand Down Expand Up @@ -50,8 +50,7 @@ public Optional<String> generation() {

}
public List<String> documents() {
Optional<List<String>> result = value("documents");
return result.orElse(emptyList());
return this.<List<String>>value("documents").orElse(emptyList());
}

}
Expand Down Expand Up @@ -248,43 +247,40 @@ private String gradeGeneration_v_documentsAndQuestion( State state ) {
}

public CompiledGraph<State> buildGraph() throws Exception {
var workflow = new StateGraph<>(State::new);

// Define the nodes
workflow.addNode("web_search", node_async(this::webSearch) ); // web search
workflow.addNode("retrieve", node_async(this::retrieve) ); // retrieve
workflow.addNode("grade_documents", node_async(this::gradeDocuments) ); // grade documents
workflow.addNode("generate", node_async(this::generate) ); // generatae
workflow.addNode("transform_query", node_async(this::transformQuery)); // transform_query

// Build graph
workflow.setConditionalEntryPoint(
edge_async(this::routeQuestion),
mapOf(
"web_search", "web_search",
"vectorstore", "retrieve"
));

workflow.addEdge("web_search", "generate");
workflow.addEdge("retrieve", "grade_documents");
workflow.addConditionalEdges(
"grade_documents",
edge_async(this::decideToGenerate),
mapOf(
"transform_query","transform_query",
"generate", "generate"
));
workflow.addEdge("transform_query", "retrieve");
workflow.addConditionalEdges(
"generate",
edge_async(this::gradeGeneration_v_documentsAndQuestion),
mapOf(
"not supported", "generate",
"useful", END,
"not useful", "transform_query"
));

return workflow.compile();
return new StateGraph<>(State::new)
// Define the nodes
.addNode("web_search", node_async(this::webSearch) ) // web search
.addNode("retrieve", node_async(this::retrieve) ) // retrieve
.addNode("grade_documents", node_async(this::gradeDocuments) ) // grade documents
.addNode("generate", node_async(this::generate) ) // generatae
.addNode("transform_query", node_async(this::transformQuery)) // transform_query
// Build graph
.addConditionalEdges(START,
edge_async(this::routeQuestion),
mapOf(
"web_search", "web_search",
"vectorstore", "retrieve"
))

.addEdge("web_search", "generate")
.addEdge("retrieve", "grade_documents")
.addConditionalEdges(
"grade_documents",
edge_async(this::decideToGenerate),
mapOf(
"transform_query","transform_query",
"generate", "generate"
))
.addEdge("transform_query", "retrieve")
.addConditionalEdges(
"generate",
edge_async(this::gradeGeneration_v_documentsAndQuestion),
mapOf(
"not supported", "generate",
"useful", END,
"not useful", "transform_query"
))
.compile();
}

public static void main( String[] args ) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;
import org.bsc.langgraph4j.state.AppenderChannel;
import org.bsc.langgraph4j.state.Channel;

import java.util.*;
import java.util.stream.Collectors;

import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.StateGraph.START;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

public class AgentExecutor {

public static class State extends AgentState {
static Map<String, Channel<?>> SCHEMA = mapOf(
"intermediate_steps", AppenderChannel.<IntermediateStep>of(ArrayList::new)
);

public State(Map<String, Object> initData) {
super(initData);
Expand All @@ -33,8 +39,8 @@ Optional<String> input() {
Optional<AgentOutcome> agentOutcome() {
return value("agent_outcome");
}
AppendableValue<IntermediateStep> intermediateSteps() {
return appendableValue("intermediate_steps");
List<IntermediateStep> intermediateSteps() {
return this.<List<IntermediateStep>>value("intermediate_steps").orElseGet(ArrayList::new);
}


Expand All @@ -45,7 +51,7 @@ Map<String,Object> runAgent( Agent agentRunnable, State state ) throws Exception
var input = state.input()
.orElseThrow(() -> new IllegalArgumentException("no input provided!"));

var intermediateSteps = state.intermediateSteps().values();
var intermediateSteps = state.intermediateSteps();

var response = agentRunnable.execute( input, intermediateSteps );

Expand Down Expand Up @@ -106,27 +112,21 @@ public CompiledGraph<State> compile(ChatLanguageModel chatLanguageModel, List<Ob
.tools( toolSpecifications )
.build();

var workflow = new StateGraph<>(State::new);

workflow.setEntryPoint("agent");

workflow.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
);

workflow.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
);

workflow.addConditionalEdges(
"agent",
edge_async(this::shouldContinue),
mapOf("continue", "action", "end", END)
);

workflow.addEdge("action", "agent");

return workflow.compile();
return new StateGraph<>(State.SCHEMA,State::new)
.addEdge(START,"agent")
.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
)
.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
)
.addConditionalEdges(
"agent",
edge_async(this::shouldContinue),
mapOf("continue", "action", "end", END)
)
.addEdge("action", "agent")
.compile();

}

Expand Down
71 changes: 63 additions & 8 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;
import org.bsc.langgraph4j.state.Channel;

import java.util.*;

import static java.lang.String.format;
import static java.util.Collections.unmodifiableMap;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

/**
* Represents a state graph with nodes and edges.
Expand Down Expand Up @@ -76,6 +79,7 @@ GraphRunnerException exception(String... args) {
}

public static String END = "__END__";
public static String START = "__START__";

Set<Node<State>> nodes = new LinkedHashSet<>();
Set<Edge<State>> edges = new LinkedHashSet<>();
Expand All @@ -84,20 +88,36 @@ GraphRunnerException exception(String... args) {
private String finishPoint;

private final AgentStateFactory<State> stateFactory;
private final Map<String, Channel<?>> channels;

/**
* Constructs a new StateGraph with the specified state factory.
*
* @param stateFactory the factory to create agent states
*/
public StateGraph(AgentStateFactory<State> stateFactory) {
this( mapOf(), stateFactory );

}

/**
*
* @param channels the state's schema of the graph
* @param stateFactory the factory to create agent states
*/
public StateGraph(Map<String, Channel<?>> channels, AgentStateFactory<State> stateFactory) {
this.stateFactory = stateFactory;
this.channels = channels;
}

public AgentStateFactory<State> getStateFactory() {
return stateFactory;
}

public Map<String, Channel<?>> getChannels() {
return unmodifiableMap(channels);
}

public EdgeValue<State> getEntryPoint() {
return entryPoint;
}
Expand All @@ -106,17 +126,37 @@ public String getFinishPoint() {
return finishPoint;
}

/**
* Sets the entry point of the graph.
*
* @param entryPoint the nodeId of the graph's entry-point
* @deprecated use addEdge(START, nodeId)
*/
@Deprecated
public void setEntryPoint(String entryPoint) {
this.entryPoint = new EdgeValue<>(entryPoint, null);
}
public void setConditionalEntryPoint(AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
if (mappings == null || mappings.isEmpty()) {
throw Errors.edgeMappingIsEmpty.exception("entry point");
}
this.entryPoint = new EdgeValue<>(null, new EdgeCondition<>(condition, mappings));

/**
* Sets a conditional entry point of the graph.
*
* @param condition the edge condition
* @param mappings the edge mappings
* @throws GraphStateException if the edge mappings is null or empty
* @deprecated use addConditionalEdge(START, consition, mappings)
*/
@Deprecated
public void setConditionalEntryPoint(AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
addConditionalEdges(START, condition, mappings);
}

/**
* Sets the identifier of the node that represents the end of the graph execution.
*
* @param finishPoint the identifier of the finish point node
* @deprecated use use addEdge(nodeId, END)
*/
@Deprecated
public void setFinishPoint(String finishPoint) {
this.finishPoint = finishPoint;
}
Expand All @@ -128,7 +168,7 @@ public void setFinishPoint(String finishPoint) {
* @param action the action to be performed by the node
* @throws GraphStateException if the node identifier is invalid or the node already exists
*/
public void addNode(String id, AsyncNodeAction<State> action) throws GraphStateException {
public StateGraph<State> addNode(String id, AsyncNodeAction<State> action) throws GraphStateException {
if (Objects.equals(id, END)) {
throw Errors.invalidNodeIdentifier.exception(END);
}
Expand All @@ -139,6 +179,7 @@ public void addNode(String id, AsyncNodeAction<State> action) throws GraphStateE
}

nodes.add(node);
return this;
}

/**
Expand All @@ -148,17 +189,24 @@ public void addNode(String id, AsyncNodeAction<State> action) throws GraphStateE
* @param targetId the identifier of the target node
* @throws GraphStateException if the edge identifier is invalid or the edge already exists
*/
public void addEdge(String sourceId, String targetId) throws GraphStateException {
public StateGraph<State> addEdge(String sourceId, String targetId) throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}

if (Objects.equals(sourceId, START)) {
this.entryPoint = new EdgeValue<>(targetId, null);
return this;
}

var edge = new Edge<State>(sourceId, new EdgeValue<>(targetId, null));

if (edges.contains(edge)) {
throw Errors.duplicateEdgeError.exception(sourceId);
}

edges.add(edge);
return this;
}

/**
Expand All @@ -169,20 +217,27 @@ public void addEdge(String sourceId, String targetId) throws GraphStateException
* @param mappings the mappings of conditions to target nodes
* @throws GraphStateException if the edge identifier is invalid, the mappings are empty, or the edge already exists
*/
public void addConditionalEdges(String sourceId, AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
public StateGraph<State> addConditionalEdges(String sourceId, AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
if (Objects.equals(sourceId, END)) {
throw Errors.invalidEdgeIdentifier.exception(END);
}
if (mappings == null || mappings.isEmpty()) {
throw Errors.edgeMappingIsEmpty.exception(sourceId);
}

if (Objects.equals(sourceId, START)) {
this.entryPoint = new EdgeValue<>(null, new EdgeCondition<>(condition, mappings));
return this;
}

var edge = new Edge<State>(sourceId, new EdgeValue<>(null, new EdgeCondition<>(condition, mappings)));

if (edges.contains(edge)) {
throw Errors.duplicateEdgeError.exception(sourceId);
}

edges.add(edge);
return this;
}

/**
Expand Down
Loading

0 comments on commit 787d41c

Please sign in to comment.