Skip to content

Commit

Permalink
feat(AsyncNodeGenerator): add output factory method
Browse files Browse the repository at this point in the history
work on #24
  • Loading branch information
bsorrentino committed Sep 14, 2024
1 parent 04bcd13 commit 0f61236
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,15 @@ public GraphRepresentation getGraph( GraphRepresentation.Type type ) {
}


public class AsyncNodeGenerator<O extends NodeOutput<State>> implements AsyncGenerator<O> {
public class AsyncNodeGenerator<Output extends NodeOutput<State>> implements AsyncGenerator<Output> {

Map<String,Object> currentState;
String currentNodeId;
String nextNodeId;
int iteration = 0;
RunnableConfig config;

public AsyncNodeGenerator(Map<String,Object> inputs, RunnableConfig config) throws Exception {
protected AsyncNodeGenerator(Map<String,Object> inputs, RunnableConfig config) throws Exception {
final boolean isResumeRequest = (inputs == null);

if( isResumeRequest ) {
Expand Down Expand Up @@ -361,8 +361,12 @@ public AsyncNodeGenerator(Map<String,Object> inputs, RunnableConfig config) thro
}
}

protected Output buildOutput(String nodeId ) throws Exception {
return (Output)NodeOutput.of( nodeId, cloneState(currentState) );
}

@Override
public Data<O> next() {
public Data<Output> next() {
// GUARD: CHECK MAX ITERATION REACHED
if( ++iteration > maxIterations ) {
log.warn( "Maximum number of iterations ({}) reached!", maxIterations);
Expand All @@ -372,21 +376,22 @@ public Data<O> next() {
// GUARD: CHECK IF IT IS END
if( nextNodeId == null && currentNodeId == null ) return Data.done();

CompletableFuture<O> future = new CompletableFuture<>();
CompletableFuture<Output> future = new CompletableFuture<>();

try {

if( START.equals(currentNodeId) ) {
nextNodeId = getEntryPoint( currentState );
currentNodeId = nextNodeId;
addCheckpoint( config, START, currentState, nextNodeId );
return Data.of((O)NodeOutput.of( START, cloneState(currentState) ));
return Data.of( buildOutput( START ) );

}

if( END.equals(nextNodeId) ) {
nextNodeId = null;
currentNodeId = null;
return Data.of((O)NodeOutput.of( END, cloneState(currentState) ));
return Data.of( buildOutput( END ) );
}

// check on previous node
Expand All @@ -407,7 +412,8 @@ public Data<O> next() {
nextNodeId = nextNodeId(currentNodeId, currentState);
addCheckpoint(config, currentNodeId, currentState, nextNodeId);

return (O)NodeOutput.of( currentNodeId, cloneState(currentState) );
return buildOutput( currentNodeId );

}
catch (Exception e) {
throw new CompletionException(e);
Expand Down

0 comments on commit 0f61236

Please sign in to comment.