Skip to content

Commit

Permalink
Merge branch 'feature/#11_checkpointer' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 7, 2024
2 parents 9ed434a + 77e4723 commit 663daff
Show file tree
Hide file tree
Showing 16 changed files with 654 additions and 13 deletions.
8 changes: 7 additions & 1 deletion core-jdk8/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
<dependency>
<groupId>org.bsc.async</groupId>
<artifactId>async-generator-jdk8</artifactId>
<version>2.0.0</version>
<version>2.0.1</version>
</dependency>

<dependency>
Expand All @@ -41,6 +41,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-jdk14</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
45 changes: 45 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;

import java.util.Optional;


public class CompileConfig {

private BaseCheckpointSaver checkpointSaver;
private String[] interruptBefore = {};
private String[] interruptAfter = {};

public Optional<BaseCheckpointSaver> getCheckpointSaver() { return Optional.ofNullable(checkpointSaver); }
public String[] getInterruptBefore() { return interruptBefore; }
public String[] getInterruptAfter() { return interruptAfter; }

public static Builder builder() {
return new Builder();
}

public static class Builder {
private final CompileConfig config = new CompileConfig();

public Builder checkpointSaver(BaseCheckpointSaver checkpointSaver) {
this.config.checkpointSaver = checkpointSaver;
return this;
}
public Builder interruptBefore(String... interruptBefore) {
this.config.interruptBefore = interruptBefore;
return this;
}
public Builder interruptAfter(String... interruptAfter) {
this.config.interruptAfter = interruptAfter;
return this;
}
public CompileConfig build() {
return config;
}
}


private CompileConfig() {}

}
63 changes: 58 additions & 5 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import org.bsc.async.AsyncGenerator;
import org.bsc.async.AsyncGeneratorQueue;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.state.AgentState;

import java.util.*;
Expand All @@ -31,14 +33,16 @@ public class CompiledGraph<State extends AgentState> {
final Map<String, EdgeValue<State>> edges = new LinkedHashMap<>();

private int maxIterations = 25;
private final CompileConfig compileConfig;

/**
* Constructs a CompiledGraph with the given StateGraph.
*
* @param stateGraph the StateGraph to be used in this CompiledGraph
*/
protected CompiledGraph(StateGraph<State> stateGraph) {
protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfig ) {
this.stateGraph = stateGraph;
this.compileConfig = compileConfig;
stateGraph.nodes.forEach(n ->
nodes.put(n.id(), n.action())
);
Expand Down Expand Up @@ -105,21 +109,46 @@ private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

private void addCheckpoint( String nodeId, State state ) throws Exception {
if( compileConfig.getCheckpointSaver().isPresent() ) {
Checkpoint.Value value = Checkpoint.Value.of(state, nodeId);
compileConfig.getCheckpointSaver().get().put( new Checkpoint(value) );
}
}

State getInitialState(Map<String,Object> inputs) {

return compileConfig.getCheckpointSaver()
.flatMap(BaseCheckpointSaver::getLast)
.map( cp -> {
var state = cp.getValue().getState();
return state.mergeWith(inputs, stateGraph.getStateFactory());
})
.orElseGet( () ->
stateGraph.getStateFactory().apply(inputs)
);
}

/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) throws Exception {
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, InvokeConfig config ) throws Exception {
Objects.requireNonNull(config, "config cannot be null");

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {

try {
var currentState = stateGraph.getStateFactory().apply(inputs);

var currentState = getInitialState(inputs);

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("start", currentState)) ));
addCheckpoint( "start", currentState );

log.trace( "START");

var currentNodeId = this.getEntryPoint( currentState );
Expand All @@ -142,6 +171,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro
currentState = currentState.mergeWith(partialState, stateGraph.getStateFactory());

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) ));
addCheckpoint( currentNodeId, currentState );

if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
break;
Expand All @@ -161,6 +191,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro
}

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("stop", currentState) ) ));
addCheckpoint( "stop", currentState );
log.trace( "STOP");

} catch (Exception e) {
Expand All @@ -171,16 +202,27 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro

}

