Skip to content

Commit

Permalink
add prompts, global search adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
benolt committed Aug 9, 2024
1 parent 6f19a98 commit a662504
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 53 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions shinkai-libs/shinkai-graphrag/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.vscode
dataset
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ impl CommunityContext {
(result, context)
};

let compute_community_weights = entities.is_some()
let compute_community_weights = entities.as_ref().is_some_and(|e| !e.is_empty())
&& !community_reports.is_empty()
&& include_community_weight
&& (community_reports[0].attributes.is_none()
|| !community_reports[0]
.attributes
.clone()
.as_ref()
.unwrap()
.contains_key(community_weight_name));

Expand Down Expand Up @@ -219,6 +219,7 @@ impl CommunityContext {
community_rank_name,
include_community_weight,
include_community_rank,
column_delimiter,
)?;

if single_batch {
Expand All @@ -243,6 +244,7 @@ impl CommunityContext {
community_rank_name,
include_community_weight,
include_community_rank,
column_delimiter,
)?;
}

Expand Down Expand Up @@ -365,8 +367,9 @@ impl Batch {
community_rank_name: &str,
include_community_weight: bool,
include_community_rank: bool,
column_delimiter: &str,
) -> anyhow::Result<()> {
let weight_column = if include_community_weight && entities.is_some_and(|e| !e.is_empty()) {
let weight_column = if include_community_weight && entities.as_ref().is_some_and(|e| !e.is_empty()) {
Some(community_weight_name)
} else {
None
Expand All @@ -387,10 +390,20 @@ impl Batch {
return Ok(());
}

let column_delimiter = if column_delimiter.is_empty() {
b'|'
} else {
column_delimiter.as_bytes()[0]
};

let mut buffer = Cursor::new(Vec::new());
CsvWriter::new(buffer.clone()).finish(&mut record_df).unwrap();
CsvWriter::new(&mut buffer)
.include_header(true)
.with_separator(column_delimiter)
.finish(&mut record_df)?;

let mut current_context_text = String::new();
buffer.set_position(0);
buffer.read_to_string(&mut current_context_text)?;

all_context_text.push(current_context_text);
Expand All @@ -410,7 +423,11 @@ impl Batch {
}

let mut data_series = Vec::new();
for (header, records) in header.iter().zip(context_records.iter()) {
for (index, header) in header.iter().enumerate() {
let records = context_records
.iter()
.map(|r| r.get(index).unwrap_or(&String::new()).to_owned())
.collect::<Vec<_>>();
let series = Series::new(header, records);
data_series.push(series);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use crate::context_builder::community_context::GlobalCommunityContext;
use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory};
use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType};
use crate::llm::utils::num_tokens;
use crate::search::global_search::prompts::NO_DATA_ANSWER;

use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT};

#[derive(Debug, Clone)]
pub struct SearchResult {
Expand Down Expand Up @@ -90,6 +93,7 @@ pub struct GlobalSearch {
context_builder: GlobalCommunityContext,
token_encoder: Option<Tokenizer>,
context_builder_params: ContextBuilderParams,
map_system_prompt: String,
reduce_system_prompt: String,
response_type: String,
allow_general_knowledge: bool,
Expand All @@ -100,22 +104,42 @@ pub struct GlobalSearch {
reduce_llm_params: LLMParams,
}

pub struct GlobalSearchParams {
pub llm: Box<dyn BaseLLM>,
pub context_builder: GlobalCommunityContext,
pub token_encoder: Option<Tokenizer>,
pub map_system_prompt: Option<String>,
pub reduce_system_prompt: Option<String>,
pub response_type: String,
pub allow_general_knowledge: bool,
pub general_knowledge_inclusion_prompt: Option<String>,
pub json_mode: bool,
pub callbacks: Option<Vec<GlobalSearchLLMCallback>>,
pub max_data_tokens: usize,
pub map_llm_params: LLMParams,
pub reduce_llm_params: LLMParams,
pub context_builder_params: ContextBuilderParams,
}

impl GlobalSearch {
pub fn new(
llm: Box<dyn BaseLLM>,
context_builder: GlobalCommunityContext,
token_encoder: Option<Tokenizer>,
reduce_system_prompt: String,
response_type: String,
allow_general_knowledge: bool,
general_knowledge_inclusion_prompt: String,
json_mode: bool,
callbacks: Option<Vec<GlobalSearchLLMCallback>>,
max_data_tokens: usize,
map_llm_params: LLMParams,
reduce_llm_params: LLMParams,
context_builder_params: ContextBuilderParams,
) -> Self {
pub fn new(global_search_params: GlobalSearchParams) -> Self {
let GlobalSearchParams {
llm,
context_builder,
token_encoder,
map_system_prompt,
reduce_system_prompt,
response_type,
allow_general_knowledge,
general_knowledge_inclusion_prompt,
json_mode,
callbacks,
max_data_tokens,
map_llm_params,
reduce_llm_params,
context_builder_params,
} = global_search_params;

let mut map_llm_params = map_llm_params;

if json_mode {
Expand All @@ -126,11 +150,17 @@ impl GlobalSearch {
map_llm_params.response_format.remove("response_format");
}

let map_system_prompt = map_system_prompt.unwrap_or(MAP_SYSTEM_PROMPT.to_string());
let reduce_system_prompt = reduce_system_prompt.unwrap_or(REDUCE_SYSTEM_PROMPT.to_string());
let general_knowledge_inclusion_prompt =
general_knowledge_inclusion_prompt.unwrap_or(GENERAL_KNOWLEDGE_INSTRUCTION.to_string());

GlobalSearch {
llm,
context_builder,
token_encoder,
context_builder_params,
map_system_prompt,
reduce_system_prompt,
response_type,
allow_general_knowledge,
Expand Down Expand Up @@ -218,7 +248,8 @@ impl GlobalSearch {
llm_params: LLMParams,
) -> anyhow::Result<SearchResult> {
let start_time = Instant::now();
let search_prompt = String::new();
let search_prompt = self.map_system_prompt.replace("{context_data}", context_data);

let mut search_messages = Vec::new();
search_messages.push(HashMap::from([
("role".to_string(), "system".to_string()),
Expand Down Expand Up @@ -253,6 +284,7 @@ impl GlobalSearch {
if let Some(points) = points.as_array() {
return points
.iter()
.filter(|element| element.get("description").is_some() && element.get("score").is_some())
.map(|element| KeyPoint {
answer: element
.get("description")
Expand All @@ -268,7 +300,10 @@ impl GlobalSearch {
}
}

Vec::new()
vec![KeyPoint {
answer: "".to_string(),
score: 0,
}]
}

async fn _reduce_response(
Expand All @@ -282,15 +317,13 @@ impl GlobalSearch {
let mut key_points: Vec<HashMap<String, String>> = Vec::new();

for (index, response) in map_responses.iter().enumerate() {
if let ResponseType::Dictionaries(response_list) = &response.response {
for element in response_list {
if let (Some(answer), Some(score)) = (element.get("answer"), element.get("score")) {
let mut point = HashMap::new();
point.insert("analyst".to_string(), (index + 1).to_string());
point.insert("answer".to_string(), answer.to_string());
point.insert("score".to_string(), score.to_string());
key_points.push(point);
}
if let ResponseType::KeyPoints(response_list) = &response.response {
for key_point in response_list {
let mut point = HashMap::new();
point.insert("analyst".to_string(), (index + 1).to_string());
point.insert("answer".to_string(), key_point.answer.clone());
point.insert("score".to_string(), key_point.score.to_string());
key_points.push(point);
}
}
}
Expand All @@ -301,8 +334,10 @@ impl GlobalSearch {
.collect();

if filtered_key_points.is_empty() && !self.allow_general_knowledge {
eprintln!("Warning: All map responses have score 0 (i.e., no relevant information found from the dataset), returning a canned 'I do not know' answer. You can try enabling `allow_general_knowledge` to encourage the LLM to incorporate relevant general knowledge, at the risk of increasing hallucinations.");

return Ok(SearchResult {
response: ResponseType::String("NO_DATA_ANSWER".to_string()),
response: ResponseType::String(NO_DATA_ANSWER.to_string()),
context_data: ContextData::String("".to_string()),
context_text: ContextText::String("".to_string()),
completion_time: start_time.elapsed().as_secs_f64(),
Expand All @@ -328,9 +363,11 @@ impl GlobalSearch {
formatted_response_data.push(format!("Importance Score: {}", point.get("score").unwrap()));
formatted_response_data.push(point.get("answer").unwrap().to_string());
let formatted_response_text = formatted_response_data.join("\n");

if total_tokens + num_tokens(&formatted_response_text, self.token_encoder) > self.max_data_tokens {
break;
}

data.push(formatted_response_text.clone());
total_tokens += num_tokens(&formatted_response_text, self.token_encoder);
}
Expand Down
2 changes: 2 additions & 0 deletions shinkai-libs/shinkai-graphrag/src/search/global_search/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod global_search;
pub mod prompts;
Loading

0 comments on commit a662504

Please sign in to comment.