diff --git a/core-jdk8/pom.xml b/core-jdk8/pom.xml
index ab237d2..a543f47 100644
--- a/core-jdk8/pom.xml
+++ b/core-jdk8/pom.xml
@@ -27,7 +27,7 @@
org.bsc.async
async-generator-jdk8
- 2.0.0
+ 2.0.1
@@ -41,6 +41,12 @@
test
+
+ org.slf4j
+ slf4j-jdk14
+ test
+
+
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java
new file mode 100644
index 0000000..d6d2f79
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java
@@ -0,0 +1,45 @@
+package org.bsc.langgraph4j;
+
+import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
+
+import java.util.Optional;
+
+
+public class CompileConfig {
+
+ private BaseCheckpointSaver checkpointSaver;
+ private String[] interruptBefore = {};
+ private String[] interruptAfter = {};
+
+ public Optional getCheckpointSaver() { return Optional.ofNullable(checkpointSaver); }
+ public String[] getInterruptBefore() { return interruptBefore; }
+ public String[] getInterruptAfter() { return interruptAfter; }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+ private final CompileConfig config = new CompileConfig();
+
+ public Builder checkpointSaver(BaseCheckpointSaver checkpointSaver) {
+ this.config.checkpointSaver = checkpointSaver;
+ return this;
+ }
+ public Builder interruptBefore(String... interruptBefore) {
+ this.config.interruptBefore = interruptBefore;
+ return this;
+ }
+ public Builder interruptAfter(String... interruptAfter) {
+ this.config.interruptAfter = interruptAfter;
+ return this;
+ }
+ public CompileConfig build() {
+ return config;
+ }
+ }
+
+
+ private CompileConfig() {}
+
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
index 6336fca..6b4c4c9 100644
--- a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
@@ -6,6 +6,8 @@
import org.bsc.async.AsyncGenerator;
import org.bsc.async.AsyncGeneratorQueue;
import org.bsc.langgraph4j.action.AsyncNodeAction;
+import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
+import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.state.AgentState;
import java.util.*;
@@ -31,14 +33,16 @@ public class CompiledGraph {
final Map> edges = new LinkedHashMap<>();
private int maxIterations = 25;
+ private final CompileConfig compileConfig;
/**
* Constructs a CompiledGraph with the given StateGraph.
*
* @param stateGraph the StateGraph to be used in this CompiledGraph
*/
- protected CompiledGraph(StateGraph stateGraph) {
+ protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig ) {
this.stateGraph = stateGraph;
+ this.compileConfig = compileConfig;
stateGraph.nodes.forEach(n ->
nodes.put(n.id(), n.action())
);
@@ -105,21 +109,46 @@ private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}
+ private void addCheckpoint( String nodeId, State state ) throws Exception {
+ if( compileConfig.getCheckpointSaver().isPresent() ) {
+ Checkpoint.Value value = Checkpoint.Value.of(state, nodeId);
+ compileConfig.getCheckpointSaver().get().put( new Checkpoint(value) );
+ }
+ }
+
+ State getInitialState(Map inputs) {
+
+ return compileConfig.getCheckpointSaver()
+ .flatMap(BaseCheckpointSaver::getLast)
+ .map( cp -> {
+ var state = cp.getValue().getState();
+ return state.mergeWith(inputs, stateGraph.getStateFactory());
+ })
+ .orElseGet( () ->
+ stateGraph.getStateFactory().apply(inputs)
+ );
+ }
+
/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
* @param inputs the input map
+ * @param config the invoke configuration
* @return an AsyncGenerator stream of NodeOutput
* @throws Exception if there is an error creating the stream
*/
- public AsyncGenerator> stream(Map inputs ) throws Exception {
+ public AsyncGenerator> stream(Map inputs, InvokeConfig config ) throws Exception {
+ Objects.requireNonNull(config, "config cannot be null");
return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {
try {
- var currentState = stateGraph.getStateFactory().apply(inputs);
+
+ var currentState = getInitialState(inputs);
queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("start", currentState)) ));
+ addCheckpoint( "start", currentState );
+
log.trace( "START");
var currentNodeId = this.getEntryPoint( currentState );
@@ -142,6 +171,7 @@ public AsyncGenerator> stream(Map inputs ) thro
currentState = currentState.mergeWith(partialState, stateGraph.getStateFactory());
queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) ));
+ addCheckpoint( currentNodeId, currentState );
if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
break;
@@ -161,6 +191,7 @@ public AsyncGenerator> stream(Map inputs ) thro
}
queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("stop", currentState) ) ));
+ addCheckpoint( "stop", currentState );
log.trace( "STOP");
} catch (Exception e) {
@@ -171,16 +202,27 @@ public AsyncGenerator> stream(Map inputs ) thro
}
+ /**
+ * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
+ *
+ * @param inputs the input map
+ * @return an AsyncGenerator stream of NodeOutput
+ * @throws Exception if there is an error creating the stream
+ */
+ public AsyncGenerator> stream(Map inputs ) throws Exception {
+ return this.stream( inputs, InvokeConfig.builder().build() );
+ }
/**
* Invokes the graph execution with the provided inputs and returns the final state.
*
* @param inputs the input map
+ * @param config the invoke configuration
* @return an Optional containing the final state if present, otherwise an empty Optional
* @throws Exception if there is an error during invocation
*/
- public Optional invoke(Map inputs ) throws Exception {
+ public Optional invoke(Map inputs, InvokeConfig config ) throws Exception {
- var sourceIterator = stream(inputs).iterator();
+ var sourceIterator = stream(inputs, config).iterator();
var result = StreamSupport.stream(
Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED),
@@ -189,6 +231,17 @@ public Optional invoke(Map inputs ) throws Exception {
return result.reduce((a, b) -> b).map( NodeOutput::state);
}
+ /**
+ * Invokes the graph execution with the provided inputs and returns the final state.
+ *
+ * @param inputs the input map
+ * @return an Optional containing the final state if present, otherwise an empty Optional
+ * @throws Exception if there is an error during invocation
+ */
+ public Optional invoke(Map inputs ) throws Exception {
+ return this.invoke( inputs, InvokeConfig.builder().build() );
+ }
+
/**
* Generates a drawable graph representation of the state graph.
*
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java
index 2677ba2..21ef364 100644
--- a/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java
@@ -7,7 +7,7 @@
@Value
@Accessors(fluent = true)
-class EdgeValue {
+public class EdgeValue {
/**
* The unique identifier for the edge value.
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java
new file mode 100644
index 0000000..3fad8d9
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java
@@ -0,0 +1,39 @@
+package org.bsc.langgraph4j;
+
+import org.bsc.langgraph4j.checkpoint.CheckpointConfig;
+
+import java.util.Optional;
+
+public class InvokeConfig {
+
+ private CheckpointConfig checkpointConfig;
+
+ public Optional getCheckpointConfig() {
+ return Optional.ofNullable(checkpointConfig);
+ }
+
+ static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+
+ private String checkpointThreadId;
+
+ public Builder checkpointThreadId(String threadId) {
+ this.checkpointThreadId = threadId;
+ return this;
+ }
+ public InvokeConfig build() {
+ InvokeConfig result = new InvokeConfig();
+
+ if( checkpointThreadId != null ) {
+ result.checkpointConfig = CheckpointConfig.of(checkpointThreadId);
+ }
+
+ return result;
+ }
+ }
+
+ private InvokeConfig() {}
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
index 362528b..c9fbbb2 100644
--- a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java
@@ -198,10 +198,13 @@ private Node makeFakeNode(String id) {
/**
* Compiles the state graph into a compiled graph.
*
+ * @param config the compile configuration
* @return a compiled graph
* @throws GraphStateException if there are errors related to the graph state
*/
- public CompiledGraph compile() throws GraphStateException {
+ public CompiledGraph compile( CompileConfig config ) throws GraphStateException {
+ Objects.requireNonNull(config, "config cannot be null");
+
if (entryPoint == null) {
throw Errors.missingEntryPoint.exception();
}
@@ -237,6 +240,17 @@ public CompiledGraph compile() throws GraphStateException {
}
}
- return new CompiledGraph<>(this);
+ return new CompiledGraph<>(this, config);
+ }
+
+ /**
+ * Compiles the state graph into a compiled graph.
+ *
+ * @return a compiled graph
+ * @throws GraphStateException if there are errors related to the graph state
+ */
+ public CompiledGraph compile() throws GraphStateException {
+ return compile(CompileConfig.builder().build());
}
+
}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java
new file mode 100644
index 0000000..85652a7
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java
@@ -0,0 +1,13 @@
+package org.bsc.langgraph4j.checkpoint;
+
+import java.io.Externalizable;
+import java.util.Collection;
+import java.util.Optional;
+
+public interface BaseCheckpointSaver {
+
+
+ Collection list();
+ Optional getLast();
+ void put( Checkpoint checkpoint ) throws Exception;
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java
new file mode 100644
index 0000000..fd3a6e8
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java
@@ -0,0 +1,52 @@
+package org.bsc.langgraph4j.checkpoint;
+
+import lombok.Data;
+import org.bsc.langgraph4j.state.AgentState;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import java.util.*;
+
+
+/**
+ * Represents a checkpoint of an agent state.
+ *
+ * The checkpoint is an immutable object that holds an {@link AgentState}
+ * and a {@code String} that represents the next state.
+ *
+ * The checkpoint is serializable and can be persisted and restored.
+ *
+ * @see AgentState
+ * @see Externalizable
+ */
+public class Checkpoint {
+
+ @lombok.Value(staticConstructor="of")
+ public static class Value {
+ AgentState state;
+ String nodeId;
+ }
+
+ String id;
+ Value value;
+
+ public final String getId() {
+ return id;
+ }
+ public final Value getValue() {
+ return value;
+ }
+
+ public Checkpoint( Value value ) {
+ this(UUID.randomUUID().toString(), value );
+ }
+ public Checkpoint(String id, Value value) {
+ Objects.requireNonNull(id, "id cannot be null");
+ Objects.requireNonNull(value, "value cannot be null");
+ this.id = id;
+ this.value = value;
+ }
+
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java
new file mode 100644
index 0000000..77de6e7
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java
@@ -0,0 +1,8 @@
+package org.bsc.langgraph4j.checkpoint;
+
+import lombok.Value;
+
+@Value(staticConstructor = "of")
+public class CheckpointConfig {
+ String threadId;
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java
new file mode 100644
index 0000000..eea17cf
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java
@@ -0,0 +1,39 @@
+package org.bsc.langgraph4j.checkpoint;
+
+import org.bsc.langgraph4j.serializer.CheckpointSerializer;
+
+import java.util.*;
+
+import static java.util.Collections.unmodifiableCollection;
+import static java.util.Collections.unmodifiableSet;
+
+public class MemorySaver implements BaseCheckpointSaver {
+
+ private final Stack checkpoints = new Stack<>();
+
+
+ @Override
+ public Collection list() {
+ return unmodifiableCollection(checkpoints); // immutable checkpoints;
+ }
+
+ @Override
+ public Optional getLast() {
+ if( checkpoints.isEmpty() ) {
+ return Optional.empty();
+ }
+ return Optional.ofNullable( checkpoints.peek() );
+ }
+
+ public Optional get(String id) {
+ return checkpoints.stream()
+ .filter( checkpoint -> checkpoint.getId().equals(id) )
+ .findFirst();
+ }
+
+ @Override
+ public void put(Checkpoint checkpoint) throws Exception {
+ checkpoints.add( CheckpointSerializer.INSTANCE.cloneObject(checkpoint) );
+ }
+
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java
new file mode 100644
index 0000000..08c1414
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java
@@ -0,0 +1,74 @@
+package org.bsc.langgraph4j.serializer;
+
+import lombok.extern.log4j.Log4j;
+import lombok.extern.slf4j.Slf4j;
+import org.bsc.langgraph4j.state.AgentState;
+
+import java.io.*;
+import java.util.HashMap;
+import java.util.Map;
+
+@Slf4j
+public class AgentStateSerializer implements Serializer {
+ public static final AgentStateSerializer INSTANCE = new AgentStateSerializer();
+ private AgentStateSerializer() {}
+
+ @Override
+ public void write(AgentState object, ObjectOutput out) throws IOException {
+ try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
+ int actuoalSize = 0;
+
+ final ObjectOutputStream tupleStream = new ObjectOutputStream( baos );
+ for( Map.Entry e : object.data().entrySet() ) {
+ try {
+ tupleStream.writeUTF(e.getKey());
+ tupleStream.writeObject(e.getValue());
+ ++actuoalSize;
+ } catch (IOException ex) {
+ log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() );
+ throw ex;
+ }
+ }
+
+ out.writeInt( object.data().size() );
+ out.writeInt( actuoalSize ); // actual size
+ byte[] data = baos.toByteArray();
+ out.writeInt( data.length );
+ out.write( data );
+
+ }
+
+ }
+
+ @Override
+ public AgentState read(ObjectInput in) throws IOException, ClassNotFoundException {
+ Map data = new HashMap<>();
+
+ int expectedSize = in.readInt();
+ int actualSize = in.readInt();
+ if( expectedSize > 0 && actualSize > 0 ) {
+
+ if( expectedSize != actualSize ) {
+ final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ;
+ log.error( message ) ;
+ throw new IOException( message ) ;
+ }
+
+ int byteLen = in.readInt();
+ byte[] bytes = new byte[byteLen];
+ in.readFully(bytes);
+
+ try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
+ ObjectInputStream ois = new ObjectInputStream( bais );
+
+ for( int i = 0; i < actualSize; i++ ) {
+ String key = ois.readUTF();
+ Object value = ois.readObject();
+ data.put(key, value);
+ }
+ }
+
+ }
+ return new AgentState(data);
+ }
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java
new file mode 100644
index 0000000..ba2a725
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java
@@ -0,0 +1,31 @@
+package org.bsc.langgraph4j.serializer;
+
+import org.bsc.langgraph4j.checkpoint.Checkpoint;
+import org.bsc.langgraph4j.state.AgentState;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
+public class CheckpointSerializer implements Serializer {
+
+ public static final CheckpointSerializer INSTANCE = new CheckpointSerializer();
+
+ private CheckpointSerializer() {}
+
+ public void write( Checkpoint object, ObjectOutput out) throws IOException {
+ out.writeUTF(object.getId());
+ Checkpoint.Value value = object.getValue();
+ AgentStateSerializer.INSTANCE.write( value.getState(), out );
+ out.writeUTF( value.getNodeId() );
+ }
+
+ public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException {
+ String id = in.readUTF();
+ AgentState state = AgentStateSerializer.INSTANCE.read( in );
+ String nodeId = in.readUTF();
+ Checkpoint.Value value = Checkpoint.Value.of( state, nodeId );
+ return new Checkpoint(id, value);
+ }
+
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java
new file mode 100644
index 0000000..15404f6
--- /dev/null
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java
@@ -0,0 +1,40 @@
+package org.bsc.langgraph4j.serializer;
+
+import org.bsc.langgraph4j.checkpoint.Checkpoint;
+
+import java.io.*;
+import java.util.Objects;
+
+public interface Serializer {
+
+ void write(T object, ObjectOutput out) throws IOException;
+
+
+ T read(ObjectInput in) throws IOException, ClassNotFoundException ;
+
+ default byte[] writeObject(T object) throws IOException {
+ Objects.requireNonNull( object, "object cannot be null" );
+ try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
+ ObjectOutputStream oas = new ObjectOutputStream(baos);
+ write(object, oas);
+ oas.flush();
+ return baos.toByteArray();
+ }
+ }
+
+ default T readObject(byte[] bytes) throws IOException, ClassNotFoundException {
+ Objects.requireNonNull( bytes, "bytes cannot be null" );
+ if( bytes.length == 0 ) {
+ throw new IllegalArgumentException("bytes cannot be empty");
+ }
+ try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
+ ObjectInputStream ois = new ObjectInputStream(bais);
+ return read(ois);
+ }
+ }
+
+ default T cloneObject(T object) throws IOException, ClassNotFoundException {
+ Objects.requireNonNull( object, "object cannot be null" );
+ return readObject(writeObject(object));
+ }
+}
diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java
index ed8e965..aa277b3 100644
--- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java
+++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java
@@ -1,5 +1,9 @@
package org.bsc.langgraph4j.state;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.*;
import static java.util.Collections.unmodifiableList;
@@ -10,8 +14,8 @@
*
* @param the type of the value
*/
-public class AppendableValueRW implements AppendableValue {
- private final List values;
+public class AppendableValueRW implements AppendableValue, Externalizable {
+ private List values;
/**
* Constructs an AppendableValueRW with the given initial collection of values.
@@ -100,4 +104,14 @@ public Optional lastMinus(int n) {
public String toString() {
return String.valueOf(values);
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeObject(values);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ values = (List) in.readObject();
+ }
}
diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java
new file mode 100644
index 0000000..4721afe
--- /dev/null
+++ b/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java
@@ -0,0 +1,72 @@
+package org.bsc.langgraph4j;
+
+import org.bsc.langgraph4j.serializer.AgentStateSerializer;
+import org.bsc.langgraph4j.state.AgentState;
+import org.junit.jupiter.api.Test;
+
+import java.io.*;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+public class SerializeTest {
+
+
+ private byte[] serializeState(AgentState state) throws Exception {
+ try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
+ ObjectOutputStream oas = new ObjectOutputStream(baos);
+ AgentStateSerializer.INSTANCE.write(state, oas);
+ oas.flush();
+ return baos.toByteArray();
+ }
+ }
+ private AgentState deserializeState( byte[] bytes ) throws Exception {
+ try(ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
+ ObjectInputStream ois = new ObjectInputStream( bais );
+ return AgentStateSerializer.INSTANCE.read( ois );
+ }
+ }
+
+ @Test
+ public void serializeStateTest() throws Exception {
+
+ Map data = new HashMap<>();
+ data.put("a", "b");
+ data.put("f", null);
+ data.put("c", "d");
+
+ final AgentState state = new AgentState(data);
+
+ byte[] bytes = serializeState(state);
+
+ assertNotNull(bytes);
+ AgentState deserializeState = deserializeState( bytes );
+
+ assertEquals( 3, deserializeState.data().size() );
+ assertEquals( "b", deserializeState.data().get("a") );
+ assertEquals( "d", deserializeState.data().get("c") );
+ }
+
+ static class NonSerializableElement {
+ String value = "TEST";
+ public NonSerializableElement() {
+ }
+ }
+ @Test
+ public void partiallySerializeStateTest() throws Exception {
+
+ Map data = new HashMap<>();
+ data.put("a", "b");
+ data.put("f", new NonSerializableElement() );
+ data.put("c", "d");
+
+ final AgentState state = new AgentState(data);
+
+ assertThrows(IOException.class, () -> {
+ serializeState(state);
+ });
+
+ }
+
+}
diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java
index 55406f5..d3cb7e7 100644
--- a/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java
+++ b/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java
@@ -1,13 +1,17 @@
package org.bsc.langgraph4j;
import lombok.var;
+import org.bsc.langgraph4j.checkpoint.Checkpoint;
+import org.bsc.langgraph4j.checkpoint.MemorySaver;
import org.bsc.langgraph4j.state.AgentState;
+import org.bsc.langgraph4j.state.AppendableValue;
import org.junit.jupiter.api.Test;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
import java.util.stream.Collectors;
+import static java.lang.String.format;
+import static java.util.Collections.emptyMap;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
@@ -104,5 +108,142 @@ public void testRunningOneNode() throws Exception {
}
+ @Test
+ public void testCheckpointInitialState() throws Exception {
+
+ var workflow = new StateGraph<>(AgentState::new);
+ workflow.setEntryPoint("agent_1");
+
+ workflow.addNode("agent_1", node_async( state -> {
+ System.out.print( "agent_1");
+ return mapOf("agent_1:prop1", "agent_1:test");
+ }));
+
+ workflow.addEdge( "agent_1", END);
+
+ var saver = new MemorySaver();
+
+ var compileConfig = CompileConfig.builder().checkpointSaver(saver).build();
+
+ var app = workflow.compile( compileConfig );
+
+ Map inputs = mapOf( "input", "test1");
+
+ var initState = app.getInitialState( inputs );
+
+ assertEquals( 1, initState.data().size() );
+ assertTrue( initState.value("input").isPresent() );
+ assertEquals( "test1", initState.value("input").get() );
+
+ //
+ // Test checkpoint not override inputs
+ //
+ var newState = new AgentState( mapOf( "input", "test2") );
+ saver.put( new Checkpoint( Checkpoint.Value.of( newState, "start" ) ) );
+
+ app = workflow.compile( compileConfig );
+ initState = app.getInitialState( inputs );
+
+ assertEquals( 1, initState.data().size() );
+ assertTrue( initState.value("input").isPresent() );
+ assertEquals( "test1", initState.value("input").get() );
+
+ // Test checkpoints are saved
+ newState = new AgentState( mapOf( "input", "test2", "agent_1:prop1", "agent_1:test") );
+ saver.put( new Checkpoint( Checkpoint.Value.of( newState, "agent_1" ) ) );
+
+ app = workflow.compile( compileConfig );
+ initState = app.getInitialState( inputs );
+
+ assertEquals( 2, initState.data().size() );
+ assertTrue( initState.value("input").isPresent() );
+ assertEquals( "test1", initState.value("input").get() );
+ assertTrue( initState.value("agent_1:prop1").isPresent() );
+ assertEquals( "agent_1:test", initState.value("agent_1:prop1").get() );
+
+ var checkpoints = saver.list();
+ assertEquals( 2, checkpoints.size() );
+ var last = saver.getLast();
+ assertTrue( last.isPresent() );
+ assertEquals( "agent_1", last.get().getValue().getNodeId() );
+ assertTrue( last.get().getValue().getState().value("agent_1:prop1").isPresent() );
+ assertEquals( "agent_1:test", last.get().getValue().getState().value("agent_1:prop1").get() );
+
+ }
+
+ static class MessagesState extends AgentState {
+
+ public MessagesState(Map initData) {
+ super( initData );
+ appendableValue("messages"); // tip: initialize messages
+ }
+
+ int steps() {
+ return value("steps").map(Integer.class::cast).orElse(0);
+ }
+
+ AppendableValue messages() {
+ return appendableValue("messages");
+ }
+
+ }
+
+ @Test
+ public void testCheckpointSaver() throws Exception {
+ var STEPS_COUNT = 5;
+
+ var workflow = new StateGraph<>(MessagesState::new);
+ workflow.setEntryPoint("agent_1");
+
+ workflow.addNode("agent_1", node_async( state -> {
+
+ System.out.println( "agent_1");
+ var steps = state.steps() + 1;
+ return mapOf("steps", steps, "messages", format( "agent_1:step %d", steps ));
+ }));
+ workflow.addConditionalEdges( "agent_1", edge_async( state -> {
+ var steps = state.steps();
+ if( steps >= STEPS_COUNT) {
+ return "exit";
+ }
+ return "next";
+ }), mapOf( "next", "agent_1", "exit", END) );
+
+ var saver = new MemorySaver();
+
+ var compileConfig = CompileConfig.builder()
+ .checkpointSaver(saver)
+ .build();
+
+ var app = workflow.compile( compileConfig );
+
+ Map inputs = mapOf( "steps", 0 );
+
+ var invokeConfig = InvokeConfig.builder().checkpointThreadId("thread_1").build();
+
+ var state = app.invoke( inputs, invokeConfig );
+
+ assertTrue( state.isPresent() );
+ assertEquals( STEPS_COUNT, state.get().steps() );
+ var messages = state.get().appendableValue("messages");
+ assertFalse( messages.isEmpty() );
+
+ System.out.println( messages.values() );
+
+ assertEquals( STEPS_COUNT, messages.size() );
+ for( int i = 0; i < messages.size(); i++ ) {
+ assertEquals( format("agent_1:step %d", i+1), messages.values().get(i) );
+ }
+
+ state = app.invoke( emptyMap(), invokeConfig );
+
+ assertTrue( state.isPresent() );
+ assertEquals( STEPS_COUNT + 1, state.get().steps() );
+ messages = state.get().appendableValue("messages");
+ assertEquals( STEPS_COUNT + 1, messages.size() );
+ for( int i = 0; i < messages.size(); i++ ) {
+ assertEquals( format("agent_1:step %d", i+1), messages.values().get(i) );
+ }
+ }
}