Skip to content

Commit

Permalink
feat: add Channel support
Browse files Browse the repository at this point in the history
add reducer, default value provider
add AppenderChannel to manage accumulated list of values
deprecate AppendableValue

work on #13
  • Loading branch information
bsorrentino committed Aug 7, 2024
1 parent c88be6f commit cd50013
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 13 deletions.
27 changes: 21 additions & 6 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;

import java.util.*;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static java.lang.String.format;
Expand Down Expand Up @@ -116,17 +119,29 @@ private void addCheckpoint( String nodeId, State state ) throws Exception {
}
}

Map<String,Object> getInitialStateFromSchema() {
return stateGraph.getChannels().entrySet().stream()
.filter( c -> c.getValue().getDefault().isPresent() )
.collect(Collectors.toMap(Map.Entry::getKey, e ->
e.getValue().getDefault().get().get()
));
}

State getInitialState(Map<String,Object> inputs) {

return compileConfig.getCheckpointSaver()
.flatMap(BaseCheckpointSaver::getLast)
.map( cp -> {
var state = cp.getValue().getState();
return state.mergeWith(inputs, stateGraph.getStateFactory());
var state = cp.getValue().getState().mergeWith(inputs, stateGraph.getChannels());
return stateGraph.getStateFactory().apply(state);
})
.orElseGet( () ->
stateGraph.getStateFactory().apply(inputs)
);
.orElseGet( () -> {
var initialState =
stateGraph.getStateFactory()
.andThen( state -> state.mergeWith(inputs, stateGraph.getChannels()) )
.apply( getInitialStateFromSchema() );
return stateGraph.getStateFactory().apply(initialState);
});
}

