From 7e19f1e8fc6e731a8def851e81100f05e752dbcf Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Tue, 14 May 2024 01:34:10 +0200 Subject: [PATCH] refactor(core-jdk8): Agent State Management - AgentState from interface to concrete class - AppendableValue a readonly interface - Create internal AppendableValueRW to update state --- .../java/org/bsc/langgraph4j/GraphState.java | 26 +------- .../org/bsc/langgraph4j/state/AgentState.java | 65 ++++++++++++++++--- .../langgraph4j/state/AppendableValue.java | 34 +++------- .../langgraph4j/state/AppendableValueRW.java | 50 ++++++++++++++ .../langgraph4j/utils/CollectionsUtils.java | 7 ++ .../org/bsc/langgraph4j/BaseAgentState.java | 15 ----- .../org/bsc/langgraph4j/LangGraphTest.java | 5 +- 7 files changed, 126 insertions(+), 76 deletions(-) create mode 100644 core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java delete mode 100644 core-jdk8/src/test/java/org/bsc/langgraph4j/BaseAgentState.java diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/GraphState.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/GraphState.java index e1d6091..b94f038 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/GraphState.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/GraphState.java @@ -7,12 +7,9 @@ import org.bsc.langgraph4j.action.AsyncNodeAction; import org.bsc.langgraph4j.state.AgentState; import org.bsc.langgraph4j.state.AgentStateFactory; -import org.bsc.langgraph4j.state.AppendableValue; import java.util.*; import java.util.concurrent.LinkedBlockingQueue; -import java.util.stream.Collectors; -import java.util.stream.Stream; import java.util.stream.StreamSupport; import static java.lang.String.format; @@ -80,27 +77,6 @@ public class Runnable { ); } - private Object mergeFunction(Object currentValue, Object newValue) { - if (currentValue instanceof AppendableValue ) { - ((AppendableValue) currentValue).append( newValue ); - return currentValue; - } - return newValue; - } - private State mergeState( State currentState, Map partialState) { - Objects.requireNonNull(currentState, "currentState"); - - if( partialState == null || partialState.isEmpty() ) { - return currentState; - } - var mergedMap = Stream.concat(currentState.data().entrySet().stream(), partialState.entrySet().stream()) - .collect(Collectors.toMap( - Map.Entry::getKey, - Map.Entry::getValue, - this::mergeFunction)); - - return stateFactory.apply(mergedMap); - } private String nextNodeId( String nodeId , State state ) throws Exception { @@ -145,7 +121,7 @@ public AsyncGenerator> stream(Map inputs ) thro partialState = action.apply(currentState).get(); - currentState = mergeState(currentState, partialState); + currentState = currentState.mergeWith(partialState, stateFactory); var data = new NodeOutput<>(currentNodeId, currentState); diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AgentState.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AgentState.java index a8977e5..c16f862 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AgentState.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AgentState.java @@ -1,21 +1,70 @@ package org.bsc.langgraph4j.state; -import java.util.List; -import java.util.Optional; +import lombok.var; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.unmodifiableMap; import static java.util.Optional.ofNullable; -public interface AgentState { +public class AgentState { + + private final java.util.Map data; - java.util.Map data(); + public AgentState( Map initData ) { + this.data = new HashMap<>(initData); + } + public final java.util.Map data() { + return unmodifiableMap(data); + } - default Optional value(String key) { + public final Optional value(String key) { return ofNullable((T) data().get(key)); }; - default Optional> appendableValue(String key ) { - return ofNullable( ((AppendableValue)data().get(key))) - .map(AppendableValue::values); + public final AppendableValue appendableValue(String key ) { + Object value = this.data.get(key); + + if( value instanceof AppendableValue ) { + return (AppendableValue) value; + } + if( value instanceof Collection) { + return new AppendableValueRW<>((Collection)value); + } + AppendableValueRW rw = new AppendableValueRW<>(); + if ( value != null ) { + rw.append(value); + } + this.data.put(key, rw); + return rw; + + } + + private Object mergeFunction(Object currentValue, Object newValue) { + if (currentValue instanceof AppendableValueRW) { + ((AppendableValueRW) currentValue).append( newValue ); + return currentValue; + } + return newValue; + } + public State mergeWith(Map partialState, AgentStateFactory factory) { + + if( partialState == null || partialState.isEmpty() ) { + return factory.apply(data()); + } + var mergedMap = Stream.concat(data().entrySet().stream(), partialState.entrySet().stream()) + .collect(Collectors.toMap( + Map.Entry::getKey, + Map.Entry::getValue, + this::mergeFunction)); + + return factory.apply(mergedMap); + } + @Override + public String toString() { + return data.toString(); } } diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java index 69a068e..d2205f5 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValue.java @@ -1,34 +1,16 @@ package org.bsc.langgraph4j.state; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; +import java.util.*; import static java.util.Collections.unmodifiableList; -public class AppendableValue { - private final List values; - public AppendableValue( List values) { - this.values = new ArrayList<>(values); - } - public AppendableValue() { - this(Collections.emptyList()); - } +public interface AppendableValue { - public List values() { - return unmodifiableList(values); - } - public void append(Object value) { - if (value instanceof Collection ) { - this.values.addAll((Collection) value); - } - else { - this.values.add((T)value); - } - } + List values(); - public String toString() { - return String.valueOf(values); - } + boolean isEmpty() ; + int size() ; + + Optional last() ; + Optional lastMinus( int n ) ; } diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java new file mode 100644 index 0000000..36f0485 --- /dev/null +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/state/AppendableValueRW.java @@ -0,0 +1,50 @@ +package org.bsc.langgraph4j.state; + +import java.util.*; + +import static java.util.Collections.unmodifiableList; + +public class AppendableValueRW implements AppendableValue { + private final List values; + + public AppendableValueRW( Collection values) { + this.values = new ArrayList<>(values); + } + public AppendableValueRW() { + this(Collections.emptyList()); + } + public void append(Object value) { + if (value instanceof Collection) { + this.values.addAll((Collection) value); + } + else { + this.values.add((T)value); + } + } + + public List values() { + return unmodifiableList(values); + } + + public boolean isEmpty() { + return values().isEmpty(); + } + public int size() { + return values().size(); + } + public Optional last() { + List values = values(); + return ( values == null || values.isEmpty() ) ? Optional.empty() : Optional.of(values.get(values.size()-1)); + } + public Optional lastMinus( int n ) { + if( values == null || values.isEmpty() ) return Optional.empty(); + if( n < 0 ) return Optional.empty(); + if( values.size() - n - 1 < 0 ) return Optional.empty(); + return Optional.of(values.get(values.size()-n-1)); + } + + public String toString() { + return String.valueOf(values); + } + +} diff --git a/core-jdk8/src/main/java/org/bsc/langgraph4j/utils/CollectionsUtils.java b/core-jdk8/src/main/java/org/bsc/langgraph4j/utils/CollectionsUtils.java index 1148dc2..e8f39f3 100644 --- a/core-jdk8/src/main/java/org/bsc/langgraph4j/utils/CollectionsUtils.java +++ b/core-jdk8/src/main/java/org/bsc/langgraph4j/utils/CollectionsUtils.java @@ -29,4 +29,11 @@ public static Map mapOf( K k1, V v1, K k2, V v2 ) { result.put(k2,v2); return unmodifiableMap(result); } + public static Map mapOf( K k1, V v1, K k2, V v2, K k3, V v3 ) { + Map result = new HashMap(); + result.put(k1,v1); + result.put(k2,v2); + result.put(k3,v3); + return unmodifiableMap(result); + } } diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/BaseAgentState.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/BaseAgentState.java deleted file mode 100644 index 9e92ae7..0000000 --- a/core-jdk8/src/test/java/org/bsc/langgraph4j/BaseAgentState.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.bsc.langgraph4j; - -import lombok.Value; -import lombok.experimental.Accessors; -import org.bsc.langgraph4j.state.AgentState; - -import java.util.Map; - -@Value -@Accessors(fluent = true) -class BaseAgentState implements AgentState { - Map data; - - -} diff --git a/core-jdk8/src/test/java/org/bsc/langgraph4j/LangGraphTest.java b/core-jdk8/src/test/java/org/bsc/langgraph4j/LangGraphTest.java index ffa626f..5d099be 100644 --- a/core-jdk8/src/test/java/org/bsc/langgraph4j/LangGraphTest.java +++ b/core-jdk8/src/test/java/org/bsc/langgraph4j/LangGraphTest.java @@ -1,6 +1,7 @@ package org.bsc.langgraph4j; import lombok.var; +import org.bsc.langgraph4j.state.AgentState; import org.junit.jupiter.api.Test; import java.util.List; @@ -26,7 +27,7 @@ public static List> sortMap(Map map ) { @Test void testValidation() throws Exception { - var workflow = new GraphState<>(BaseAgentState::new); + var workflow = new GraphState<>(AgentState::new); var exception = assertThrows(GraphStateException.class, workflow::compile); System.out.println(exception.getMessage()); assertEquals( "missing Entry Point", exception.getMessage()); @@ -80,7 +81,7 @@ void testValidation() throws Exception { @Test public void testRunningOneNode() throws Exception { - var workflow = new GraphState<>(BaseAgentState::new); + var workflow = new GraphState<>(AgentState::new); workflow.setEntryPoint("agent_1"); workflow.addNode("agent_1", node_async( state -> {