Skip to content

Commit

Permalink
fix: pause management
Browse files Browse the repository at this point in the history
- check resume startpoint
- refine state cloning
- improve unit test
  • Loading branch information
bsorrentino committed Sep 5, 2024
1 parent e8b0735 commit 7042bce
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 25 deletions.
44 changes: 27 additions & 17 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.bsc.langgraph4j;

import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.bsc.async.AsyncGenerator;
Expand Down Expand Up @@ -156,7 +157,10 @@ private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

private boolean shouldInterruptBefore( String nodeId ) {
private boolean shouldInterruptBefore(@NonNull String nodeId, String startNodeId ) {
if( nodeId.equals(startNodeId)) { // FIX RESUME ERROR
return false;
}
return Arrays.asList(compileConfig.getInterruptBefore()).contains(nodeId);
}

Expand All @@ -168,11 +172,12 @@ private void addCheckpoint( RunnableConfig config, String nodeId, State state, S
if( compileConfig.checkpointSaver().isPresent() ) {
Checkpoint cp = Checkpoint.builder()
.nodeId( nodeId )
.state( state )
.state( state.data() )
.nextNodeId( nextNodeId )
.build();
compileConfig.checkpointSaver().get().put( config, cp );
}

}

Map<String,Object> getInitialStateFromSchema() {
Expand Down Expand Up @@ -214,30 +219,30 @@ private void streamData( State initialState,
while( !Objects.equals(currentNodeId, END) ) {

log.trace( "NEXT NODE: {}", currentNodeId);

var action = nodes.get(currentNodeId);

if (action == null)
throw StateGraph.RunnableErrors.missingNode.exception(currentNodeId);

if ( shouldInterruptBefore( currentNodeId ) ) {
if ( shouldInterruptBefore( currentNodeId, startNodeId )) {
log.trace("interrupt before node {}", currentNodeId);
addCheckpoint( config, currentNodeId, cloneState(currentState.data()), currentNodeId );
return;
}

partialState = action.apply(currentState).get();
partialState = action.apply( cloneState(currentState.data())).get();

currentState = cloneState( AgentState.updateState(currentState, partialState, stateGraph.getChannels()) );
currentState = stateGraph.getStateFactory().apply(AgentState.updateState(currentState, partialState, stateGraph.getChannels()));

yieldData.accept( NodeOutput.of(currentNodeId, currentState) );
yieldData.accept( NodeOutput.of(currentNodeId, cloneState(currentState.data())) );

if ( Objects.equals(currentNodeId, stateGraph.getFinishPoint()) ) {
addCheckpoint( config, currentNodeId, currentState, stateGraph.getFinishPoint() );
addCheckpoint( config, currentNodeId, cloneState(currentState.data()), stateGraph.getFinishPoint() );
break;
}

final String nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint( config, currentNodeId, currentState, nextNodeId );
addCheckpoint( config, currentNodeId, cloneState(currentState.data()), nextNodeId );

if ( shouldInterruptAfter( currentNodeId ) ) {
log.trace( "interrupt after node {}", currentNodeId);
Expand All @@ -257,7 +262,7 @@ private void streamData( State initialState,

}

yieldData.accept( NodeOutput.of(END, currentState) );
yieldData.accept( NodeOutput.of(END, cloneState(currentState.data())) );

// addCheckpoint( config, END, currentState, null );

Expand All @@ -280,7 +285,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna
final boolean isResumeRequest = (inputs == null);

if( isResumeRequest ) {

log.trace( "RESUME REQUEST" );
BaseCheckpointSaver saver = compileConfig.checkpointSaver().orElseThrow(() -> (new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured")));

Checkpoint startCheckpoint = saver.get( config ).orElseThrow( () -> (new IllegalStateException("Resume request without a saved checkpoint!")) );
Expand All @@ -291,9 +296,14 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna

State startState = stateGraph.getStateFactory().apply( startCheckpoint.getState() );

// Reset checkpoint id
RunnableConfig resumeConfig = RunnableConfig.builder(config)
.checkPointId(null)
.build();

streamData( startState,
startCheckpoint.getNextNodeId(),
config,
resumeConfig,
data -> queue.add( AsyncGenerator.Data.of( completedFuture(data) ) )
);
}));
Expand All @@ -304,17 +314,17 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Runna

log.trace( "START" );

State startState = cloneState( getInitialState(inputs, config) ) ;
queue.add( AsyncGenerator.Data.of( NodeOutput.of( START, startState ) ));
State startState = stateGraph.getStateFactory().apply(getInitialState(inputs, config )) ;

queue.add( AsyncGenerator.Data.of( NodeOutput.of( START, cloneState(startState.data()) ) ));

String startNodeId = this.getEntryPoint( startState );
if( shouldInterruptBefore( startNodeId ) ) return;
if( shouldInterruptBefore( startNodeId, null ) ) return;

addCheckpoint( config, START, startState, startNodeId );
addCheckpoint( config, START, cloneState(startState.data()), startNodeId );

if( shouldInterruptAfter( startNodeId ) ) return;


streamData( startState,
startNodeId,
config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.bsc.langgraph4j.checkpoint;

import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.serializer.CheckpointSerializer;

import java.util.*;
import java.util.concurrent.locks.Lock;
Expand All @@ -19,8 +18,6 @@ public class MemorySaver implements BaseCheckpointSaver {
private final Lock r = rwl.readLock();
private final Lock w = rwl.writeLock();

private final CheckpointSerializer _serializer = CheckpointSerializer.of();

public MemorySaver() {
}

Expand Down Expand Up @@ -68,22 +65,21 @@ public RunnableConfig put(RunnableConfig config, Checkpoint checkpoint) throws E

w.lock();
try {
final Checkpoint clonedCheckpoint = _serializer.cloneObject(checkpoint);

if (config.checkPointId().isPresent()) { // Replace Checkpoint
String checkPointId = config.checkPointId().get();
int index = IntStream.range(0, checkpoints.size())
.filter(i -> checkpoints.get(i).getId().equals(checkPointId))
.findFirst()
.orElseThrow(() -> (new NoSuchElementException(format("Checkpoint with id %s not found!", checkPointId))));
checkpoints.set(index, clonedCheckpoint);
checkpoints.set(index, checkpoint );
return config;
}

checkpoints.push(clonedCheckpoint); // Add Checkpoint
checkpoints.push( checkpoint ); // Add Checkpoint

return RunnableConfig.builder(config)
.checkPointId(clonedCheckpoint.getId())
.checkPointId(checkpoint.getId())
.build();
}
finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ List<String> messages() {
}

Optional<String> lastMessage() {
List<String> messages = messages();
List<String> messages = messages();
if( messages.isEmpty() ) {
return Optional.empty();
}
Expand Down Expand Up @@ -311,6 +311,70 @@ public void testViewAndUpdatePastGraphState() throws Exception {

@Test
public void testPauseAndUpdatePastGraphState() throws Exception {
var workflow = new StateGraph<>(MessagesState.SCHEMA, MessagesState::new)
.addNode("agent", node_async( state -> {
String lastMessage = state.lastMessage().orElseThrow( () -> new IllegalStateException("No last message!") );

if( lastMessage.contains( "temperature")) {
return mapOf("messages", "whether in Naples is sunny");
}
if( lastMessage.contains( "whether")) {
return mapOf("messages", "tool_calls");
}
if( lastMessage.contains( "bartolo")) {
return mapOf("messages", "Hi bartolo, nice to meet you too! How can I assist you today?");
}
if(state.messages().stream().anyMatch(m -> m.contains("bartolo"))) {
return mapOf("messages", "Hi, bartolo welcome back?");
}
throw new IllegalStateException( "unknown message!" );
}))
.addNode("tools", node_async( state ->
mapOf( "messages", "temperature in Napoli is 30 degree" )
))
.addEdge(START, "agent")
.addConditionalEdges("agent", edge_async( state ->
state.lastMessage().filter( m -> m.equals("tool_calls") ).map( m -> "tools" ).orElse(END)
), mapOf("tools", "tools", END, END))
.addEdge("tools", "agent");


var saver = new MemorySaver();

var compileConfig = CompileConfig.builder()
.checkpointSaver(saver)
.interruptBefore("tools")
.build();

var app = workflow.compile( compileConfig );

var runnableConfig = RunnableConfig.builder()
.threadId("thread_1")
.build();

Map<String,Object> inputs = mapOf( "messages","whether in Naples?" ) ;
var results = app.stream( inputs, runnableConfig ).stream().collect(Collectors.toList());

assertNotNull( results );
assertEquals( 2, results.size() );
assertEquals( START, results.get(0).node() );
assertEquals( "agent", results.get(1).node() );
assertTrue( results.get(1).state().lastMessage().isPresent() );

var state = app.getState(runnableConfig);

assertNotNull( state );
assertEquals( "tools", state.getNext() );

results = app.stream( null, state.getConfig() ).stream().collect(Collectors.toList());

assertNotNull( results );
assertEquals( 3, results.size() );
assertEquals( "tools", results.get(0).node() );
assertEquals( "agent", results.get(1).node() );
assertEquals( END, results.get(2).node() );
assertTrue( results.get(2).state().lastMessage().isPresent() );

System.out.println( results.get(2).state().lastMessage().get() );
}
}

0 comments on commit 7042bce

Please sign in to comment.