diff --git a/src/main/java/org/bsc/langgraph4j/GraphState.java b/src/main/java/org/bsc/langgraph4j/GraphState.java index 6c35285..1dacfb5 100644 --- a/src/main/java/org/bsc/langgraph4j/GraphState.java +++ b/src/main/java/org/bsc/langgraph4j/GraphState.java @@ -2,6 +2,8 @@ import org.bsc.langgraph4j.action.AsyncEdgeAction; import org.bsc.langgraph4j.action.AsyncNodeAction; +import org.bsc.langgraph4j.async.AsyncIterator; +import org.bsc.langgraph4j.async.AsyncQueue; import org.bsc.langgraph4j.flow.SyncSubmissionPublisher; import org.bsc.langgraph4j.state.AgentState; import org.bsc.langgraph4j.state.AgentStateFactory; @@ -46,40 +48,6 @@ GraphStateException exception(String... args ) { } } - public static CompletableFuture> convertPublisherToStream( Flow.Publisher publisher ) { - - var future = new CompletableFuture>(); - - var list = new ArrayList(); - - publisher.subscribe(new Flow.Subscriber<>() { - - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(T item) { - list.add(item); - } - - @Override - public void onError(Throwable throwable) { - future.completeExceptionally(throwable); - } - - @Override - public void onComplete() { - var result = StreamSupport.stream(Spliterators.spliterator(list, Spliterator.ORDERED), false); - future.complete(result); - - } - }); - - return future; - } - public class Runnable { enum Errors { @@ -156,21 +124,22 @@ private String nextNodeId( String nodeId , State state ) throws Exception { } - public Flow.Publisher> stream( Map inputs ) throws Exception { - var publisher = new SyncSubmissionPublisher>(); + public AsyncIterator> stream( Map inputs ) throws Exception { + + var queue = new AsyncQueue>( java.lang.Runnable::run ); - var executor = Executors.newSingleThreadExecutor(); + var executor = Executors.newSingleThreadExecutor(); - executor.submit(() -> { + executor.submit(() -> { var currentState = stateFactory.apply(inputs); var currentNodeId = entryPoint; Map partialState; - do { - try { + try (queue) { + do { var action = nodes.get(currentNodeId); if (action == null) { - publisher.closeExceptionally(Errors.missingNode.exception(currentNodeId)); + queue.closeExceptionally(Errors.missingNode.exception(currentNodeId)); break; } @@ -178,7 +147,7 @@ public Flow.Publisher> stream( Map inputs ) thr currentState = mergeState(currentState, partialState); - publisher.submit(new NodeOutput<>(currentNodeId, currentState)); + queue.put(new NodeOutput<>(currentNodeId, currentState)); if (Objects.equals(currentNodeId, finishPoint)) { break; @@ -186,24 +155,24 @@ public Flow.Publisher> stream( Map inputs ) thr currentNodeId = nextNodeId(currentNodeId, currentState); - } catch (Exception e) { - publisher.closeExceptionally(e); - break; - } + } while (!Objects.equals(currentNodeId, END)); - } while (!Objects.equals(currentNodeId, END)); + } catch (Exception e) { + queue.closeExceptionally(e); + } - publisher.close(); - }); + }); - return publisher; + return queue; } public Optional invoke( Map inputs ) throws Exception { - var future = convertPublisherToStream(stream(inputs)); + var sourceIterator = stream(inputs).iterator(); - var result = future.get(); + var result = StreamSupport.stream( + Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED), + false); return result.reduce((a, b) -> b).map( NodeOutput::state); } diff --git a/src/main/java/org/bsc/langgraph4j/async/AsyncIterator.java b/src/main/java/org/bsc/langgraph4j/async/AsyncIterator.java index 8e6553e..b4601de 100644 --- a/src/main/java/org/bsc/langgraph4j/async/AsyncIterator.java +++ b/src/main/java/org/bsc/langgraph4j/async/AsyncIterator.java @@ -9,13 +9,25 @@ public interface AsyncIterator extends Iterable { - record Data(T data, boolean done) {} + record Data(T data, boolean done, Throwable error) { + public Data(T data, boolean done) { + this(data, done, null ); + } + public Data(Throwable error) { + this(null, false, error ); + } + } CompletableFuture> next(); default CompletableFuture forEachAsync( final AsyncFunction consumer) { return next().thenCompose(data -> { + if (data.error != null ) { + var error = new CompletableFuture(); + error.completeExceptionally(data.error); + return error; + } if (data.done) { return completedFuture(null); } @@ -35,8 +47,9 @@ public boolean hasNext() { return false; } - return !currentFetchedData.updateAndGet( (v) -> AsyncIterator.this.next().join() ).done; + var next = currentFetchedData.updateAndGet( (v) -> AsyncIterator.this.next().join() ); + return !next.done(); } @Override @@ -47,11 +60,14 @@ public T next() { throw new NoSuchElementException("no more elements into iterator"); } } - if (currentFetchedData.get().done) { + if (currentFetchedData.get().error() != null ) { + throw new IllegalStateException(currentFetchedData.get().error()); + } + if (currentFetchedData.get().done()) { throw new NoSuchElementException("no more elements into iterator"); } - return currentFetchedData.getAndUpdate((v) -> null).data; + return currentFetchedData.getAndUpdate((v) -> null).data(); } }; } diff --git a/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java b/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java index cb227e8..a2af80f 100644 --- a/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java +++ b/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java @@ -2,6 +2,7 @@ import java.util.Objects; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicReference; import static java.util.concurrent.CompletableFuture.completedFuture; @@ -9,6 +10,7 @@ public class AsyncQueue implements AsyncIterator, AutoCloseable { private BlockingQueue> queue; private final Executor executor; + private final AtomicReference exception = new AtomicReference<>(); public AsyncQueue() { this(ForkJoinPool.commonPool()); @@ -24,30 +26,36 @@ public AsyncQueue(Executor executor) { * @throws InterruptedException if interrupted while waiting for space to become available */ public void put(E e) throws InterruptedException { - Objects.requireNonNull(queue); + if( exception.get() != null ) { + throw new IllegalStateException("Queue has been closed with exception!"); + } queue.put(new Data<>(e, false)); } + public void closeExceptionally(Throwable ex) { + exception.set(ex); + } + @Override public void close() throws Exception { - Objects.requireNonNull(queue); - queue.put(new Data<>(null, true)); - queue = null; + if (exception.get() != null) { + queue.put(new Data<>(exception.get())); + } + else { + queue.put(new Data<>(null, true)); + } } @Override public CompletableFuture> next() { - // queue has been closed - if( queue == null ) { -// final var result = new CompletableFuture>(); -// result.completeExceptionally( new IllegalStateException("Queue has been closed")); -// return result; - return completedFuture(new Data<>(null, true)); - } return CompletableFuture.supplyAsync( () -> { try { - return queue.take(); + var result = queue.take(); + if( result.error() != null ) { + queue = null; + } + return result; } catch (InterruptedException e) { throw new RuntimeException(e); } diff --git a/src/test/java/org/bsc/langgraph4j/AsyncTest.java b/src/test/java/org/bsc/langgraph4j/AsyncTest.java index eda38d5..f269e24 100644 --- a/src/test/java/org/bsc/langgraph4j/AsyncTest.java +++ b/src/test/java/org/bsc/langgraph4j/AsyncTest.java @@ -4,10 +4,9 @@ import org.bsc.langgraph4j.async.AsyncQueue; import org.junit.jupiter.api.Test; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.stream.StreamSupport; import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.ForkJoinPool.commonPool; @@ -106,4 +105,86 @@ public void asyncQueueDirectTest() throws Exception { } + @Test + public void asyncQueueToStreamTest() throws Exception { + + // AsyncQueue initialized with a direct executor. No thread is used on next() invocation + final var queue = new AsyncQueue(Runnable::run); + + commonPool().execute( () -> { + try(queue) { + for( int i = 0 ; i < 10 ; ++i ) { + queue.put( "e"+i ); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + }); + + var sourceIterator = queue.iterator(); + + var result = StreamSupport.stream( + Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED), + false); + + var lastElement = result.reduce((a, b) -> b); + + assertTrue( lastElement.isPresent()); + assertEquals( lastElement.get(), "e9" ); + + } + + @Test + public void asyncQueueIteratorExceptionTest() throws Exception { + + // AsyncQueue initialized with a direct executor. No thread is used on next() invocation + final var queue = new AsyncQueue(Runnable::run); + + commonPool().execute( () -> { + try(queue) { + for( int i = 0 ; i < 2 ; ++i ) { + queue.put( "e"+i ); + } + queue.closeExceptionally(new Exception("test")); + + } catch (Exception e) { + queue.closeExceptionally(e); + } + + }); + + var sourceIterator = queue.iterator(); + + var result = StreamSupport.stream( + Spliterators.spliteratorUnknownSize(sourceIterator, Spliterator.ORDERED), + false); + + assertThrows( Exception.class, () -> result.reduce((a, b) -> b )); + + } + + @Test + public void asyncQueueForEachExceptionTest() throws Exception { + + // AsyncQueue initialized with a direct executor. No thread is used on next() invocation + final var queue = new AsyncQueue(Runnable::run); + + commonPool().execute( () -> { + try(queue) { + for( int i = 0 ; i < 2 ; ++i ) { + queue.put( "e"+i ); + } + queue.closeExceptionally(new Exception("test")); + + } catch (Exception e) { + queue.closeExceptionally(e); + } + + }); + + assertThrows( Exception.class, () -> queue.forEachAsync( consumer_async(System.out::println) ).get() ); + + } + }