From 9b66f2d5ac02523817b9d40680f50a41a4dfd45d Mon Sep 17 00:00:00 2001 From: benolt Date: Fri, 16 Aug 2024 12:13:53 +0200 Subject: [PATCH] test global search with llama 3.1 --- Cargo.lock | 1 + shinkai-libs/shinkai-graphrag/Cargo.toml | 1 + shinkai-libs/shinkai-graphrag/src/llm/llm.rs | 6 + .../src/search/global_search/global_search.rs | 11 +- .../tests/global_search_tests.rs | 106 +++++++++++++++++- .../shinkai-graphrag/tests/utils/mod.rs | 1 + .../shinkai-graphrag/tests/utils/ollama.rs | 100 +++++++++++++++++ .../shinkai-graphrag/tests/utils/openai.rs | 3 +- 8 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs diff --git a/Cargo.lock b/Cargo.lock index 8d3c41100..1bdfcdd91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10406,6 +10406,7 @@ dependencies = [ "polars", "polars-lazy", "rand 0.8.5", + "reqwest 0.11.27", "serde", "serde_json", "tiktoken-rs", diff --git a/shinkai-libs/shinkai-graphrag/Cargo.toml b/shinkai-libs/shinkai-graphrag/Cargo.toml index 18650385c..7975fa77c 100644 --- a/shinkai-libs/shinkai-graphrag/Cargo.toml +++ b/shinkai-libs/shinkai-graphrag/Cargo.toml @@ -16,4 +16,5 @@ tokio = { version = "1.36", features = ["full"] } [dev-dependencies] async-openai = "0.23.4" +reqwest = { version = "0.11.26", features = ["json"] } tiktoken-rs = "0.5.9" \ No newline at end of file diff --git a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs index 0a8482144..5fa5cf633 100644 --- a/shinkai-libs/shinkai-graphrag/src/llm/llm.rs +++ b/shinkai-libs/shinkai-graphrag/src/llm/llm.rs @@ -40,6 +40,7 @@ pub trait BaseLLM { streaming: bool, callbacks: Option>, llm_params: LLMParams, + search_phase: Option, ) -> anyhow::Result; } @@ -47,3 +48,8 @@ pub trait BaseLLM { pub trait BaseTextEmbedding { async fn aembed(&self, text: &str) -> Vec; } + +pub enum GlobalSearchPhase { + Map, + Reduce, +} diff --git a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs index 0368d4a59..8bb60b955 100644 --- a/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs +++ b/shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs @@ -6,7 +6,7 @@ use std::time::Instant; use crate::context_builder::community_context::GlobalCommunityContext; use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory}; -use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use crate::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use crate::search::global_search::prompts::NO_DATA_ANSWER; use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT}; @@ -258,7 +258,13 @@ impl GlobalSearch { let search_response = self .llm - .agenerate(MessageType::Dictionary(search_messages), false, None, llm_params) + .agenerate( + MessageType::Dictionary(search_messages), + false, + None, + llm_params, + Some(GlobalSearchPhase::Map), + ) .await?; let processed_response = self.parse_search_response(&search_response); @@ -412,6 +418,7 @@ impl GlobalSearch { true, llm_callbacks, llm_params, + Some(GlobalSearchPhase::Reduce), ) .await?; diff --git a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs index 42bcd5834..5c888a8df 100644 --- a/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs +++ b/shinkai-libs/shinkai-graphrag/tests/global_search_tests.rs @@ -7,12 +7,114 @@ use shinkai_graphrag::{ llm::llm::LLMParams, search::global_search::global_search::{GlobalSearch, GlobalSearchParams}, }; -use utils::openai::{num_tokens, ChatOpenAI}; +use utils::{ + ollama::Ollama, + openai::{num_tokens, ChatOpenAI}, +}; mod utils; // #[tokio::test] -async fn global_search_test() -> Result<(), Box> { +async fn ollama_global_search_test() -> Result<(), Box> { + let base_url = "http://localhost:11434"; + let model_type = "llama3.1"; + + let llm = Ollama::new(base_url.to_string(), model_type.to_string()); + + // Load community reports + // Download dataset: https://microsoft.github.io/graphrag/data/operation_dulce/dataset.zip + + let input_dir = "./dataset"; + let community_report_table = "create_final_community_reports"; + let entity_table = "create_final_nodes"; + let entity_embedding_table = "create_final_entities"; + + let community_level = 2; + + let mut entity_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_table)).unwrap(); + let entity_df = ParquetReader::new(&mut entity_file).finish().unwrap(); + + let mut report_file = std::fs::File::open(format!("{}/{}.parquet", input_dir, community_report_table)).unwrap(); + let report_df = ParquetReader::new(&mut report_file).finish().unwrap(); + + let mut entity_embedding_file = + std::fs::File::open(format!("{}/{}.parquet", input_dir, entity_embedding_table)).unwrap(); + let entity_embedding_df = ParquetReader::new(&mut entity_embedding_file).finish().unwrap(); + + let reports = read_indexer_reports(&report_df, &entity_df, community_level)?; + let entities = read_indexer_entities(&entity_df, &entity_embedding_df, community_level)?; + + println!("Reports: {:?}", report_df.head(Some(5))); + + // Build global context based on community reports + + // Using tiktoken for token count estimation + let context_builder = GlobalCommunityContext::new(reports, Some(entities), num_tokens); + + let context_builder_params = ContextBuilderParams { + use_community_summary: false, // False means using full community reports. True means using community short summaries. + shuffle_data: true, + include_community_rank: true, + min_community_rank: 0, + community_rank_name: String::from("rank"), + include_community_weight: true, + community_weight_name: String::from("occurrence weight"), + normalize_community_weight: true, + max_tokens: 5000, // change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000) + context_name: String::from("Reports"), + column_delimiter: String::from("|"), + }; + + // LLM params are ignored for Ollama + let map_llm_params = LLMParams { + max_tokens: 1000, + temperature: 0.0, + response_format: std::collections::HashMap::from([("type".to_string(), "json_object".to_string())]), + }; + + let reduce_llm_params = LLMParams { + max_tokens: 2000, + temperature: 0.0, + response_format: std::collections::HashMap::new(), + }; + + // Perform global search + + let search_engine = GlobalSearch::new(GlobalSearchParams { + llm: Box::new(llm), + context_builder, + num_tokens_fn: num_tokens, + map_system_prompt: None, + reduce_system_prompt: None, + response_type: String::from("multiple paragraphs"), + allow_general_knowledge: false, + general_knowledge_inclusion_prompt: None, + json_mode: true, + callbacks: None, + max_data_tokens: 5000, + map_llm_params, + reduce_llm_params, + context_builder_params, + }); + + let result = search_engine + .asearch( + "What is the major conflict in this story and who are the protagonist and antagonist?".to_string(), + None, + ) + .await?; + + println!("Response: {:?}", result.response); + + println!("Context: {:?}", result.context_data); + + println!("LLM calls: {}. LLM tokens: {}", result.llm_calls, result.prompt_tokens); + + Ok(()) +} + +// #[tokio::test] +async fn openai_global_search_test() -> Result<(), Box> { let api_key = std::env::var("GRAPHRAG_API_KEY").unwrap(); let llm_model = std::env::var("GRAPHRAG_LLM_MODEL").unwrap(); diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs index d8c308735..3ef32f620 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/mod.rs @@ -1 +1,2 @@ +pub mod ollama; pub mod openai; diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs new file mode 100644 index 000000000..41d3619b8 --- /dev/null +++ b/shinkai-libs/shinkai-graphrag/tests/utils/ollama.rs @@ -0,0 +1,100 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct OllamaResponse { + pub model: String, + pub created_at: String, + pub message: OllamaMessage, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct OllamaMessage { + pub role: String, + pub content: String, +} + +pub struct Ollama { + base_url: String, + model_type: String, +} + +impl Ollama { + pub fn new(base_url: String, model_type: String) -> Self { + Ollama { base_url, model_type } + } +} + +#[async_trait] +impl BaseLLM for Ollama { + async fn agenerate( + &self, + messages: MessageType, + _streaming: bool, + _callbacks: Option>, + _llm_params: LLMParams, + search_phase: Option, + ) -> anyhow::Result { + let client = Client::new(); + let chat_url = format!("{}{}", &self.base_url, "/api/chat"); + + let messages_json = match messages { + MessageType::String(message) => json![message], + MessageType::Strings(messages) => json!(messages), + MessageType::Dictionary(messages) => { + let messages = match search_phase { + Some(GlobalSearchPhase::Map) => { + // Filter out system messages and convert them to user messages + messages + .into_iter() + .filter(|map| map.get_key_value("role").is_some_and(|(_, v)| v == "system")) + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + Some(GlobalSearchPhase::Reduce) => { + // Convert roles to user + messages + .into_iter() + .map(|map| { + map.into_iter() + .map(|(key, value)| { + if key == "role" { + return (key, "user".to_string()); + } + (key, value) + }) + .collect() + }) + .collect() + } + _ => messages, + }; + + json!(messages) + } + }; + + let payload = json!({ + "model": self.model_type, + "messages": messages_json, + "stream": false, + }); + + let response = client.post(chat_url).json(&payload).send().await?; + let response = response.json::().await?; + + Ok(response.message.content) + } +} diff --git a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs index 95e7e7b80..255d5b4e5 100644 --- a/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs +++ b/shinkai-libs/shinkai-graphrag/tests/utils/openai.rs @@ -7,7 +7,7 @@ use async_openai::{ Client, }; use async_trait::async_trait; -use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType}; +use shinkai_graphrag::llm::llm::{BaseLLM, BaseLLMCallback, GlobalSearchPhase, LLMParams, MessageType}; use tiktoken_rs::{get_bpe_from_tokenizer, tokenizer::Tokenizer}; pub struct ChatOpenAI { @@ -133,6 +133,7 @@ impl BaseLLM for ChatOpenAI { streaming: bool, callbacks: Option>, llm_params: LLMParams, + _search_phase: Option, ) -> anyhow::Result { self.agenerate(messages, streaming, callbacks, llm_params).await }