diff --git a/core-jdk8/pom.xml b/core-jdk8/pom.xml index ab237d2..a543f47 100644 --- a/core-jdk8/pom.xml +++ b/core-jdk8/pom.xml @@ -27,7 +27,7 @@ org.bsc.async async-generator-jdk8 - 2.0.0 + 2.0.1 @@ -41,6 +41,12 @@ test + + org.slf4j + slf4j-jdk14 + test + + diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java new file mode 100644 index 0000000..d6d2f79 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompileConfig.java @@ -0,0 +1,45 @@ +package org.bsc.langgraph4j; + +import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver; + +import java.util.Optional; + + +public class CompileConfig { + + private BaseCheckpointSaver checkpointSaver; + private String[] interruptBefore = {}; + private String[] interruptAfter = {}; + + public Optional getCheckpointSaver() { return Optional.ofNullable(checkpointSaver); } + public String[] getInterruptBefore() { return interruptBefore; } + public String[] getInterruptAfter() { return interruptAfter; } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private final CompileConfig config = new CompileConfig(); + + public Builder checkpointSaver(BaseCheckpointSaver checkpointSaver) { + this.config.checkpointSaver = checkpointSaver; + return this; + } + public Builder interruptBefore(String... interruptBefore) { + this.config.interruptBefore = interruptBefore; + return this; + } + public Builder interruptAfter(String... interruptAfter) { + this.config.interruptAfter = interruptAfter; + return this; + } + public CompileConfig build() { + return config; + } + } + + + private CompileConfig() {} + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java index 6336fca..6b4c4c9 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java @@ -6,6 +6,8 @@ import org.bsc.async.AsyncGenerator; import org.bsc.async.AsyncGeneratorQueue; import org.bsc.langgraph4j.action.AsyncNodeAction; +import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver; +import org.bsc.langgraph4j.checkpoint.Checkpoint; import org.bsc.langgraph4j.state.AgentState; import java.util.*; @@ -31,14 +33,16 @@ public class CompiledGraph { final Map> edges = new LinkedHashMap<>(); private int maxIterations = 25; + private final CompileConfig compileConfig; /** * Constructs a CompiledGraph with the given StateGraph. * * @param stateGraph the StateGraph to be used in this CompiledGraph */ - protected CompiledGraph(StateGraph stateGraph) { + protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig ) { this.stateGraph = stateGraph; + this.compileConfig = compileConfig; stateGraph.nodes.forEach(n -> nodes.put(n.id(), n.action()) ); @@ -105,21 +109,46 @@ private String getEntryPoint( State state ) throws Exception { return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint"); } + private void addCheckpoint( String nodeId, State state ) throws Exception { + if( compileConfig.getCheckpointSaver().isPresent() ) { + Checkpoint.Value value = Checkpoint.Value.of(state, nodeId); + compileConfig.getCheckpointSaver().get().put( new Checkpoint(value) ); + } + } + + State getInitialState(Map inputs) { + + return compileConfig.getCheckpointSaver() + .flatMap(BaseCheckpointSaver::getLast) + .map( cp -> { + var state = cp.getValue().getState(); + return state.mergeWith(inputs, stateGraph.getStateFactory()); + }) + .orElseGet( () -> + stateGraph.getStateFactory().apply(inputs) + ); + } + /** * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs. * * @param inputs the input map + * @param config the invoke configuration * @return an AsyncGenerator stream of NodeOutput * @throws Exception if there is an error creating the stream */ - public AsyncGenerator> stream(Map inputs ) throws Exception { + public AsyncGenerator> stream(Map inputs, InvokeConfig config ) throws Exception { + Objects.requireNonNull(config, "config cannot be null"); return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> { try { - var currentState = stateGraph.getStateFactory().apply(inputs); + + var currentState = getInitialState(inputs); queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("start", currentState)) )); + addCheckpoint( "start", currentState ); + log.trace( "START"); var currentNodeId = this.getEntryPoint( currentState ); @@ -142,6 +171,7 @@ public AsyncGenerator> stream(Map inputs ) thro currentState = currentState.mergeWith(partialState, stateGraph.getStateFactory()); queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) )); + addCheckpoint( currentNodeId, currentState ); if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) { break; @@ -161,6 +191,7 @@ public AsyncGenerator> stream(Map inputs ) thro } queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of("stop", currentState) ) )); + addCheckpoint( "stop", currentState ); log.trace( "STOP"); } catch (Exception e) { @@ -171,16 +202,27 @@ public AsyncGenerator> stream(Map inputs ) thro } + /** + * Creates an AsyncGenerator stream of NodeOutput based on the provided inputs. + * + * @param inputs the input map + * @return an AsyncGenerator stream of NodeOutput + * @throws Exception if there is an error creating the stream + */ + public AsyncGenerator> stream(Map inputs ) throws Exception { + return this.stream( inputs, InvokeConfig.builder().build() ); + } /** * Invokes the graph execution with the provided inputs and returns the final state. * * @param inputs the input map + * @param config the invoke configuration * @return an Optional containing the final state if present, otherwise an empty Optional * @throws Exception if there is an error during invocation */ - public Optional invoke(Map inputs ) throws Exception { + public Optional invoke(Map inputs, InvokeConfig config ) throws Exception { - var sourceIterator = stream(inputs).iterator(); + var sourceIterator = stream(inputs, config).iterator(); var result = StreamSupport.stream( Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED), @@ -189,6 +231,17 @@ public Optional invoke(Map inputs ) throws Exception { return result.reduce((a, b) -> b).map( NodeOutput::state); } + /** + * Invokes the graph execution with the provided inputs and returns the final state. + * + * @param inputs the input map + * @return an Optional containing the final state if present, otherwise an empty Optional + * @throws Exception if there is an error during invocation + */ + public Optional invoke(Map inputs ) throws Exception { + return this.invoke( inputs, InvokeConfig.builder().build() ); + } + /** * Generates a drawable graph representation of the state graph. * diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java index 2677ba2..21ef364 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/EdgeValue.java @@ -7,7 +7,7 @@ @Value @Accessors(fluent = true) -class EdgeValue { +public class EdgeValue { /** * The unique identifier for the edge value. diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java new file mode 100644 index 0000000..3fad8d9 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/InvokeConfig.java @@ -0,0 +1,39 @@ +package org.bsc.langgraph4j; + +import org.bsc.langgraph4j.checkpoint.CheckpointConfig; + +import java.util.Optional; + +public class InvokeConfig { + + private CheckpointConfig checkpointConfig; + + public Optional getCheckpointConfig() { + return Optional.ofNullable(checkpointConfig); + } + + static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private String checkpointThreadId; + + public Builder checkpointThreadId(String threadId) { + this.checkpointThreadId = threadId; + return this; + } + public InvokeConfig build() { + InvokeConfig result = new InvokeConfig(); + + if( checkpointThreadId != null ) { + result.checkpointConfig = CheckpointConfig.of(checkpointThreadId); + } + + return result; + } + } + + private InvokeConfig() {} +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java index 362528b..c9fbbb2 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/StateGraph.java @@ -198,10 +198,13 @@ private Node makeFakeNode(String id) { /** * Compiles the state graph into a compiled graph. * + * @param config the compile configuration * @return a compiled graph * @throws GraphStateException if there are errors related to the graph state */ - public CompiledGraph compile() throws GraphStateException { + public CompiledGraph compile( CompileConfig config ) throws GraphStateException { + Objects.requireNonNull(config, "config cannot be null"); + if (entryPoint == null) { throw Errors.missingEntryPoint.exception(); } @@ -237,6 +240,17 @@ public CompiledGraph compile() throws GraphStateException { } } - return new CompiledGraph<>(this); + return new CompiledGraph<>(this, config); + } + + /** + * Compiles the state graph into a compiled graph. + * + * @return a compiled graph + * @throws GraphStateException if there are errors related to the graph state + */ + public CompiledGraph compile() throws GraphStateException { + return compile(CompileConfig.builder().build()); } + } diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java new file mode 100644 index 0000000..85652a7 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/BaseCheckpointSaver.java @@ -0,0 +1,13 @@ +package org.bsc.langgraph4j.checkpoint; + +import java.io.Externalizable; +import java.util.Collection; +import java.util.Optional; + +public interface BaseCheckpointSaver { + + + Collection list(); + Optional getLast(); + void put( Checkpoint checkpoint ) throws Exception; +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java new file mode 100644 index 0000000..fd3a6e8 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/Checkpoint.java @@ -0,0 +1,52 @@ +package org.bsc.langgraph4j.checkpoint; + +import lombok.Data; +import org.bsc.langgraph4j.state.AgentState; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.*; + + +/** + * Represents a checkpoint of an agent state. + * + * The checkpoint is an immutable object that holds an {@link AgentState} + * and a {@code String} that represents the next state. + * + * The checkpoint is serializable and can be persisted and restored. + * + * @see AgentState + * @see Externalizable + */ +public class Checkpoint { + + @lombok.Value(staticConstructor="of") + public static class Value { + AgentState state; + String nodeId; + } + + String id; + Value value; + + public final String getId() { + return id; + } + public final Value getValue() { + return value; + } + + public Checkpoint( Value value ) { + this(UUID.randomUUID().toString(), value ); + } + public Checkpoint(String id, Value value) { + Objects.requireNonNull(id, "id cannot be null"); + Objects.requireNonNull(value, "value cannot be null"); + this.id = id; + this.value = value; + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java new file mode 100644 index 0000000..77de6e7 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/CheckpointConfig.java @@ -0,0 +1,8 @@ +package org.bsc.langgraph4j.checkpoint; + +import lombok.Value; + +@Value(staticConstructor = "of") +public class CheckpointConfig { + String threadId; +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java new file mode 100644 index 0000000..eea17cf --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/checkpoint/MemorySaver.java @@ -0,0 +1,39 @@ +package org.bsc.langgraph4j.checkpoint; + +import org.bsc.langgraph4j.serializer.CheckpointSerializer; + +import java.util.*; + +import static java.util.Collections.unmodifiableCollection; +import static java.util.Collections.unmodifiableSet; + +public class MemorySaver implements BaseCheckpointSaver { + + private final Stack checkpoints = new Stack<>(); + + + @Override + public Collection list() { + return unmodifiableCollection(checkpoints); // immutable checkpoints; + } + + @Override + public Optional getLast() { + if( checkpoints.isEmpty() ) { + return Optional.empty(); + } + return Optional.ofNullable( checkpoints.peek() ); + } + + public Optional get(String id) { + return checkpoints.stream() + .filter( checkpoint -> checkpoint.getId().equals(id) ) + .findFirst(); + } + + @Override + public void put(Checkpoint checkpoint) throws Exception { + checkpoints.add( CheckpointSerializer.INSTANCE.cloneObject(checkpoint) ); + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java new file mode 100644 index 0000000..08c1414 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/AgentStateSerializer.java @@ -0,0 +1,74 @@ +package org.bsc.langgraph4j.serializer; + +import lombok.extern.log4j.Log4j; +import lombok.extern.slf4j.Slf4j; +import org.bsc.langgraph4j.state.AgentState; + +import java.io.*; +import java.util.HashMap; +import java.util.Map; + +@Slf4j +public class AgentStateSerializer implements Serializer { + public static final AgentStateSerializer INSTANCE = new AgentStateSerializer(); + private AgentStateSerializer() {} + + @Override + public void write(AgentState object, ObjectOutput out) throws IOException { + try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { + int actuoalSize = 0; + + final ObjectOutputStream tupleStream = new ObjectOutputStream( baos ); + for( Map.Entry e : object.data().entrySet() ) { + try { + tupleStream.writeUTF(e.getKey()); + tupleStream.writeObject(e.getValue()); + ++actuoalSize; + } catch (IOException ex) { + log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() ); + throw ex; + } + } + + out.writeInt( object.data().size() ); + out.writeInt( actuoalSize ); // actual size + byte[] data = baos.toByteArray(); + out.writeInt( data.length ); + out.write( data ); + + } + + } + + @Override + public AgentState read(ObjectInput in) throws IOException, ClassNotFoundException { + Map data = new HashMap<>(); + + int expectedSize = in.readInt(); + int actualSize = in.readInt(); + if( expectedSize > 0 && actualSize > 0 ) { + + if( expectedSize != actualSize ) { + final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ; + log.error( message ) ; + throw new IOException( message ) ; + } + + int byteLen = in.readInt(); + byte[] bytes = new byte[byteLen]; + in.readFully(bytes); + + try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { + ObjectInputStream ois = new ObjectInputStream( bais ); + + for( int i = 0; i < actualSize; i++ ) { + String key = ois.readUTF(); + Object value = ois.readObject(); + data.put(key, value); + } + } + + } + return new AgentState(data); + } +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java new file mode 100644 index 0000000..ba2a725 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/CheckpointSerializer.java @@ -0,0 +1,31 @@ +package org.bsc.langgraph4j.serializer; + +import org.bsc.langgraph4j.checkpoint.Checkpoint; +import org.bsc.langgraph4j.state.AgentState; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +public class CheckpointSerializer implements Serializer { + + public static final CheckpointSerializer INSTANCE = new CheckpointSerializer(); + + private CheckpointSerializer() {} + + public void write( Checkpoint object, ObjectOutput out) throws IOException { + out.writeUTF(object.getId()); + Checkpoint.Value value = object.getValue(); + AgentStateSerializer.INSTANCE.write( value.getState(), out ); + out.writeUTF( value.getNodeId() ); + } + + public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException { + String id = in.readUTF(); + AgentState state = AgentStateSerializer.INSTANCE.read( in ); + String nodeId = in.readUTF(); + Checkpoint.Value value = Checkpoint.Value.of( state, nodeId ); + return new Checkpoint(id, value); + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java new file mode 100644 index 0000000..15404f6 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/serializer/Serializer.java @@ -0,0 +1,40 @@ +package org.bsc.langgraph4j.serializer; + +import org.bsc.langgraph4j.checkpoint.Checkpoint; + +import java.io.*; +import java.util.Objects; + +public interface Serializer { + + void write(T object, ObjectOutput out) throws IOException; + + + T read(ObjectInput in) throws IOException, ClassNotFoundException ; + + default byte[] writeObject(T object) throws IOException { + Objects.requireNonNull( object, "object cannot be null" ); + try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { + ObjectOutputStream oas = new ObjectOutputStream(baos); + write(object, oas); + oas.flush(); + return baos.toByteArray(); + } + } + + default T readObject(byte[] bytes) throws IOException, ClassNotFoundException { + Objects.requireNonNull( bytes, "bytes cannot be null" ); + if( bytes.length == 0 ) { + throw new IllegalArgumentException("bytes cannot be empty"); + } + try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { + ObjectInputStream ois = new ObjectInputStream(bais); + return read(ois); + } + } + + default T cloneObject(T object) throws IOException, ClassNotFoundException { + Objects.requireNonNull( object, "object cannot be null" ); + return readObject(writeObject(object)); + } +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java index ed8e965..aa277b3 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java @@ -1,5 +1,9 @@ package org.bsc.langgraph4j.state; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.*; import static java.util.Collections.unmodifiableList; @@ -10,8 +14,8 @@ * * @param the type of the value */ -public class AppendableValueRW implements AppendableValue { - private final List values; +public class AppendableValueRW implements AppendableValue, Externalizable { + private List values; /** * Constructs an AppendableValueRW with the given initial collection of values. @@ -100,4 +104,14 @@ public Optional lastMinus(int n) { public String toString() { return String.valueOf(values); } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(values); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + values = (List) in.readObject(); + } } diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java new file mode 100644 index 0000000..4721afe --- /dev/null +++ b/core-jdk8/src/test/java/org/bsc/langgraph4j/SerializeTest.java @@ -0,0 +1,72 @@ +package org.bsc.langgraph4j; + +import org.bsc.langgraph4j.serializer.AgentStateSerializer; +import org.bsc.langgraph4j.state.AgentState; +import org.junit.jupiter.api.Test; + +import java.io.*; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class SerializeTest { + + + private byte[] serializeState(AgentState state) throws Exception { + try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) { + ObjectOutputStream oas = new ObjectOutputStream(baos); + AgentStateSerializer.INSTANCE.write(state, oas); + oas.flush(); + return baos.toByteArray(); + } + } + private AgentState deserializeState( byte[] bytes ) throws Exception { + try(ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) { + ObjectInputStream ois = new ObjectInputStream( bais ); + return AgentStateSerializer.INSTANCE.read( ois ); + } + } + + @Test + public void serializeStateTest() throws Exception { + + Map data = new HashMap<>(); + data.put("a", "b"); + data.put("f", null); + data.put("c", "d"); + + final AgentState state = new AgentState(data); + + byte[] bytes = serializeState(state); + + assertNotNull(bytes); + AgentState deserializeState = deserializeState( bytes ); + + assertEquals( 3, deserializeState.data().size() ); + assertEquals( "b", deserializeState.data().get("a") ); + assertEquals( "d", deserializeState.data().get("c") ); + } + + static class NonSerializableElement { + String value = "TEST"; + public NonSerializableElement() { + } + } + @Test + public void partiallySerializeStateTest() throws Exception { + + Map data = new HashMap<>(); + data.put("a", "b"); + data.put("f", new NonSerializableElement() ); + data.put("c", "d"); + + final AgentState state = new AgentState(data); + + assertThrows(IOException.class, () -> { + serializeState(state); + }); + + } + +} diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java index 55406f5..d3cb7e7 100644 --- a/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java +++ b/core-jdk8/src/test/java/org/bsc/langgraph4j/StateGraphTest.java @@ -1,13 +1,17 @@ package org.bsc.langgraph4j; import lombok.var; +import org.bsc.langgraph4j.checkpoint.Checkpoint; +import org.bsc.langgraph4j.checkpoint.MemorySaver; import org.bsc.langgraph4j.state.AgentState; +import org.bsc.langgraph4j.state.AppendableValue; import org.junit.jupiter.api.Test; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; +import static java.lang.String.format; +import static java.util.Collections.emptyMap; import static org.bsc.langgraph4j.StateGraph.END; import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async; import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async; @@ -104,5 +108,142 @@ public void testRunningOneNode() throws Exception { } + @Test + public void testCheckpointInitialState() throws Exception { + + var workflow = new StateGraph<>(AgentState::new); + workflow.setEntryPoint("agent_1"); + + workflow.addNode("agent_1", node_async( state -> { + System.out.print( "agent_1"); + return mapOf("agent_1:prop1", "agent_1:test"); + })); + + workflow.addEdge( "agent_1", END); + + var saver = new MemorySaver(); + + var compileConfig = CompileConfig.builder().checkpointSaver(saver).build(); + + var app = workflow.compile( compileConfig ); + + Map inputs = mapOf( "input", "test1"); + + var initState = app.getInitialState( inputs ); + + assertEquals( 1, initState.data().size() ); + assertTrue( initState.value("input").isPresent() ); + assertEquals( "test1", initState.value("input").get() ); + + // + // Test checkpoint not override inputs + // + var newState = new AgentState( mapOf( "input", "test2") ); + saver.put( new Checkpoint( Checkpoint.Value.of( newState, "start" ) ) ); + + app = workflow.compile( compileConfig ); + initState = app.getInitialState( inputs ); + + assertEquals( 1, initState.data().size() ); + assertTrue( initState.value("input").isPresent() ); + assertEquals( "test1", initState.value("input").get() ); + + // Test checkpoints are saved + newState = new AgentState( mapOf( "input", "test2", "agent_1:prop1", "agent_1:test") ); + saver.put( new Checkpoint( Checkpoint.Value.of( newState, "agent_1" ) ) ); + + app = workflow.compile( compileConfig ); + initState = app.getInitialState( inputs ); + + assertEquals( 2, initState.data().size() ); + assertTrue( initState.value("input").isPresent() ); + assertEquals( "test1", initState.value("input").get() ); + assertTrue( initState.value("agent_1:prop1").isPresent() ); + assertEquals( "agent_1:test", initState.value("agent_1:prop1").get() ); + + var checkpoints = saver.list(); + assertEquals( 2, checkpoints.size() ); + var last = saver.getLast(); + assertTrue( last.isPresent() ); + assertEquals( "agent_1", last.get().getValue().getNodeId() ); + assertTrue( last.get().getValue().getState().value("agent_1:prop1").isPresent() ); + assertEquals( "agent_1:test", last.get().getValue().getState().value("agent_1:prop1").get() ); + + } + + static class MessagesState extends AgentState { + + public MessagesState(Map initData) { + super( initData ); + appendableValue("messages"); // tip: initialize messages + } + + int steps() { + return value("steps").map(Integer.class::cast).orElse(0); + } + + AppendableValue messages() { + return appendableValue("messages"); + } + + } + + @Test + public void testCheckpointSaver() throws Exception { + var STEPS_COUNT = 5; + + var workflow = new StateGraph<>(MessagesState::new); + workflow.setEntryPoint("agent_1"); + + workflow.addNode("agent_1", node_async( state -> { + + System.out.println( "agent_1"); + var steps = state.steps() + 1; + return mapOf("steps", steps, "messages", format( "agent_1:step %d", steps )); + })); + workflow.addConditionalEdges( "agent_1", edge_async( state -> { + var steps = state.steps(); + if( steps >= STEPS_COUNT) { + return "exit"; + } + return "next"; + }), mapOf( "next", "agent_1", "exit", END) ); + + var saver = new MemorySaver(); + + var compileConfig = CompileConfig.builder() + .checkpointSaver(saver) + .build(); + + var app = workflow.compile( compileConfig ); + + Map inputs = mapOf( "steps", 0 ); + + var invokeConfig = InvokeConfig.builder().checkpointThreadId("thread_1").build(); + + var state = app.invoke( inputs, invokeConfig ); + + assertTrue( state.isPresent() ); + assertEquals( STEPS_COUNT, state.get().steps() ); + var messages = state.get().appendableValue("messages"); + assertFalse( messages.isEmpty() ); + + System.out.println( messages.values() ); + + assertEquals( STEPS_COUNT, messages.size() ); + for( int i = 0; i < messages.size(); i++ ) { + assertEquals( format("agent_1:step %d", i+1), messages.values().get(i) ); + } + + state = app.invoke( emptyMap(), invokeConfig ); + + assertTrue( state.isPresent() ); + assertEquals( STEPS_COUNT + 1, state.get().steps() ); + messages = state.get().appendableValue("messages"); + assertEquals( STEPS_COUNT + 1, messages.size() ); + for( int i = 0; i < messages.size(); i++ ) { + assertEquals( format("agent_1:step %d", i+1), messages.values().get(i) ); + } + } }