-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: finalize developing langchain4j agentexecutor using langgraph4j
- Loading branch information
1 parent
cb3cf80
commit 7dd851c
Showing
13 changed files
with
454 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
166 changes: 147 additions & 19 deletions
166
agents/src/main/java/dev/langchain4j/AgentExecutor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,170 @@ | ||
package dev.langchain4j; | ||
|
||
import dev.langchain4j.agent.tool.ToolExecutionRequest; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.model.openai.OpenAiChatModel; | ||
import dev.langchain4j.model.output.FinishReason; | ||
import org.bsc.langgraph4j.GraphState; | ||
import org.bsc.langgraph4j.async.AsyncIterator; | ||
import org.bsc.langgraph4j.state.AgentState; | ||
import org.bsc.langgraph4j.state.AppendableValue; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
import java.util.*; | ||
|
||
import static java.util.Optional.ofNullable; | ||
import static org.bsc.langgraph4j.GraphState.END; | ||
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async; | ||
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async; | ||
|
||
public class AgentExecutor { | ||
|
||
record AgentAction ( String tool, String tool_input, String log ) {} | ||
record AgentFinish ( Map<String,Object> return_values, String log ) {} | ||
record AgentAction ( ToolExecutionRequest toolExecutionRequest, String log ) { | ||
public AgentAction { | ||
Objects.requireNonNull(toolExecutionRequest); | ||
} | ||
} | ||
record AgentFinish ( Map<String,Object> returnValues, String log ) {} | ||
|
||
record AgentOutcome(AgentAction action, AgentFinish finish) {} | ||
|
||
record IntermediateStep(AgentAction action, String observation) { | ||
public record IntermediateStep(AgentAction action, String observation) { | ||
} | ||
record BaseAgentState( Map<String,Object> data ) implements AgentState { | ||
public static class State implements AgentState { | ||
|
||
private final Map<String,Object> data; | ||
|
||
BaseAgentState(Map<String, Object> data) { | ||
this.data = data; | ||
data.put("intermediate_steps", new ArrayList<>() ); | ||
public State( Map<String,Object> initData ) { | ||
this.data = new HashMap<>(initData); | ||
if( !data.containsKey("intermediate_steps")) { | ||
this.data.put("intermediate_steps", | ||
new AppendableValue<IntermediateStep>()); | ||
} | ||
} | ||
|
||
public Map<String,Object> data() { | ||
return Map.copyOf(data); | ||
} | ||
|
||
Optional<String> input() { | ||
return ofNullable((String) data.get("input")); | ||
return value("input"); | ||
} | ||
Optional<AgentOutcome> agent_outcome() { | ||
return ofNullable((AgentOutcome) data.get("agent_outcome")); | ||
Optional<AgentOutcome> agentOutcome() { | ||
return value("agent_outcome"); | ||
} | ||
List<IntermediateStep> intermediate_steps() { | ||
return (List<IntermediateStep>) data.get("intermediate_steps"); | ||
Optional<List<IntermediateStep>> intermediateSteps() { | ||
return appendableValue("intermediate_steps"); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return data.toString(); | ||
} | ||
} | ||
|
||
Map<String,Object> runAgent( Agent agentRunnable, State state ) throws Exception { | ||
|
||
var input = state.input() | ||
.orElseThrow(() -> new IllegalArgumentException("no input provided!")); | ||
var intermediateSteps = state.intermediateSteps() | ||
.orElseThrow(() -> new IllegalArgumentException("no intermediateSteps provided!")); | ||
|
||
var response = agentRunnable.execute( input, intermediateSteps ); | ||
|
||
if( response.finishReason() == FinishReason.TOOL_EXECUTION ) { | ||
|
||
var toolExecutionRequest = response.content().toolExecutionRequests(); | ||
if (toolExecutionRequest.size() != 1) { | ||
throw new IllegalStateException("unexpected number of tool execution requests: " + toolExecutionRequest.size()); | ||
} | ||
var action = new AgentAction( toolExecutionRequest.get(0), ""); | ||
|
||
return Map.of("agent_outcome", new AgentOutcome( action, null ) ); | ||
|
||
} | ||
else { | ||
var result = response.content().text(); | ||
var finish = new AgentFinish( Map.of("returnValues", result), result ); | ||
|
||
return Map.of("agent_outcome", new AgentOutcome( null, finish ) ); | ||
} | ||
|
||
} | ||
Map<String,Object> executeTools( List<ToolInfo> toolInfoList, State state ) throws Exception { | ||
|
||
var agentOutcome = state.agentOutcome().orElseThrow(() -> new IllegalArgumentException("no agentOutcome provided!")); | ||
|
||
if (agentOutcome.action() == null) { | ||
throw new IllegalStateException("no action provided!" ); | ||
} | ||
|
||
var toolExecutionRequest = agentOutcome.action().toolExecutionRequest(); | ||
|
||
var tool = toolInfoList.stream() | ||
.filter( v -> v.specification().name().equals(toolExecutionRequest.name())) | ||
.findFirst() | ||
.orElseThrow(() -> new IllegalStateException("no tool found for: " + toolExecutionRequest.name())); | ||
|
||
var result = tool.executor().execute( toolExecutionRequest, null ); | ||
|
||
return Map.of("intermediate_steps", new IntermediateStep( agentOutcome.action(), result ) ); | ||
|
||
} | ||
|
||
String shouldContinue(State state) { | ||
|
||
if (state.agentOutcome().map(AgentOutcome::finish).isPresent()) { | ||
return "end"; | ||
} | ||
return "continue"; | ||
} | ||
public static void execute() { | ||
|
||
System.out.println("Hello World!"); | ||
public AsyncIterator<GraphState.Runnable.NodeOutput<State>> execute(Map<String, Object> inputs, List<Object> objectsWithTools) throws Exception { | ||
|
||
|
||
var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY") | ||
.orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!")); | ||
|
||
var chatLanguageModel = OpenAiChatModel.builder() | ||
.apiKey( openApiKey ) | ||
.modelName( "gpt-3.5-turbo-0613" ) | ||
.logResponses(true) | ||
.maxRetries(2) | ||
.temperature(0.0) | ||
.maxTokens(2000) | ||
.build(); | ||
|
||
var toolInfoList = ToolInfo.fromList( objectsWithTools ); | ||
|
||
final List<ToolSpecification> toolSpecifications = toolInfoList.stream() | ||
.map(ToolInfo::specification) | ||
.toList(); | ||
|
||
var agentRunnable = Agent.builder() | ||
.chatLanguageModel(chatLanguageModel) | ||
.tools( toolSpecifications ) | ||
.build(); | ||
|
||
var workflow = new GraphState<>(State::new); | ||
|
||
workflow.setEntryPoint("agent"); | ||
|
||
workflow.addNode( "agent", node_async( state -> | ||
runAgent(agentRunnable, state)) | ||
); | ||
|
||
workflow.addNode( "action", node_async( state -> | ||
executeTools(toolInfoList, state)) | ||
); | ||
|
||
workflow.addConditionalEdge( | ||
"agent", | ||
edge_async(this::shouldContinue), | ||
Map.of("continue", "action", "end", END) | ||
); | ||
|
||
workflow.addEdge("action", "agent"); | ||
|
||
var app = workflow.compile(); | ||
|
||
return app.stream( inputs ); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package dev.langchain4j; | ||
|
||
import dev.langchain4j.agent.tool.DefaultToolExecutor; | ||
import dev.langchain4j.agent.tool.Tool; | ||
import dev.langchain4j.agent.tool.ToolExecutor; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
|
||
import java.lang.reflect.Method; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom; | ||
|
||
public record ToolInfo(ToolSpecification specification, ToolExecutor executor ) { | ||
|
||
public ToolInfo { | ||
Objects.requireNonNull(specification); | ||
Objects.requireNonNull(executor); | ||
} | ||
|
||
public static List<ToolInfo> of( Object ...objectsWithTools) { | ||
return fromArray( (Object[])objectsWithTools ); | ||
} | ||
public static List<ToolInfo> fromArray( Object[] objectsWithTools ) { | ||
var toolSpecifications = new ArrayList<ToolInfo>(); | ||
|
||
for (Object objectWithTools : objectsWithTools) { | ||
for (Method method : objectWithTools.getClass().getDeclaredMethods()) { | ||
if (method.isAnnotationPresent(Tool.class)) { | ||
var toolSpecification = toolSpecificationFrom(method); | ||
var executor = new DefaultToolExecutor(objectWithTools, method); | ||
toolSpecifications.add( new ToolInfo( toolSpecification, executor)); | ||
} | ||
} | ||
} | ||
return List.copyOf(toolSpecifications); | ||
} | ||
public static List<ToolInfo> fromList(List<Object> objectsWithTools ) { | ||
return fromArray(objectsWithTools.toArray()); | ||
} | ||
|
||
} |
Oops, something went wrong.