Skip to content

Commit

Permalink
feat: complete AdaptiveRag implementation
Browse files Browse the repository at this point in the history
resolve #6
  • Loading branch information
bsorrentino committed Jun 19, 2024
1 parent 7ab1205 commit e3d6240
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 42 deletions.
3 changes: 1 addition & 2 deletions adaptive-rag/logging.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
handlers=java.util.logging.ConsoleHandler
.level=INFO
DiagramCorrectionProcess.level=FINEST
ImageToDiagramProcess.level=FINEST
AdaptiveRag.level=FINE
java.util.logging.ConsoleHandler.level=ALL
java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.state.AgentState;

import java.util.List;
Expand All @@ -12,9 +15,13 @@
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static org.bsc.langgraph4j.StateGraph.END;
import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;
import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;
import static org.bsc.langgraph4j.utils.CollectionsUtils.listOf;
import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;

@Slf4j( topic="AdaptiveRag")
public class AdaptiveRag {

/**
Expand All @@ -34,9 +41,8 @@ public String question() {
Optional<String> result = value("question");
return result.orElseThrow( () -> new IllegalStateException( "question is not set!" ) );
}
public String generation() {
Optional<String> result = value("generation");
return result.orElseThrow( () -> new IllegalStateException( "generation is not set!" ) );
public Optional<String> generation() {
return value("generation");

}
public List<String> documents() {
Expand All @@ -62,7 +68,8 @@ public AdaptiveRag( String openApiKey, String tavilyApiKey ) {
* @param state The current graph state
* @return New key added to state, documents, that contains retrieved documents
*/
public Map<String,Object> retrieve( State state ) {
private Map<String,Object> retrieve( State state ) {
log.debug("---RETRIEVE---");

String question = state.question();

Expand All @@ -81,7 +88,9 @@ public Map<String,Object> retrieve( State state ) {
* @param state The current graph state
* @return New key added to state, generation, that contains LLM generation
*/
public Map<String,Object> generate( State state ) {
private Map<String,Object> generate( State state ) {
log.debug("---GENERATE---");

String question = state.question();
List<String> documents = state.documents();

Expand All @@ -95,7 +104,8 @@ public Map<String,Object> generate( State state ) {
* @param state The current graph state
* @return Updates documents key with only filtered relevant documents
*/
public Map<String,Object> gradeDocuments( State state ) {
private Map<String,Object> gradeDocuments( State state ) {
log.debug("---CHECK DOCUMENT RELEVANCE TO QUESTION---");

String question = state.question();

Expand All @@ -106,7 +116,14 @@ public Map<String,Object> gradeDocuments( State state ) {
List<String> filteredDocs = documents.stream()
.filter( d -> {
var score = grader.apply( RetrievalGrader.Arguments.of(question, d ));
return score.binaryScore.equals("yes");
boolean relevant = score.binaryScore.equals("yes");
if( relevant ) {
log.debug("---GRADE: DOCUMENT RELEVANT---");
}
else {
log.debug("---GRADE: DOCUMENT NOT RELEVANT---");
}
return relevant;
})
.collect(Collectors.toList());

Expand All @@ -118,7 +135,9 @@ public Map<String,Object> gradeDocuments( State state ) {
* @param state The current graph state
* @return Updates question key with a re-phrased question
*/
public Map<String,Object> transformQuery( State state ) {
private Map<String,Object> transformQuery( State state ) {
log.debug("---TRANSFORM QUERY---");

String question = state.question();

String betterQuestion = QuestionRewriter.of( openApiKey ).apply( question );
Expand All @@ -131,7 +150,9 @@ public Map<String,Object> transformQuery( State state ) {
* @param state The current graph state
* @return Updates documents key with appended web results
*/
public Map<String,Object> webSearch( State state ) {
private Map<String,Object> webSearch( State state ) {
log.debug("---WEB SEARCH---");

String question = state.question();

var result = WebSearchTool.of( tavilyApiKey ).apply(question);
Expand All @@ -148,11 +169,18 @@ public Map<String,Object> webSearch( State state ) {
* @param state The current graph state
* @return Next node to call
*/
public String routeQuestion( State state ) {
private String routeQuestion( State state ) {
log.debug("---ROUTE QUESTION---");

String question = state.question();

var source = QuestionRouter.of( openApiKey ).apply( question );

if( source == QuestionRouter.Type.web_search ) {
log.debug("---ROUTE QUESTION TO WEB SEARCH---");
}
else {
log.debug("---ROUTE QUESTION TO RAG---");
}
return source.name();
}

Expand All @@ -161,12 +189,15 @@ public String routeQuestion( State state ) {
* @param state The current graph state
* @return Binary decision for next node to call
*/
public String decideToGenerate( State state ) {
private String decideToGenerate( State state ) {
log.debug("---ASSESS GRADED DOCUMENTS---");
List<String> documents = state.documents();

if(documents.isEmpty()) {
log.debug("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---");
return "transform_query";
}
log.debug( "---DECISION: GENERATE---" );
return "generate";
}

Expand All @@ -175,25 +206,73 @@ public String decideToGenerate( State state ) {
* @param state The current graph state
* @return Decision for next node to call
*/
public String gradeGeneration_v_DocumentsAndQuestion( State state ) {
private String gradeGeneration_v_documentsAndQuestion( State state ) {
log.debug("---CHECK HALLUCINATIONS---");

String question = state.question();
List<String> documents = state.documents();
String generation = state.generation();
String generation = state.generation()
.orElseThrow( () -> new IllegalStateException( "generation is not set!" ) );


HallucinationGrader.Score score = HallucinationGrader.of( openApiKey )
.apply( HallucinationGrader.Arguments.of(documents, generation));

if(Objects.equals(score.binaryScore, "yes")) {

log.debug( "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---" );
log.debug("---GRADE GENERATION vs QUESTION---");
AnswerGrader.Score score2 = AnswerGrader.of( openApiKey )
.apply( AnswerGrader.Arguments.of(question, generation) );
if( Objects.equals( score2.binaryScore, "yes") ) {
log.debug( "---DECISION: GENERATION ADDRESSES QUESTION---" );
return "useful";
}

log.debug("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---");
return "not useful";
}

log.debug( "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---" );
return "not supported";
}

public CompiledGraph<State> buildGraph() throws Exception {
var workflow = new StateGraph<>(State::new);

// Define the nodes
workflow.addNode("web_search", node_async(this::webSearch) ); // web search
workflow.addNode("retrieve", node_async(this::retrieve) ); // retrieve
workflow.addNode("grade_documents", node_async(this::gradeDocuments) ); // grade documents
workflow.addNode("generate", node_async(this::generate) ); // generatae
workflow.addNode("transform_query", node_async(this::transformQuery)); // transform_query

// Build graph
workflow.setConditionalEntryPoint(
edge_async(this::routeQuestion),
mapOf(
"web_search", "web_search",
"vectorstore", "retrieve"
));

workflow.addEdge("web_search", "generate");
workflow.addEdge("retrieve", "grade_documents");
workflow.addConditionalEdges(
"grade_documents",
edge_async(this::decideToGenerate),
mapOf(
"transform_query","transform_query",
"generate", "generate"
));
workflow.addEdge("transform_query", "retrieve");
workflow.addConditionalEdges(
"generate",
edge_async(this::gradeGeneration_v_documentsAndQuestion),
mapOf(
"not supported", "generate",
"useful", END,
"not useful", "transform_query"
));

return workflow.compile();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.logging.LogManager;
import java.util.stream.Collectors;

import static org.bsc.langgraph4j.utils.CollectionsUtils.mapOf;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

Expand Down Expand Up @@ -134,4 +135,24 @@ public void generationTest() {

System.out.println( result );
}

@Test
public void execute() throws Exception {
AdaptiveRag adaptiveRagTest = new AdaptiveRag(getOpenAiKey(), getTavilyApiKey());

var graph = adaptiveRagTest.buildGraph();

var result = graph.stream( mapOf( "question", "What player at the Bears expected to draft first in the 2024 NFL draft?" ) );

String generation = "";
for( var r : result ) {
System.out.printf( "Node: '%s':\n", r.node() );

generation = r.state().generation().orElse( "")
;
}

System.out.println( generation );

}
}
45 changes: 27 additions & 18 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/CompiledGraph.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,8 @@ void setMaxIterations(int maxIterations) {
this.maxIterations = maxIterations;
}

/**
* Determines the next node ID based on the current node ID and state.
*
* @param nodeId the current node ID
* @param state the current state
* @return the next node ID
* @throws Exception if there is an error determining the next node ID
*/
private String nextNodeId(String nodeId, State state) throws Exception {
private String nextNodeId( EdgeValue<State> route , State state, String nodeId ) throws Exception {

var route = edges.get(nodeId);
if( route == null ) {
throw StateGraph.RunnableErrors.missingEdge.exception(nodeId);
}
Expand All @@ -70,20 +61,33 @@ private String nextNodeId(String nodeId, State state) throws Exception {
}
if( route.value() != null ) {
var condition = route.value().action();

var newRoute = condition.apply(state).get();

var result = route.value().mappings().get(newRoute);
if( result == null ) {
throw StateGraph.RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
}
return result;
}

throw StateGraph.RunnableErrors.executionError.exception( format("invalid edge value for nodeId: [%s] !", nodeId) );
}

/**
* Determines the next node ID based on the current node ID and state.
*
* @param nodeId the current node ID
* @param state the current state
* @return the next node ID
* @throws Exception if there is an error determining the next node ID
*/
private String nextNodeId(String nodeId, State state) throws Exception {
return nextNodeId(edges.get(nodeId), state, nodeId);

}

private String getEntryPoint( State state ) throws Exception {
return nextNodeId(stateGraph.getEntryPoint(), state, "entryPoint");
}

/**
* Creates an AsyncGenerator stream of NodeOutput based on the provided inputs.
*
Expand All @@ -95,11 +99,13 @@ public AsyncGenerator<NodeOutput<State>> stream(Map<String,Object> inputs ) thro

return AsyncGeneratorQueue.of(new LinkedBlockingQueue<>(), queue -> {

var currentState = stateGraph.getStateFactory().apply(inputs);
var currentNodeId = stateGraph.getEntryPoint();
Map<String, Object> partialState;

try {
var currentState = stateGraph.getStateFactory().apply(inputs);

var currentNodeId = this.getEntryPoint( currentState );

Map<String, Object> partialState;

for(int i = 0; i < maxIterations && !Objects.equals(currentNodeId, StateGraph.END); ++i ) {
var action = nodes.get(currentNodeId);
if (action == null) {
Expand Down Expand Up @@ -181,7 +187,10 @@ public GraphRepresentation getGraph( GraphRepresentation.Type type ) {
}
});

sb.append( format("start -down-> \"%s\"\n", stateGraph.getEntryPoint() ));
var entryPoint = stateGraph.getEntryPoint();
if( entryPoint.id() != null ) {
sb.append( format("start -down-> \"%s\"\n", entryPoint.id() ));
}

conditionalEdgeCount[0] = 0; // reset

Expand Down
1 change: 1 addition & 0 deletions core-jdk8/src/main/java/org/bsc/langgraph4j/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@Accessors(fluent = true)
class Node<State extends AgentState> {


/**
* The unique identifier for the node.
*/
Expand Down
Loading

0 comments on commit e3d6240

Please sign in to comment.