From 7dd851cc9a632848609898c97561ec38a5900449 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Thu, 28 Mar 2024 00:17:11 +0100 Subject: [PATCH] feat: finalize developing langchain4j agentexecutor using langgraph4j --- .../src/main/java/dev/langchain4j/Agent.java | 60 ++++--- .../java/dev/langchain4j/AgentExecutor.java | 166 ++++++++++++++++-- .../{AIConfig.java => DotEnvConfig.java} | 26 ++- .../main/java/dev/langchain4j/ToolInfo.java | 43 +++++ .../dev/langchain4j/AgentExecutorTest.java | 43 +++++ .../test/java/dev/langchain4j/AgentTest.java | 32 ++-- .../test/java/dev/langchain4j/TestTool.java | 23 +++ .../java/org/bsc/langgraph4j/GraphState.java | 26 ++- .../org/bsc/langgraph4j/async/AsyncQueue.java | 86 ++++++--- .../org/bsc/langgraph4j/state/AgentState.java | 13 +- .../langgraph4j/state/AgentStateFactory.java | 2 - .../langgraph4j/state/AppendableValue.java | 24 ++- .../java/org/bsc/langgraph4j/AsyncTest.java | 27 ++- 13 files changed, 454 insertions(+), 117 deletions(-) rename agents/src/main/java/dev/langchain4j/{AIConfig.java => DotEnvConfig.java} (61%) create mode 100644 agents/src/main/java/dev/langchain4j/ToolInfo.java create mode 100644 agents/src/test/java/dev/langchain4j/AgentExecutorTest.java create mode 100644 agents/src/test/java/dev/langchain4j/TestTool.java diff --git a/agents/src/main/java/dev/langchain4j/Agent.java b/agents/src/main/java/dev/langchain4j/Agent.java index 16a932a..2eaf56b 100644 --- a/agents/src/main/java/dev/langchain4j/Agent.java +++ b/agents/src/main/java/dev/langchain4j/Agent.java @@ -1,44 +1,28 @@ package dev.langchain4j; -import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; -import dev.langchain4j.service.AiServices; import lombok.Builder; import lombok.Singular; -import java.lang.reflect.Method; -import java.util.*; - -import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; @Builder public class Agent { private final ChatLanguageModel chatLanguageModel; - @Singular private final List tools; - + @Singular private final List tools; - private List getToolSpecifications() { - var toolSpecifications = new ArrayList(); - for (Object tool : tools) { - for (Method method : tool.getClass().getDeclaredMethods()) { - if (method.isAnnotationPresent(Tool.class)) { - var toolSpecification = toolSpecificationFrom(method); - toolSpecifications.add(toolSpecification); - //context.toolExecutors.put(toolSpecification.name(), new DefaultToolExecutor(objectWithTool, method)); - } - } - } - return toolSpecifications; - } public Response execute( Map inputs ) { var messages = new ArrayList(); var promptTemplate = PromptTemplate.from( "USER: {{input}}" ).apply(inputs); @@ -47,6 +31,38 @@ public Response execute( Map inputs ) { messages.add( new UserMessage(promptTemplate.text()) ); - return chatLanguageModel.generate( messages, getToolSpecifications() ); + return chatLanguageModel.generate( messages, tools ); + } + + private PromptTemplate getToolResponseTemplate( ) { + var TEMPLATE_TOOL_RESPONSE = new StringBuilder() + .append("TOOL RESPONSE:").append('\n') + .append("---------------------").append('\n') + .append("{{observation}}").append('\n') + .append( "--------------------" ).append('\n') + .append('\n') + .toString(); + return PromptTemplate.from(TEMPLATE_TOOL_RESPONSE); + } + + public Response execute( String input, List intermediateSteps ) { + var agentScratchpadTemplate = getToolResponseTemplate(); + var userMessageTemplate = PromptTemplate.from( "USER'S INPUT: {{input}}" ).apply( Map.of( "input", input)); + + var messages = new ArrayList(); + + messages.add(new SystemMessage("You are a helpful assistant")); + + if( intermediateSteps.isEmpty()) { + messages.add(new UserMessage(userMessageTemplate.text())); + } + + for( AgentExecutor.IntermediateStep step: intermediateSteps ) { + var agentScratchpad = agentScratchpadTemplate.apply( Map.of("observation", step.observation()) ); + messages.add(new UserMessage(agentScratchpad.text())); + ; + } + + return chatLanguageModel.generate( messages, tools ); } } diff --git a/agents/src/main/java/dev/langchain4j/AgentExecutor.java b/agents/src/main/java/dev/langchain4j/AgentExecutor.java index 5d7deb9..7ed074c 100644 --- a/agents/src/main/java/dev/langchain4j/AgentExecutor.java +++ b/agents/src/main/java/dev/langchain4j/AgentExecutor.java @@ -1,42 +1,170 @@ package dev.langchain4j; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.output.FinishReason; +import org.bsc.langgraph4j.GraphState; +import org.bsc.langgraph4j.async.AsyncIterator; import org.bsc.langgraph4j.state.AgentState; +import org.bsc.langgraph4j.state.AppendableValue; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; -import static java.util.Optional.ofNullable; +import static org.bsc.langgraph4j.GraphState.END; +import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async; +import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async; public class AgentExecutor { - record AgentAction ( String tool, String tool_input, String log ) {} - record AgentFinish ( Map return_values, String log ) {} + record AgentAction ( ToolExecutionRequest toolExecutionRequest, String log ) { + public AgentAction { + Objects.requireNonNull(toolExecutionRequest); + } + } + record AgentFinish ( Map returnValues, String log ) {} record AgentOutcome(AgentAction action, AgentFinish finish) {} - record IntermediateStep(AgentAction action, String observation) { + public record IntermediateStep(AgentAction action, String observation) { } - record BaseAgentState( Map data ) implements AgentState { + public static class State implements AgentState { + + private final Map data; - BaseAgentState(Map data) { - this.data = data; - data.put("intermediate_steps", new ArrayList<>() ); + public State( Map initData ) { + this.data = new HashMap<>(initData); + if( !data.containsKey("intermediate_steps")) { + this.data.put("intermediate_steps", + new AppendableValue()); + } } + + public Map data() { + return Map.copyOf(data); + } + Optional input() { - return ofNullable((String) data.get("input")); + return value("input"); } - Optional agent_outcome() { - return ofNullable((AgentOutcome) data.get("agent_outcome")); + Optional agentOutcome() { + return value("agent_outcome"); } - List intermediate_steps() { - return (List) data.get("intermediate_steps"); + Optional> intermediateSteps() { + return appendableValue("intermediate_steps"); + } + + @Override + public String toString() { + return data.toString(); + } + } + + Map runAgent( Agent agentRunnable, State state ) throws Exception { + + var input = state.input() + .orElseThrow(() -> new IllegalArgumentException("no input provided!")); + var intermediateSteps = state.intermediateSteps() + .orElseThrow(() -> new IllegalArgumentException("no intermediateSteps provided!")); + + var response = agentRunnable.execute( input, intermediateSteps ); + + if( response.finishReason() == FinishReason.TOOL_EXECUTION ) { + + var toolExecutionRequest = response.content().toolExecutionRequests(); + if (toolExecutionRequest.size() != 1) { + throw new IllegalStateException("unexpected number of tool execution requests: " + toolExecutionRequest.size()); + } + var action = new AgentAction( toolExecutionRequest.get(0), ""); + + return Map.of("agent_outcome", new AgentOutcome( action, null ) ); + + } + else { + var result = response.content().text(); + var finish = new AgentFinish( Map.of("returnValues", result), result ); + + return Map.of("agent_outcome", new AgentOutcome( null, finish ) ); + } + + } + Map executeTools( List toolInfoList, State state ) throws Exception { + + var agentOutcome = state.agentOutcome().orElseThrow(() -> new IllegalArgumentException("no agentOutcome provided!")); + + if (agentOutcome.action() == null) { + throw new IllegalStateException("no action provided!" ); } + var toolExecutionRequest = agentOutcome.action().toolExecutionRequest(); + + var tool = toolInfoList.stream() + .filter( v -> v.specification().name().equals(toolExecutionRequest.name())) + .findFirst() + .orElseThrow(() -> new IllegalStateException("no tool found for: " + toolExecutionRequest.name())); + + var result = tool.executor().execute( toolExecutionRequest, null ); + + return Map.of("intermediate_steps", new IntermediateStep( agentOutcome.action(), result ) ); + + } + + String shouldContinue(State state) { + + if (state.agentOutcome().map(AgentOutcome::finish).isPresent()) { + return "end"; + } + return "continue"; } - public static void execute() { - System.out.println("Hello World!"); + public AsyncIterator> execute(Map inputs, List objectsWithTools) throws Exception { + + + var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY") + .orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!")); + + var chatLanguageModel = OpenAiChatModel.builder() + .apiKey( openApiKey ) + .modelName( "gpt-3.5-turbo-0613" ) + .logResponses(true) + .maxRetries(2) + .temperature(0.0) + .maxTokens(2000) + .build(); + + var toolInfoList = ToolInfo.fromList( objectsWithTools ); + + final List toolSpecifications = toolInfoList.stream() + .map(ToolInfo::specification) + .toList(); + + var agentRunnable = Agent.builder() + .chatLanguageModel(chatLanguageModel) + .tools( toolSpecifications ) + .build(); + + var workflow = new GraphState<>(State::new); + + workflow.setEntryPoint("agent"); + + workflow.addNode( "agent", node_async( state -> + runAgent(agentRunnable, state)) + ); + + workflow.addNode( "action", node_async( state -> + executeTools(toolInfoList, state)) + ); + + workflow.addConditionalEdge( + "agent", + edge_async(this::shouldContinue), + Map.of("continue", "action", "end", END) + ); + + workflow.addEdge("action", "agent"); + + var app = workflow.compile(); + + return app.stream( inputs ); } } diff --git a/agents/src/main/java/dev/langchain4j/AIConfig.java b/agents/src/main/java/dev/langchain4j/DotEnvConfig.java similarity index 61% rename from agents/src/main/java/dev/langchain4j/AIConfig.java rename to agents/src/main/java/dev/langchain4j/DotEnvConfig.java index 5fe222c..7745640 100644 --- a/agents/src/main/java/dev/langchain4j/AIConfig.java +++ b/agents/src/main/java/dev/langchain4j/DotEnvConfig.java @@ -2,13 +2,14 @@ import java.io.FileReader; import java.io.Reader; -import java.nio.file.Path; import java.nio.file.Paths; import java.util.Optional; -public class AIConfig { +public interface DotEnvConfig { - public static AIConfig load() { + static void load() { + + // Search for .env file var path = Paths.get(".").toAbsolutePath(); var filePath = Paths.get( path.toString(), ".env"); @@ -23,24 +24,21 @@ public static AIConfig load() { } } + // load .env contents in System.properties try { - return new AIConfig( filePath ); + final java.util.Properties properties = new java.util.Properties(); + + try( Reader r = new FileReader(filePath.toFile())) { + properties.load(r); + } + System.getProperties().putAll(properties); } catch (Exception e) { throw new RuntimeException(e); } } - private AIConfig( Path envFilePath ) throws Exception { - final java.util.Properties properties = new java.util.Properties(); - - try( Reader r = new FileReader(envFilePath.toFile())) { - properties.load(r); - } - System.getProperties().putAll(properties); - } - - public Optional valueOf(String key ) { + static Optional valueOf(String key ) { return Optional.ofNullable(System.getenv( key )) .or( () -> Optional.ofNullable(System.getProperty(key))); } diff --git a/agents/src/main/java/dev/langchain4j/ToolInfo.java b/agents/src/main/java/dev/langchain4j/ToolInfo.java new file mode 100644 index 0000000..a15f49a --- /dev/null +++ b/agents/src/main/java/dev/langchain4j/ToolInfo.java @@ -0,0 +1,43 @@ +package dev.langchain4j; + +import dev.langchain4j.agent.tool.DefaultToolExecutor; +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutor; +import dev.langchain4j.agent.tool.ToolSpecification; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom; + +public record ToolInfo(ToolSpecification specification, ToolExecutor executor ) { + + public ToolInfo { + Objects.requireNonNull(specification); + Objects.requireNonNull(executor); + } + + public static List of( Object ...objectsWithTools) { + return fromArray( (Object[])objectsWithTools ); + } + public static List fromArray( Object[] objectsWithTools ) { + var toolSpecifications = new ArrayList(); + + for (Object objectWithTools : objectsWithTools) { + for (Method method : objectWithTools.getClass().getDeclaredMethods()) { + if (method.isAnnotationPresent(Tool.class)) { + var toolSpecification = toolSpecificationFrom(method); + var executor = new DefaultToolExecutor(objectWithTools, method); + toolSpecifications.add( new ToolInfo( toolSpecification, executor)); + } + } + } + return List.copyOf(toolSpecifications); + } + public static List fromList(List objectsWithTools ) { + return fromArray(objectsWithTools.toArray()); + } + +} \ No newline at end of file diff --git a/agents/src/test/java/dev/langchain4j/AgentExecutorTest.java b/agents/src/test/java/dev/langchain4j/AgentExecutorTest.java new file mode 100644 index 0000000..3d3dd27 --- /dev/null +++ b/agents/src/test/java/dev/langchain4j/AgentExecutorTest.java @@ -0,0 +1,43 @@ +package dev.langchain4j; + +import dev.langchain4j.agent.tool.P; +import dev.langchain4j.agent.tool.Tool; + +import java.util.List; +import java.util.Map; + +import static java.lang.String.format; +import static java.util.Optional.ofNullable; + +public class AgentExecutorTest { + + public static void main( String[] args) { + + DotEnvConfig.load(); + + var agentExecutor = new AgentExecutor(); + + try { + var iterator = agentExecutor.execute( + Map.of( "input", "what is the result of test with message: 'MY FIRST TEST'?"), + List.of(new TestTool()) ); + + AgentExecutor.State output = null; + + for( var i : iterator ) { + output = i.state(); + System.out.println(i.node()); + } + System.out.println( "Finished! " + ofNullable(output) + .flatMap(AgentExecutor.State::agentOutcome) + .map(AgentExecutor.AgentOutcome::finish) + .map(AgentExecutor.AgentFinish::returnValues) + .orElse(Map.of( "result", "state undefined!")) ); + + } catch (Exception e) { + System.out.println( "ERROR! " + e.getMessage() ); + } + System.exit(0); + + } +} diff --git a/agents/src/test/java/dev/langchain4j/AgentTest.java b/agents/src/test/java/dev/langchain4j/AgentTest.java index 0aa6ce9..eb67a63 100644 --- a/agents/src/test/java/dev/langchain4j/AgentTest.java +++ b/agents/src/test/java/dev/langchain4j/AgentTest.java @@ -2,9 +2,8 @@ import dev.langchain4j.agent.tool.P; import dev.langchain4j.agent.tool.Tool; -import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.FinishReason; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeAll; import dev.langchain4j.model.openai.OpenAiChatModel; import java.util.Map; @@ -15,29 +14,18 @@ public class AgentTest { - static class TestTool { - - private String lastMessage; - - Optional lastMessage() { - return Optional.ofNullable(lastMessage); - } - - @Tool("tool for test AI system") - String execTest(@P("test message") String message) { - - lastMessage = format( "test tool executed: %s", message); - return lastMessage; - } + @BeforeAll + static void init() { + DotEnvConfig.load(); } - @Test - void agentCreationTest() throws Exception { - var config = AIConfig.load(); - assertTrue(config.valueOf("OPENAI_API_KEY").isPresent()); + public static void main( String[] args) throws Exception { + DotEnvConfig.load(); + + assertTrue(DotEnvConfig.valueOf("OPENAI_API_KEY").isPresent()); var chatLanguageModel = OpenAiChatModel.builder() - .apiKey( config.valueOf("OPENAI_API_KEY").get() ) + .apiKey( DotEnvConfig.valueOf("OPENAI_API_KEY").get() ) .modelName( "gpt-3.5-turbo-0613" ) .logResponses(true) .maxRetries(2) @@ -48,7 +36,7 @@ void agentCreationTest() throws Exception { var tool = new TestTool(); var agent = Agent.builder() .chatLanguageModel(chatLanguageModel) - .tool(tool) + .tools( ToolInfo.of(tool).stream().map(ToolInfo::specification).toList() ) .build(); var msg = "hello world"; diff --git a/agents/src/test/java/dev/langchain4j/TestTool.java b/agents/src/test/java/dev/langchain4j/TestTool.java new file mode 100644 index 0000000..5812b4c --- /dev/null +++ b/agents/src/test/java/dev/langchain4j/TestTool.java @@ -0,0 +1,23 @@ +package dev.langchain4j; + +import dev.langchain4j.agent.tool.P; +import dev.langchain4j.agent.tool.Tool; + +import java.util.Optional; + +import static java.lang.String.format; + +public class TestTool { + private String lastResult; + + Optional lastResult() { + return Optional.ofNullable(lastResult); + } + + @Tool("tool for test AI agent executor") + String execTest(@P("test message") String message) { + + lastResult = format( "test tool executed: %s", message); + return lastResult; + } +} diff --git a/core/src/main/java/org/bsc/langgraph4j/GraphState.java b/core/src/main/java/org/bsc/langgraph4j/GraphState.java index c5169ad..dbbff12 100644 --- a/core/src/main/java/org/bsc/langgraph4j/GraphState.java +++ b/core/src/main/java/org/bsc/langgraph4j/GraphState.java @@ -6,6 +6,7 @@ import org.bsc.langgraph4j.async.AsyncQueue; import org.bsc.langgraph4j.state.AgentState; import org.bsc.langgraph4j.state.AgentStateFactory; +import org.bsc.langgraph4j.state.AppendableValue; import java.util.*; import java.util.concurrent.Executors; @@ -70,6 +71,8 @@ public record NodeOutput( String node, State state) {} final Map> nodes = new HashMap<>(); final Map> edges = new HashMap<>(); + private final int maxIterations = 25; + Runnable() { GraphState.this.nodes.forEach( n -> @@ -81,6 +84,13 @@ public record NodeOutput( String node, State state) {} ); } + private Object mergeFunction(Object currentValue, Object newValue) { + if (currentValue instanceof AppendableValue ) { + ((AppendableValue) currentValue).append( newValue ); + return currentValue; + } + return newValue; + } private State mergeState( State currentState, Map partialState) { Objects.requireNonNull(currentState, "currentState"); @@ -91,7 +101,7 @@ private State mergeState( State currentState, Map partialState) { .collect(Collectors.toMap( Map.Entry::getKey, Map.Entry::getValue, - (oldValue, newValue) -> newValue)); + this::mergeFunction)); return stateFactory.apply(mergedMap); } @@ -114,9 +124,10 @@ private String nextNodeId( String nodeId , State state ) throws Exception { if( result == null ) { throw Errors.missingNodeInEdgeMapping.exception(nodeId, newRoute); } + return result; } - throw Errors.executionError.exception( format("invalid edge value for nodeId: %s !", nodeId) ); + throw Errors.executionError.exception( format("invalid edge value for nodeId: [%s] !", nodeId) ); } @@ -132,8 +143,8 @@ public AsyncIterator> stream( Map inputs ) thro var currentNodeId = entryPoint; Map partialState; - try (queue) { - do { + try { + for( int i = 0; i < maxIterations && !Objects.equals(currentNodeId, END); ++i ) { var action = nodes.get(currentNodeId); if (action == null) { queue.closeExceptionally(Errors.missingNode.exception(currentNodeId)); @@ -152,11 +163,14 @@ public AsyncIterator> stream( Map inputs ) thro currentNodeId = nextNodeId(currentNodeId, currentState); - } while (!Objects.equals(currentNodeId, END)); + } - } catch (Exception e) { + } catch (Throwable e) { queue.closeExceptionally(e); } + finally { + queue.close(); + } }); diff --git a/core/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java b/core/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java index c7d6ffd..15c72a2 100644 --- a/core/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java +++ b/core/src/main/java/org/bsc/langgraph4j/async/AsyncQueue.java @@ -2,56 +2,96 @@ import java.util.Objects; import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicReference; -import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.lang.String.format; -public class AsyncQueue implements AsyncIterator, AutoCloseable { +public class AsyncQueue implements AsyncIterator { - private BlockingQueue> queue; + record Item( Data data, Throwable error ) { + + boolean isEnd() { return data.done(); } + boolean isError() { + return error != null; + } + + static Item of(Data data) { + return new Item<>(data, null); + } + + static Item of(Throwable error) { + return new Item<>(null, error); + } + } + private BlockingQueue> queue; private final Executor executor; - private final AtomicReference exception = new AtomicReference<>(); + + final long timeout; + final TimeUnit timeoutUnit; public AsyncQueue() { - this(ForkJoinPool.commonPool()); + this(ForkJoinPool.commonPool(), 60, TimeUnit.SECONDS); } - public AsyncQueue(Executor executor) { + this(executor, 60, TimeUnit.SECONDS); + } + public AsyncQueue(Executor executor, long timeout, TimeUnit timeoutUnit) { + Objects.requireNonNull(executor); + Objects.requireNonNull(timeoutUnit); + queue = new SynchronousQueue<>(); this.executor = executor; + this.timeout = timeout; + this.timeoutUnit = timeoutUnit; } /** * Inserts the specified element into this queue, waiting if necessary for space to become available. * @param e Element to be inserted * @throws InterruptedException if interrupted while waiting for space to become available */ - public void put(E e) throws InterruptedException { + public final void put(E e) throws InterruptedException { Objects.requireNonNull(queue); - queue.put(new Data<>(e, false)); + queue.put( Item.of( new Data<>(e, false) ) ); } - public void closeExceptionally(Throwable ex) { - Objects.requireNonNull(queue); - exception.set(ex); + public boolean closeExceptionally(Throwable ex) { + if( queue == null ) { + return false; + } + try { + return queue.offer( Item.of(ex), timeout, timeoutUnit); + } catch (InterruptedException e) { + return false; + } } - @Override - public void close() throws Exception { - Objects.requireNonNull(queue); - queue.put(new Data<>(null, true)); - queue = null; + public final boolean close() { + if( queue == null ) { + return false; + } + try { + return queue.offer( Item.of(new Data<>(null, true) ), timeout, timeoutUnit); + } catch (InterruptedException e) { + return false; + } } - @Override - public CompletableFuture> next() { + public final CompletableFuture> next() { return CompletableFuture.supplyAsync( () -> { try { - var result = queue.take(); - if( exception.get()!=null ) { - throw new RuntimeException(exception.get()); + var result = queue.poll(timeout, timeoutUnit); + if( result == null ) { + queue = null; + throw new RuntimeException( format("queue exceed the poll timeout %d %s", timeout, timeoutUnit) ); + } + if (result.isError()) { + queue = null; + throw new RuntimeException(result.error); + } + if( result.isEnd() ) { + queue = null; } - return result; + return result.data(); } catch (InterruptedException e) { throw new RuntimeException(e); } diff --git a/core/src/main/java/org/bsc/langgraph4j/state/AgentState.java b/core/src/main/java/org/bsc/langgraph4j/state/AgentState.java index 69345f6..c769fba 100644 --- a/core/src/main/java/org/bsc/langgraph4j/state/AgentState.java +++ b/core/src/main/java/org/bsc/langgraph4j/state/AgentState.java @@ -1,12 +1,21 @@ package org.bsc.langgraph4j.state; +import java.util.List; import java.util.Optional; +import static java.util.Optional.ofNullable; + public interface AgentState { java.util.Map data(); - default Optional getValue(String key) { - return Optional.ofNullable((T) data().get(key)); + default Optional value(String key) { + return ofNullable((T) data().get(key)); }; + + default Optional> appendableValue(String key ) { + return ofNullable( ((AppendableValue)data().get(key))) + .map( ( v ) -> v.values ); + + } } diff --git a/core/src/main/java/org/bsc/langgraph4j/state/AgentStateFactory.java b/core/src/main/java/org/bsc/langgraph4j/state/AgentStateFactory.java index c145c60..a5aac40 100644 --- a/core/src/main/java/org/bsc/langgraph4j/state/AgentStateFactory.java +++ b/core/src/main/java/org/bsc/langgraph4j/state/AgentStateFactory.java @@ -1,8 +1,6 @@ package org.bsc.langgraph4j.state; -import org.bsc.langgraph4j.state.AgentState; - import java.util.Map; import java.util.function.Function; diff --git a/core/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java b/core/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java index 59088b9..4576b9a 100644 --- a/core/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java +++ b/core/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java @@ -1,6 +1,28 @@ package org.bsc.langgraph4j.state; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; -public record AppendableValue(List values ) { +public class AppendableValue { + final List values; + public AppendableValue( List values) { + this.values = new ArrayList<>(values); + } + public AppendableValue() { + this(Collections.emptyList()); + } + public void append(Object value) { + if (value instanceof Collection ) { + this.values.addAll((Collection) value); + } + else { + this.values.add((T)value); + } + } + + public String toString() { + return String.valueOf(values); + } } diff --git a/core/src/test/java/org/bsc/langgraph4j/AsyncTest.java b/core/src/test/java/org/bsc/langgraph4j/AsyncTest.java index f269e24..05a5da1 100644 --- a/core/src/test/java/org/bsc/langgraph4j/AsyncTest.java +++ b/core/src/test/java/org/bsc/langgraph4j/AsyncTest.java @@ -52,8 +52,8 @@ public CompletableFuture> next() { public void asyncQueueTest() throws Exception { var result = new ArrayList(); - - try ( final var queue = new AsyncQueue() ) { + final var queue = new AsyncQueue(); + try { queue.forEachAsync( consumer_async(result::add)).thenAccept( (t) -> { System.out.println( "Finished"); @@ -64,6 +64,9 @@ public void asyncQueueTest() throws Exception { } } + finally { + queue.close(); + } assertEquals(result.size(), 10); assertIterableEquals(result, List.of("e0", "e1", "e2", "e3", "e4", "e5", "e6", "e7", "e8", "e9")); @@ -77,13 +80,16 @@ public void asyncQueueDirectTest() throws Exception { final var queue = new AsyncQueue(Runnable::run); commonPool().execute( () -> { - try(queue) { + try { for( int i = 0 ; i < 10 ; ++i ) { queue.put( "e"+i ); } } catch (Exception e) { throw new RuntimeException(e); } + finally { + queue.close(); + } }); @@ -112,13 +118,16 @@ public void asyncQueueToStreamTest() throws Exception { final var queue = new AsyncQueue(Runnable::run); commonPool().execute( () -> { - try(queue) { + try { for( int i = 0 ; i < 10 ; ++i ) { queue.put( "e"+i ); } } catch (Exception e) { throw new RuntimeException(e); } + finally { + queue.close(); + } }); @@ -142,7 +151,7 @@ public void asyncQueueIteratorExceptionTest() throws Exception { final var queue = new AsyncQueue(Runnable::run); commonPool().execute( () -> { - try(queue) { + try { for( int i = 0 ; i < 2 ; ++i ) { queue.put( "e"+i ); } @@ -151,6 +160,9 @@ public void asyncQueueIteratorExceptionTest() throws Exception { } catch (Exception e) { queue.closeExceptionally(e); } + finally { + queue.close(); + } }); @@ -171,7 +183,7 @@ public void asyncQueueForEachExceptionTest() throws Exception { final var queue = new AsyncQueue(Runnable::run); commonPool().execute( () -> { - try(queue) { + try { for( int i = 0 ; i < 2 ; ++i ) { queue.put( "e"+i ); } @@ -180,6 +192,9 @@ public void asyncQueueForEachExceptionTest() throws Exception { } catch (Exception e) { queue.closeExceptionally(e); } + finally { + queue.close(); + } });