/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) throws Exception {
return this.stream( inputs, InvokeConfig.builder().build() );
}
/**
* Invokes the graph execution with the provided inputs and returns the final state.
*
* @param inputs the input map
* @param config the invoke configuration
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
public Optional<State> invoke(Map<String,Object> inputs, InvokeConfig config ) throws Exception {

var sourceIterator = stream(inputs).iterator();
var sourceIterator = stream(inputs, config).iterator();

var result = StreamSupport.stream(
Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED),
Expand All @@ -189,6 +231,17 @@ public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return result.reduce((a, b) -> b).map( NodeOutput::state);
}

/**
* Invokes the graph execution with the provided inputs and returns the final state.
*
* @param inputs the input map
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
public Optional<State> invoke(Map<String,Object> inputs ) throws Exception {
return this.invoke( inputs, InvokeConfig.builder().build() );
}

/**
* Generates a drawable graph representation of the state graph.
*
Expand Down
2 changes: 1 addition & 1 deletion core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@Value
@Accessors(fluent = true)
class EdgeValue<State extends AgentState> {
public class EdgeValue<State extends AgentState> {

/**
* The unique identifier for the edge value.
Expand Down
39 changes: 39 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.bsc.langgraph4j;

import org.bsc.langgraph4j.checkpoint.CheckpointConfig;

import java.util.Optional;

public class InvokeConfig {

private CheckpointConfig checkpointConfig;

public Optional<CheckpointConfig> getCheckpointConfig() {
return Optional.ofNullable(checkpointConfig);
}

static Builder builder() {
return new Builder();
}

public static class Builder {

private String checkpointThreadId;

public Builder checkpointThreadId(String threadId) {
this.checkpointThreadId = threadId;
return this;
}
public InvokeConfig build() {
InvokeConfig result = new InvokeConfig();

if( checkpointThreadId != null ) {
result.checkpointConfig = CheckpointConfig.of(checkpointThreadId);
}

return result;
}
}

private InvokeConfig() {}
}
18 changes: 16 additions & 2 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,13 @@ private Node<State> makeFakeNode(String id) {
/**
* Compiles the state graph into a compiled graph.
*
* @param config the compile configuration
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph<State> compile() throws GraphStateException {
public CompiledGraph<State> compile( CompileConfig config ) throws GraphStateException {
Objects.requireNonNull(config, "config cannot be null");

if (entryPoint == null) {
throw Errors.missingEntryPoint.exception();
}
Expand Down Expand Up @@ -237,6 +240,17 @@ public CompiledGraph<State> compile() throws GraphStateException {
}
}

return new CompiledGraph<>(this);
return new CompiledGraph<>(this, config);
}

/**
* Compiles the state graph into a compiled graph.
*
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
public CompiledGraph<State> compile() throws GraphStateException {
return compile(CompileConfig.builder().build());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.bsc.langgraph4j.checkpoint;

import java.io.Externalizable;
import java.util.Collection;
import java.util.Optional;

public interface BaseCheckpointSaver {


Collection<Checkpoint> list();
Optional<Checkpoint> getLast();
void put( Checkpoint checkpoint ) throws Exception;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.bsc.langgraph4j.checkpoint;

import lombok.Data;
import org.bsc.langgraph4j.state.AgentState;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.*;


/**
* Represents a checkpoint of an agent state.
*
* The checkpoint is an immutable object that holds an {@link AgentState}
* and a {@code String} that represents the next state.
*
* The checkpoint is serializable and can be persisted and restored.
*
* @see AgentState
* @see Externalizable
*/
public class Checkpoint {

@lombok.Value(staticConstructor="of")
public static class Value {
AgentState state;
String nodeId;
}

String id;
Value value;

public final String getId() {
return id;
}
public final Value getValue() {
return value;
}

public Checkpoint( Value value ) {
this(UUID.randomUUID().toString(), value );
}
public Checkpoint(String id, Value value) {
Objects.requireNonNull(id, "id cannot be null");
Objects.requireNonNull(value, "value cannot be null");
this.id = id;
this.value = value;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.bsc.langgraph4j.checkpoint;

import lombok.Value;

@Value(staticConstructor = "of")
public class CheckpointConfig {
String threadId;
}
Loading

0 comments on commit 663daff

Please sign in to comment.