Skip to content

Commit

Permalink
feat: finalize developing langchain4j agentexecutor using langgraph4j
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Mar 27, 2024
1 parent cb3cf80 commit 7dd851c
Show file tree
Hide file tree
Showing 13 changed files with 454 additions and 117 deletions.
60 changes: 38 additions & 22 deletions agents/src/main/java/dev/langchain4j/Agent.java
Original file line number Diff line number Diff line change
@@ -1,44 +1,28 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.AiServices;
import lombok.Builder;
import lombok.Singular;

import java.lang.reflect.Method;
import java.util.*;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

@Builder
public class Agent {

private final ChatLanguageModel chatLanguageModel;
@Singular private final List<Object> tools;

@Singular private final List<ToolSpecification> tools;

private List<ToolSpecification> getToolSpecifications() {
var toolSpecifications = new ArrayList<ToolSpecification>();

for (Object tool : tools) {
for (Method method : tool.getClass().getDeclaredMethods()) {
if (method.isAnnotationPresent(Tool.class)) {
var toolSpecification = toolSpecificationFrom(method);
toolSpecifications.add(toolSpecification);
//context.toolExecutors.put(toolSpecification.name(), new DefaultToolExecutor(objectWithTool, method));
}
}
}
return toolSpecifications;
}
public Response<AiMessage> execute( Map<String,Object> inputs ) {
var messages = new ArrayList<ChatMessage>();
var promptTemplate = PromptTemplate.from( "USER: {{input}}" ).apply(inputs);
Expand All @@ -47,6 +31,38 @@ public Response<AiMessage> execute( Map<String,Object> inputs ) {

messages.add( new UserMessage(promptTemplate.text()) );

return chatLanguageModel.generate( messages, getToolSpecifications() );
return chatLanguageModel.generate( messages, tools );
}

private PromptTemplate getToolResponseTemplate( ) {
var TEMPLATE_TOOL_RESPONSE = new StringBuilder()
.append("TOOL RESPONSE:").append('\n')
.append("---------------------").append('\n')
.append("{{observation}}").append('\n')
.append( "--------------------" ).append('\n')
.append('\n')
.toString();
return PromptTemplate.from(TEMPLATE_TOOL_RESPONSE);
}

public Response<AiMessage> execute( String input, List<AgentExecutor.IntermediateStep> intermediateSteps ) {
var agentScratchpadTemplate = getToolResponseTemplate();
var userMessageTemplate = PromptTemplate.from( "USER'S INPUT: {{input}}" ).apply( Map.of( "input", input));

var messages = new ArrayList<ChatMessage>();

messages.add(new SystemMessage("You are a helpful assistant"));

if( intermediateSteps.isEmpty()) {
messages.add(new UserMessage(userMessageTemplate.text()));
}

for( AgentExecutor.IntermediateStep step: intermediateSteps ) {
var agentScratchpad = agentScratchpadTemplate.apply( Map.of("observation", step.observation()) );
messages.add(new UserMessage(agentScratchpad.text()));
;
}

return chatLanguageModel.generate( messages, tools );
}
}
166 changes: 147 additions & 19 deletions agents/src/main/java/dev/langchain4j/AgentExecutor.java
Original file line number Diff line number Diff line change
@@ -1,42 +1,170 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.output.FinishReason;
import org.bsc.langgraph4j.GraphState;
import org.bsc.langgraph4j.async.AsyncIterator;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AppendableValue;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;

import static java.util.Optional.ofNullable;
import static org.bsc.langgraph4j.GraphState.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;

public class AgentExecutor {

record AgentAction ( String tool, String tool_input, String log ) {}
record AgentFinish ( Map<String,Object> return_values, String log ) {}
record AgentAction ( ToolExecutionRequest toolExecutionRequest, String log ) {
public AgentAction {
Objects.requireNonNull(toolExecutionRequest);
}
}
record AgentFinish ( Map<String,Object> returnValues, String log ) {}

record AgentOutcome(AgentAction action, AgentFinish finish) {}

record IntermediateStep(AgentAction action, String observation) {
public record IntermediateStep(AgentAction action, String observation) {
}
record BaseAgentState( Map<String,Object> data ) implements AgentState {
public static class State implements AgentState {

private final Map<String,Object> data;

BaseAgentState(Map<String, Object> data) {
this.data = data;
data.put("intermediate_steps", new ArrayList<>() );
public State( Map<String,Object> initData ) {
this.data = new HashMap<>(initData);
if( !data.containsKey("intermediate_steps")) {
this.data.put("intermediate_steps",
new AppendableValue<IntermediateStep>());
}
}

public Map<String,Object> data() {
return Map.copyOf(data);
}

Optional<String> input() {
return ofNullable((String) data.get("input"));
return value("input");
}
Optional<AgentOutcome> agent_outcome() {
return ofNullable((AgentOutcome) data.get("agent_outcome"));
Optional<AgentOutcome> agentOutcome() {
return value("agent_outcome");
}
List<IntermediateStep> intermediate_steps() {
return (List<IntermediateStep>) data.get("intermediate_steps");
Optional<List<IntermediateStep>> intermediateSteps() {
return appendableValue("intermediate_steps");
}

@Override
public String toString() {
return data.toString();
}
}

Map<String,Object> runAgent( Agent agentRunnable, State state ) throws Exception {

var input = state.input()
.orElseThrow(() -> new IllegalArgumentException("no input provided!"));
var intermediateSteps = state.intermediateSteps()
.orElseThrow(() -> new IllegalArgumentException("no intermediateSteps provided!"));

var response = agentRunnable.execute( input, intermediateSteps );

if( response.finishReason() == FinishReason.TOOL_EXECUTION ) {

var toolExecutionRequest = response.content().toolExecutionRequests();
if (toolExecutionRequest.size() != 1) {
throw new IllegalStateException("unexpected number of tool execution requests: " + toolExecutionRequest.size());
}
var action = new AgentAction( toolExecutionRequest.get(0), "");

return Map.of("agent_outcome", new AgentOutcome( action, null ) );

}
else {
var result = response.content().text();
var finish = new AgentFinish( Map.of("returnValues", result), result );

return Map.of("agent_outcome", new AgentOutcome( null, finish ) );
}

}
Map<String,Object> executeTools( List<ToolInfo> toolInfoList, State state ) throws Exception {

var agentOutcome = state.agentOutcome().orElseThrow(() -> new IllegalArgumentException("no agentOutcome provided!"));

if (agentOutcome.action() == null) {
throw new IllegalStateException("no action provided!" );
}

var toolExecutionRequest = agentOutcome.action().toolExecutionRequest();

var tool = toolInfoList.stream()
.filter( v -> v.specification().name().equals(toolExecutionRequest.name()))
.findFirst()
.orElseThrow(() -> new IllegalStateException("no tool found for: " + toolExecutionRequest.name()));

var result = tool.executor().execute( toolExecutionRequest, null );

return Map.of("intermediate_steps", new IntermediateStep( agentOutcome.action(), result ) );

}

String shouldContinue(State state) {

if (state.agentOutcome().map(AgentOutcome::finish).isPresent()) {
return "end";
}
return "continue";
}
public static void execute() {

System.out.println("Hello World!");
public AsyncIterator<GraphState.Runnable.NodeOutput<State>> execute(Map<String, Object> inputs, List<Object> objectsWithTools) throws Exception {


var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY")
.orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!"));

var chatLanguageModel = OpenAiChatModel.builder()
.apiKey( openApiKey )
.modelName( "gpt-3.5-turbo-0613" )
.logResponses(true)
.maxRetries(2)
.temperature(0.0)
.maxTokens(2000)
.build();

var toolInfoList = ToolInfo.fromList( objectsWithTools );

final List<ToolSpecification> toolSpecifications = toolInfoList.stream()
.map(ToolInfo::specification)
.toList();

var agentRunnable = Agent.builder()
.chatLanguageModel(chatLanguageModel)
.tools( toolSpecifications )
.build();

var workflow = new GraphState<>(State::new);

workflow.setEntryPoint("agent");

workflow.addNode( "agent", node_async( state ->
runAgent(agentRunnable, state))
);

workflow.addNode( "action", node_async( state ->
executeTools(toolInfoList, state))
);

workflow.addConditionalEdge(
"agent",
edge_async(this::shouldContinue),
Map.of("continue", "action", "end", END)
);

workflow.addEdge("action", "agent");

var app = workflow.compile();

return app.stream( inputs );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import java.io.FileReader;
import java.io.Reader;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;

public class AIConfig {
public interface DotEnvConfig {

public static AIConfig load() {
static void load() {

// Search for .env file
var path = Paths.get(".").toAbsolutePath();

var filePath = Paths.get( path.toString(), ".env");
Expand All @@ -23,24 +24,21 @@ public static AIConfig load() {
}
}

// load .env contents in System.properties
try {
return new AIConfig( filePath );
final java.util.Properties properties = new java.util.Properties();

try( Reader r = new FileReader(filePath.toFile())) {
properties.load(r);
}
System.getProperties().putAll(properties);
} catch (Exception e) {
throw new RuntimeException(e);
}
}


private AIConfig( Path envFilePath ) throws Exception {
final java.util.Properties properties = new java.util.Properties();

try( Reader r = new FileReader(envFilePath.toFile())) {
properties.load(r);
}
System.getProperties().putAll(properties);
}

public Optional<String> valueOf(String key ) {
static Optional<String> valueOf(String key ) {
return Optional.ofNullable(System.getenv( key ))
.or( () -> Optional.ofNullable(System.getProperty(key)));
}
Expand Down
43 changes: 43 additions & 0 deletions agents/src/main/java/dev/langchain4j/ToolInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package dev.langchain4j;

import dev.langchain4j.agent.tool.DefaultToolExecutor;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;

public record ToolInfo(ToolSpecification specification, ToolExecutor executor ) {

public ToolInfo {
Objects.requireNonNull(specification);
Objects.requireNonNull(executor);
}

public static List<ToolInfo> of( Object ...objectsWithTools) {
return fromArray( (Object[])objectsWithTools );
}
public static List<ToolInfo> fromArray( Object[] objectsWithTools ) {
var toolSpecifications = new ArrayList<ToolInfo>();

for (Object objectWithTools : objectsWithTools) {
for (Method method : objectWithTools.getClass().getDeclaredMethods()) {
if (method.isAnnotationPresent(Tool.class)) {
var toolSpecification = toolSpecificationFrom(method);
var executor = new DefaultToolExecutor(objectWithTools, method);
toolSpecifications.add( new ToolInfo( toolSpecification, executor));
}
}
}
return List.copyOf(toolSpecifications);
}
public static List<ToolInfo> fromList(List<Object> objectsWithTools ) {
return fromArray(objectsWithTools.toArray());
}

}
Loading

0 comments on commit 7dd851c

Please sign in to comment.