Skip to content

Commit

Permalink
refactor: refine AsyncIterator support
Browse files Browse the repository at this point in the history
experimental feature
  • Loading branch information
bsorrentino committed Mar 25, 2024
1 parent e29517b commit 19b43fd
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 71 deletions.
73 changes: 21 additions & 52 deletions src/main/java/org/bsc/langgraph4j/GraphState.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.async.AsyncIterator;
import org.bsc.langgraph4j.async.AsyncQueue;
import org.bsc.langgraph4j.flow.SyncSubmissionPublisher;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;
Expand Down Expand Up @@ -46,40 +48,6 @@ GraphStateException exception(String... args ) {
}
}

public static <T> CompletableFuture<Stream<T>> convertPublisherToStream( Flow.Publisher<T> publisher ) {

var future = new CompletableFuture<Stream<T>>();

var list = new ArrayList<T>();

publisher.subscribe(new Flow.Subscriber<>() {

@Override
public void onSubscribe(Flow.Subscription subscription) {
subscription.request(Long.MAX_VALUE);
}

@Override
public void onNext(T item) {
list.add(item);
}

@Override
public void onError(Throwable throwable) {
future.completeExceptionally(throwable);
}

@Override
public void onComplete() {
var result = StreamSupport.stream(Spliterators.spliterator(list, Spliterator.ORDERED), false);
future.complete(result);

}
});

return future;
}