/**
Expand Down Expand Up @@ -168,7 +183,7 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs, Invok

partialState = action.apply(currentState).get();

currentState = currentState.mergeWith(partialState, stateGraph.getStateFactory());
currentState = stateGraph.getStateFactory().apply( currentState.mergeWith(partialState, stateGraph.getChannels()) );

queue.add( AsyncGenerator.Data.of( completedFuture( NodeOutput.of(currentNodeId, currentState) ) ));
addCheckpoint( currentNodeId, currentState );
Expand Down
39 changes: 32 additions & 7 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/state/AgentState.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import lombok.var;

import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -44,14 +45,23 @@ public final java.util.Map<String,Object> data() {
public final <T> Optional<T> value(String key) {
return ofNullable((T) data().get(key));
}
public final <T> T value(String key, T defaultValue ) {
return (T)value(key).orElse(defaultValue);
}

public final <T> T value(String key, Supplier<T> defaultProvider ) {
return (T)value(key).orElseGet(defaultProvider);
}

/**
* Retrieves or creates an AppendableValue associated with the given key.
*
* @param key the key whose associated AppendableValue is to be returned or created
* @param <T> the type of the value
* @return an AppendableValue associated with the given key
* @deprecated use {@link Channel} instead
*/
@Deprecated
public final <T> AppendableValue<T> appendableValue(String key) {
Object value = this.data.get(key);

Expand Down Expand Up @@ -84,25 +94,40 @@ private Object mergeFunction(Object currentValue, Object newValue) {
return newValue;
}

private Map<String,Object> updatePartialStateFromSchema( Map<String,Object> partialState, Map<String, Channel<?>> channels ) {
if( channels == null || channels.isEmpty() ) {
return partialState;
}
return partialState.entrySet().stream().map( entry -> {

var channel = channels.get(entry.getKey());
if (channel != null) {
var newValue = channel.update( entry.getKey(), data().get(entry.getKey()), entry.getValue());
return new AbstractMap.SimpleImmutableEntry<>(entry.getKey(), newValue);
}

return entry;
}).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

/**
* Merges the current state with a partial state and returns a new state.
*
* @param partialState the partial state to merge with
* @param factory the factory to create a new state
* @param <State> the type of the agent state
* @return a new state resulting from the merge
*/
public <State extends AgentState> State mergeWith(Map<String,Object> partialState, AgentStateFactory<State> factory) {
public Map<String,Object> mergeWith( Map<String,Object> partialState, Map<String, Channel<?>> channels ) {
if (partialState == null || partialState.isEmpty()) {
return factory.apply(data());
return data();
}
var mergedMap = Stream.concat(data().entrySet().stream(), partialState.entrySet().stream())

var updatedPartialState = updatePartialStateFromSchema(partialState, channels);

return Stream.concat(data().entrySet().stream(), updatedPartialState.entrySet().stream())
.collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
this::mergeFunction));

return factory.apply(mergedMap);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*
* @param <T> the type of the value
*/
@Deprecated
public interface AppendableValue<T> {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*
* @param <T> the type of the value
*/
@Deprecated
public class AppendableValueRW<T> implements AppendableValue<T>, Externalizable {
private List<T> values;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package org.bsc.langgraph4j.state;

import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;

import static java.util.Optional.ofNullable;
import static org.bsc.langgraph4j.utils.CollectionsUtils.listOf;


/*
* AppenderChannel is a {@link Channel} implementation that
* is used to accumulate a list of values.
*
* @param <T> the type of the values being accumulated
* @see Channel
*/
@Slf4j
public class AppenderChannel<T> implements Channel<List<T>> {

private final Reducer<List<T>> reducer;
private final Supplier<List<T>> defaultProvider;

@Override
public Optional<Reducer<List<T>>> getReducer() {
return ofNullable(reducer);
}

@Override
public Optional<Supplier<List<T>>> getDefault() {
return ofNullable(defaultProvider);
}

public static <T> AppenderChannel<T> of( Supplier<List<T>> defaultProvider ) {
return new AppenderChannel<>(defaultProvider);
}

private AppenderChannel( Supplier<List<T>> defaultProvider) {
this.reducer = new Reducer<List<T>>() {
@Override
public List<T> apply(List<T> left, List<T> right) {
if( left == null ) {
return right;
}
left.addAll(right);
return left;
}
};
this.defaultProvider = defaultProvider;
}

public Object update( String key, Object oldValue, Object newValue) {
try {
try { // this is to allow single value other than
T typedValue = (T) newValue;
return Channel.super.update(key, oldValue, listOf(typedValue));
} catch (ClassCastException e) {
return Channel.super.update(key, oldValue, newValue);
}
} catch (UnsupportedOperationException ex) {
log.error("Unsupported operation: probably because the appendable channel has been initialized with a immutable List. Check please !");
throw ex;
}
}

}
63 changes: 63 additions & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/state/Channel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.bsc.langgraph4j.state;

import java.util.Optional;
import java.util.function.Supplier;

/**
* A Channel is a mechanism used to maintain a state property.
* <p>
* A Channel is associated with a key and a value. The Channel is updated
* by calling the {@link #update(String, Object, Object)} method. The update
* operation is applied to the channel's value.
* <p>
* The Channel may be initialized with a default value. This default value
* is provided by a {@link Supplier}. The {@link #getDefault()} method returns
* an optional containing the default supplier.
* <p>
* The Channel may also be associated with a Reducer. The Reducer is a
* function that combines the current value of the channel with a new value
* and returns the updated value.
* <p>
* The {@link #update(String, Object, Object)} method updates the channel's
* value with the provided key, old value and new value. The update operation
* is applied to the channel's value. If the channel is not initialized, the
* default value is used. If the channel is initialized, the reducer is used
* to compute the new value.
*
* @param <T> the type of the state property
*/
public interface Channel<T> {

/**
* The Reducer, if provided, is invoked for each state property to compute value.
*
* @return An optional containing the reducer, if it exists.
*/
Optional<Reducer<T>> getReducer() ;

/**
* a Supplier that provide a default value. The result must be mutable.
*
* @return an Optional containing the default Supplier
*/
Optional<Supplier<T>> getDefault();


/**
* Update the state property with the given key and returns the new value.
*
* @param key the key of the state property to be updated
* @param oldValue the current value of the state property
* @param newValue the new value to be set
* @return the new value of the state property
*/
default Object update(String key, Object oldValue, Object newValue) {
T _new = (T)newValue;

final T _old = (oldValue == null) ?
getDefault().map(Supplier::get).orElse(null) :
(T)oldValue;

return getReducer().map( reducer -> reducer.apply( _old, _new)).orElse(_new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.bsc.langgraph4j.state;

import java.util.function.BiFunction;

public interface Reducer<T> extends BiFunction<T,T,T> {
}

0 comments on commit cd50013

Please sign in to comment.