Skip to content

Commit

Permalink
test global search with llama 3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
benolt committed Aug 21, 2024
1 parent 327c996 commit 9b66f2d
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 5 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions shinkai-libs/shinkai-graphrag/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ tokio = { version = "1.36", features = ["full"] }

[dev-dependencies]
async-openai = "0.23.4"
reqwest = { version = "0.11.26", features = ["json"] }
tiktoken-rs = "0.5.9"
6 changes: 6 additions & 0 deletions shinkai-libs/shinkai-graphrag/src/llm/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@ pub trait BaseLLM {
streaming: bool,
callbacks: Option<Vec<BaseLLMCallback>>,
llm_params: LLMParams,
search_phase: Option<GlobalSearchPhase>,
) -> anyhow::Result<String>;
}

#[async_trait]
pub trait BaseTextEmbedding {
async fn aembed(&self, text: &str) -> Vec<f64>;
}

pub enum GlobalSearchPhase {
Map,
Reduce,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::time::Instant;

use crate::context_builder::community_context::GlobalCommunityContext;
use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory};
use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType};
use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType};
use crate::search::global_search::prompts::NO_DATA_ANSWER;

use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT};
Expand Down Expand Up @@ -258,7 +258,13 @@ impl GlobalSearch {

let search_response = self
.llm
.agenerate(MessageType::Dictionary(search_messages), false, None, llm_params)
.agenerate(
MessageType::Dictionary(search_messages),
false,
None,
llm_params,
Some(GlobalSearchPhase::Map),
)
.await?;

let processed_response = self.parse_search_response(&search_response);
Expand Down Expand Up @@ -412,6 +418,7 @@ impl GlobalSearch {
true,
llm_callbacks,
llm_params,
Some(GlobalSearchPhase::Reduce),
)
.await?;

Expand Down
106 changes: 104 additions & 2 deletions shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,114 @@ use shinkai_graphrag::{
llm::llm::LLMParams,
search::global_search::global_search::{GlobalSearch, GlobalSearchParams},
};
use utils::openai::{num_tokens, ChatOpenAI};
use utils::{
ollama::Ollama,
openai::{num_tokens, ChatOpenAI},
};

mod utils;

// #[tokio::test]
async fn global_search_test() -> Result<(), Box<dyn std::error::Error>> {
async fn ollama_global_search_test() -> Result<(), Box<dyn std::error::Error>> {
let base_url = "http://localhost:11434";
let model_type = "llama3.1";

let llm = Ollama::new(base_url.to_string(), model_type.to_string());

// Load community reports
// Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip

let input_dir = "./dataset";
let community_report_table = "create_final_community_reports";
let entity_table = "create_final_nodes";
let entity_embedding_table = "create_final_entities";

let community_level = 2;

let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap();
let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap();

let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap();
let report_df = ParquetReader::new(&mut report_file).finish().unwrap();

let mut entity_embedding_file =
std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap();
let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap();

let reports = read_indexer_reports(&report_df, &entity_df, community_level)?;
let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?;

println!("Reports: {:?}", report_df.head(Some(5)));

// Build global context based on community reports

// Using tiktoken for token count estimation
let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens);

let context_builder_params = ContextBuilderParams {
use_community_summary: false, // False means using full community reports. True means using community short summaries.
shuffle_data: true,
include_community_rank: true,
min_community_rank: 0,
community_rank_name: String::from("rank"),
include_community_weight: true,
community_weight_name: String::from("occurrence weight"),
normalize_community_weight: true,
max_tokens: 5000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
context_name: String::from("Reports"),
column_delimiter: String::from("|"),
};

// LLM params are ignored for Ollama
let map_llm_params = LLMParams {
max_tokens: 1000,
temperature: 0.0,
response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]),
};

let reduce_llm_params = LLMParams {
max_tokens: 2000,
temperature: 0.0,
response_format: std::collections::HashMap::new(),
};

// Perform global search

let search_engine = GlobalSearch::new(GlobalSearchParams {
llm: Box::new(llm),
context_builder,
num_tokens_fn: num_tokens,
map_system_prompt: None,
reduce_system_prompt: None,
response_type: String::from("multiple paragraphs"),
allow_general_knowledge: false,
general_knowledge_inclusion_prompt: None,
json_mode: true,
callbacks: None,
max_data_tokens: 5000,
map_llm_params,
reduce_llm_params,
context_builder_params,
});

let result = search_engine
.asearch(
"What is the major conflict in this story and who are the protagonist and antagonist?".to_string(),
None,
)
.await?;

println!("Response: {:?}", result.response);

println!("Context: {:?}", result.context_data);

println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens);

Ok(())
}

// #[tokio::test]
async fn openai_global_search_test() -> Result<(), Box<dyn std::error::Error>> {
let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap();
let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap();

Expand Down
1 change: 1 addition & 0 deletions shinkai-libs/shinkai-graphrag/tests/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod ollama;
pub mod openai;
100 changes: 100 additions & 0 deletions shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType};

#[derive(Serialize, Deserialize, Debug)]
pub struct OllamaResponse {
pub model: String,
pub created_at: String,
pub message: OllamaMessage,
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
pub struct OllamaMessage {
pub role: String,
pub content: String,
}

pub struct Ollama {
base_url: String,
model_type: String,
}

impl Ollama {
pub fn new(base_url: String, model_type: String) -> Self {
Ollama { base_url, model_type }
}
}

#[async_trait]
impl BaseLLM for Ollama {
async fn agenerate(
&self,
messages: MessageType,
_streaming: bool,
_callbacks: Option<Vec<BaseLLMCallback>>,
_llm_params: LLMParams,
search_phase: Option<GlobalSearchPhase>,
) -> anyhow::Result<String> {
let client = Client::new();
let chat_url = format!("{}{}", &self.base_url, "/api/chat");

let messages_json = match messages {
MessageType::String(message) => json![message],
MessageType::Strings(messages) => json!(messages),
MessageType::Dictionary(messages) => {
let messages = match search_phase {
Some(GlobalSearchPhase::Map) => {
// Filter out system messages and convert them to user messages
messages
.into_iter()
.filter(|map| map.get_key_value("role").is_some_and(|(_, v)| v == "system"))
.map(|map| {
map.into_iter()
.map(|(key, value)| {
if key == "role" {
return (key, "user".to_string());
}
(key, value)
})
.collect()
})
.collect()
}
Some(GlobalSearchPhase::Reduce) => {
// Convert roles to user
messages
.into_iter()
.map(|map| {
map.into_iter()
.map(|(key, value)| {
if key == "role" {
return (key, "user".to_string());
}
(key, value)
})
.collect()
})
.collect()
}
_ => messages,
};

json!(messages)
}
};

let payload = json!({
"model": self.model_type,
"messages": messages_json,
"stream": false,
});

let response = client.post(chat_url).json(&payload).send().await?;
let response = response.json::<OllamaResponse>().await?;

Ok(response.message.content)
}
}
3 changes: 2 additions & 1 deletion shinkai-libs/shinkai-graphrag/tests/utils/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use async_openai::{
Client,
};
use async_trait::async_trait;
use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType};
use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType};
use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer};

pub struct ChatOpenAI {
Expand Down Expand Up @@ -133,6 +133,7 @@ impl BaseLLM for ChatOpenAI {
streaming: bool,
callbacks: Option<Vec<BaseLLMCallback>>,
llm_params: LLMParams,
_search_phase: Option<GlobalSearchPhase>,
) -> anyhow::Result<String> {
self.agenerate(messages, streaming, callbacks, llm_params).await
}
Expand Down

0 comments on commit 9b66f2d

Please sign in to comment.