public class Runnable {

enum Errors {
Expand Down Expand Up @@ -156,54 +124,55 @@ private String nextNodeId( String nodeId , State state ) throws Exception {
}


public Flow.Publisher<NodeOutput<State>> stream( Map<String,Object> inputs ) throws Exception {
var publisher = new SyncSubmissionPublisher<NodeOutput<State>>();
public AsyncIterator<NodeOutput<State>> stream( Map<String,Object> inputs ) throws Exception {

var queue = new AsyncQueue<NodeOutput<State>>( java.lang.Runnable::run );

var executor = Executors.newSingleThreadExecutor();
var executor = Executors.newSingleThreadExecutor();

executor.submit(() -> {
executor.submit(() -> {
var currentState = stateFactory.apply(inputs);
var currentNodeId = entryPoint;
Map<String, Object> partialState;

do {
try {
try (queue) {
do {
var action = nodes.get(currentNodeId);
if (action == null) {
publisher.closeExceptionally(Errors.missingNode.exception(currentNodeId));
queue.closeExceptionally(Errors.missingNode.exception(currentNodeId));
break;
}

partialState = action.apply(currentState).get();

currentState = mergeState(currentState, partialState);

publisher.submit(new NodeOutput<>(currentNodeId, currentState));
queue.put(new NodeOutput<>(currentNodeId, currentState));

if (Objects.equals(currentNodeId, finishPoint)) {
break;
}

currentNodeId = nextNodeId(currentNodeId, currentState);

} catch (Exception e) {
publisher.closeExceptionally(e);
break;
}
} while (!Objects.equals(currentNodeId, END));

} while (!Objects.equals(currentNodeId, END));
} catch (Exception e) {
queue.closeExceptionally(e);
}

publisher.close();
});
});

return publisher;
return queue;
}

public Optional<State> invoke( Map<String,Object> inputs ) throws Exception {

var future = convertPublisherToStream(stream(inputs));
var sourceIterator = stream(inputs).iterator();

var result = future.get();
var result = StreamSupport.stream(
Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED),
false);

return result.reduce((a, b) -> b).map( NodeOutput::state);
}
Expand Down
24 changes: 20 additions & 4 deletions src/main/java/org/bsc/langgraph4j/async/AsyncIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@

public interface AsyncIterator<T> extends Iterable<T> {

record Data<T>(T data, boolean done) {}
record Data<T>(T data, boolean done, Throwable error) {
public Data(T data, boolean done) {
this(data, done, null );
}
public Data(Throwable error) {
this(null, false, error );
}
}

CompletableFuture<Data<T>> next();

default CompletableFuture<Void> forEachAsync( final AsyncFunction<T,Void> consumer) {

return next().thenCompose(data -> {
if (data.error != null ) {
var error = new CompletableFuture<Void>();
error.completeExceptionally(data.error);
return error;
}
if (data.done) {
return completedFuture(null);
}
Expand All @@ -35,8 +47,9 @@ public boolean hasNext() {
return false;
}

return !currentFetchedData.updateAndGet( (v) -> AsyncIterator.this.next().join() ).done;
var next = currentFetchedData.updateAndGet( (v) -> AsyncIterator.this.next().join() );

return !next.done();
}

@Override
Expand All @@ -47,11 +60,14 @@ public T next() {
throw new NoSuchElementException("no more elements into iterator");
}
}
if (currentFetchedData.get().done) {
if (currentFetchedData.get().error() != null ) {
throw new IllegalStateException(currentFetchedData.get().error());
}
if (currentFetchedData.get().done()) {
throw new NoSuchElementException("no more elements into iterator");
}

return currentFetchedData.getAndUpdate((v) -> null).data;
return currentFetchedData.getAndUpdate((v) -> null).data();
}
};
}
Expand Down
32 changes: 20 additions & 12 deletions src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.concurrent.CompletableFuture.completedFuture;

public class AsyncQueue<E> implements AsyncIterator<E>, AutoCloseable {

private BlockingQueue<Data<E>> queue;
private final Executor executor;
private final AtomicReference<Throwable> exception = new AtomicReference<>();

public AsyncQueue() {
this(ForkJoinPool.commonPool());
Expand All @@ -24,30 +26,36 @@ public AsyncQueue(Executor executor) {
* @throws InterruptedException if interrupted while waiting for space to become available
*/
public void put(E e) throws InterruptedException {
Objects.requireNonNull(queue);
if( exception.get() != null ) {
throw new IllegalStateException("Queue has been closed with exception!");
}
queue.put(new Data<>(e, false));
}

public void closeExceptionally(Throwable ex) {
exception.set(ex);
}

@Override
public void close() throws Exception {
Objects.requireNonNull(queue);
queue.put(new Data<>(null, true));
queue = null;
if (exception.get() != null) {
queue.put(new Data<>(exception.get()));
}
else {
queue.put(new Data<>(null, true));
}
}


@Override
public CompletableFuture<Data<E>> next() {
// queue has been closed
if( queue == null ) {
// final var result = new CompletableFuture<Data<E>>();
// result.completeExceptionally( new IllegalStateException("Queue has been closed"));
// return result;
return completedFuture(new Data<>(null, true));
}
return CompletableFuture.supplyAsync( () -> {
try {
return queue.take();
var result = queue.take();
if( result.error() != null ) {
queue = null;
}
return result;
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Expand Down
87 changes: 84 additions & 3 deletions src/test/java/org/bsc/langgraph4j/AsyncTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import org.bsc.langgraph4j.async.AsyncQueue;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.stream.StreamSupport;

import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.ForkJoinPool.commonPool;
Expand Down Expand Up @@ -106,4 +105,86 @@ public void asyncQueueDirectTest() throws Exception {

}

@Test
public void asyncQueueToStreamTest() throws Exception {

// AsyncQueue initialized with a direct executor. No thread is used on next() invocation
final var queue = new AsyncQueue<String>(Runnable::run);

commonPool().execute( () -> {
try(queue) {
for( int i = 0 ; i < 10 ; ++i ) {
queue.put( "e"+i );
}
} catch (Exception e) {
throw new RuntimeException(e);
}

});

var sourceIterator = queue.iterator();

var result = StreamSupport.stream(
Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED),
false);

var lastElement = result.reduce((a, b) -> b);

assertTrue( lastElement.isPresent());
assertEquals( lastElement.get(), "e9" );

}

@Test
public void asyncQueueIteratorExceptionTest() throws Exception {

// AsyncQueue initialized with a direct executor. No thread is used on next() invocation
final var queue = new AsyncQueue<String>(Runnable::run);

commonPool().execute( () -> {
try(queue) {
for( int i = 0 ; i < 2 ; ++i ) {
queue.put( "e"+i );
}
queue.closeExceptionally(new Exception("test"));

} catch (Exception e) {
queue.closeExceptionally(e);
}

});

var sourceIterator = queue.iterator();

var result = StreamSupport.stream(
Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED),
false);

assertThrows( Exception.class, () -> result.reduce((a, b) -> b ));

}

@Test
public void asyncQueueForEachExceptionTest() throws Exception {

// AsyncQueue initialized with a direct executor. No thread is used on next() invocation
final var queue = new AsyncQueue<String>(Runnable::run);

commonPool().execute( () -> {
try(queue) {
for( int i = 0 ; i < 2 ; ++i ) {
queue.put( "e"+i );
}
queue.closeExceptionally(new Exception("test"));

} catch (Exception e) {
queue.closeExceptionally(e);
}

});

assertThrows( Exception.class, () -> queue.forEachAsync( consumer_async(System.out::println) ).get() );

}

}

0 comments on commit 19b43fd

Please sign in to comment.