Skip to content

Commit

Permalink
test: test refinements
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Mar 28, 2024
1 parent ccaf2da commit abead3b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 32 deletions.
82 changes: 55 additions & 27 deletions agents/src/test/java/dev/langchain4j/AgentExecutorTest.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
package dev.langchain4j;

import dev.langchain4j.model.openai.OpenAiChatModel;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Map;

import static java.util.Optional.ofNullable;
import static org.junit.jupiter.api.Assertions.*;

public class AgentExecutorTest {

public static void main( String[] args) {

@BeforeAll
public static void loadEnv() {
DotEnvConfig.load();
}

private AgentExecutor.State executeAgent(String prompt ) throws Exception {

var openApiKey = DotEnvConfig.valueOf("OPENAI_API_KEY")
.orElseThrow( () -> new IllegalArgumentException("no APIKEY provided!"));
Expand All @@ -25,31 +30,54 @@ public static void main( String[] args) {
.maxTokens(2000)
.build();

try {
var agentExecutor = new AgentExecutor();

var iterator = agentExecutor.execute(
chatLanguageModel,
Map.of( "input", "what is the result of test with messages: 'MY FIRST TEST' and the result of test with message: 'MY SECOND TEST'"),
//Map.of( "input", "what is the result of test with messages: 'MY FIRST TEST'"),
List.of(new TestTool()) );

AgentExecutor.State output = null;

for( var i : iterator ) {
output = i.state();
System.out.println(i.node());
}
System.out.println( "Finished! " + ofNullable(output)
.flatMap(AgentExecutor.State::agentOutcome)
.map(AgentExecutor.AgentOutcome::finish)
.map(AgentExecutor.AgentFinish::returnValues)
.orElse(Map.of( "result", "state undefined!")) );

} catch (Exception e) {
System.out.println( "ERROR! " + e.getMessage() );

var agentExecutor = new AgentExecutor();

var iterator = agentExecutor.execute(
chatLanguageModel,
Map.of( "input", prompt ),
List.of(new TestTool()) );

AgentExecutor.State state = null;

for( var i : iterator ) {
state = i.state();
System.out.println(i.node());
}
System.exit(0);

return state;

}

@Test
void executeAgentWithSingleToolInvocation() throws Exception {

var state = executeAgent("what is the result of test with messages: 'MY FIRST TEST'");

assertNotNull(state);
assertTrue(state.intermediateSteps().isPresent());
assertEquals( 1, state.intermediateSteps().get().size());
assertTrue(state.agentOutcome().isPresent());
assertNotNull(state.agentOutcome().get().finish());
assertTrue( state.agentOutcome().get().finish().returnValues().containsKey("returnValues"));
assertEquals("The test with the message 'MY FIRST TEST' has been executed successfully.",
state.agentOutcome().get().finish().returnValues().get("returnValues") );
}
@Test
void executeAgentWithDoubleToolInvocation() throws Exception {

var state = executeAgent("what is the result of test with messages: 'MY FIRST TEST' and the result of test with message: 'MY SECOND TEST'");

assertNotNull(state);
assertTrue(state.intermediateSteps().isPresent());
assertEquals( 2, state.intermediateSteps().get().size());
assertTrue(state.agentOutcome().isPresent());
assertNotNull(state.agentOutcome().get().finish());
assertTrue( state.agentOutcome().get().finish().returnValues().containsKey("returnValues"));
assertEquals(
"The result of the test with the message 'MY FIRST TEST' is: test tool executed: MY FIRST TEST\n" +
"The result of the test with the message 'MY SECOND TEST' is: test tool executed: MY SECOND TEST",
state.agentOutcome().get().finish().returnValues().get("returnValues") );

}
}
9 changes: 4 additions & 5 deletions agents/src/test/java/dev/langchain4j/AgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dev.langchain4j.model.output.FinishReason;
import org.junit.jupiter.api.BeforeAll;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Map;
Expand All @@ -16,12 +17,12 @@
public class AgentTest {

@BeforeAll
static void init() {
public static void loadEnv() {
DotEnvConfig.load();
}

public static void main( String[] args) throws Exception {
DotEnvConfig.load();
@Test
public void runAgentTest() throws Exception {

assertTrue(DotEnvConfig.valueOf("OPENAI_API_KEY").isPresent());

Expand Down Expand Up @@ -55,7 +56,5 @@ public static void main( String[] args) throws Exception {
assertEquals("execTest", toolExecutionRequest.name());
assertEquals("{ \"arg0\": \"hello world\"}", toolExecutionRequest.arguments().replaceAll("\n",""));

System.out.println( response );

}
}

0 comments on commit abead3b

Please sign in to comment.