Skip to content

Commit

Permalink
read indexer entities and reports improvements, compute community wei…
Browse files Browse the repository at this point in the history
…ghts
  • Loading branch information
benolt committed Aug 13, 2024
1 parent a9339a7 commit 85d8658
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 136 deletions.
3 changes: 1 addition & 2 deletions shinkai-libs/shinkai-graphrag/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ tiktoken-rs = "0.5.9"
tokio = { version = "1.36", features = ["full"] }

[dev-dependencies]
async-openai = "0.23.4"
tokio = { version = "1.36", features = ["full"] }
async-openai = "0.23.4"
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
io::{Cursor, Read},
};

Expand Down Expand Up @@ -276,7 +276,7 @@ impl CommunityContext {
) -> Vec<CommunityReport> {
// Calculate a community's weight as the count of text units associated with entities within the community.
if let Some(entities) = entities {
let mut community_reports = community_reports.clone();
let mut community_reports = community_reports;
let mut community_text_units = std::collections::HashMap::new();
for entity in entities {
if let Some(community_ids) = entity.community_ids.clone() {
Expand All @@ -297,7 +297,7 @@ impl CommunityContext {
weight_attribute.to_string(),
community_text_units
.get(&report.community_id)
.map(|text_units| text_units.len())
.map(|text_units| text_units.iter().flatten().cloned().collect::<HashSet<String>>().len())
.unwrap_or(0)
.to_string(),
);
Expand All @@ -316,7 +316,7 @@ impl CommunityContext {
})
.collect();
if let Some(max_weight) = all_weights.iter().cloned().max_by(|a, b| a.partial_cmp(b).unwrap()) {
for mut report in community_reports {
for report in &mut community_reports {
if let Some(attributes) = &mut report.attributes {
if let Some(weight) = attributes.get_mut(weight_attribute) {
*weight = (weight.parse::<f64>().unwrap_or(0.0) / max_weight).to_string();
Expand All @@ -325,6 +325,8 @@ impl CommunityContext {
}
}
}

return community_reports;
}
community_reports
}
Expand Down
224 changes: 138 additions & 86 deletions shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use polars::prelude::*;
use polars_lazy::dsl::col;
Expand All @@ -12,44 +12,19 @@ pub fn read_indexer_entities(
community_level: u32,
) -> anyhow::Result<Vec<Entity>> {
let entity_df = final_nodes.clone();
let mut entity_df = filter_under_community_level(&entity_df, community_level)?;
let entity_df = filter_under_community_level(&entity_df, community_level)?;

let entity_df = entity_df.rename("title", "name")?.rename("degree", "rank")?;
let entity_embedding_df = final_entities.clone();

let entity_df = entity_df
.clone()
.lazy()
.rename(["title", "degree"], ["name", "rank"])
.with_column(col("community").fill_null(lit(-1)))
.collect()?;
let entity_df = entity_df
.clone()
.lazy()
.with_column(col("community").cast(DataType::Int32))
.collect()?;
let entity_df = entity_df
.clone()
.lazy()
.with_column(col("rank").cast(DataType::Int32))
.collect()?;

let entity_embedding_df = final_entities.clone();

let entity_df = entity_df
.clone()
.lazy()
.group_by([col("name"), col("rank")])
.agg([col("community").max()])
.collect()?;

let entity_df = entity_df
.clone()
.lazy()
.with_column(col("community").cast(DataType::String))
.collect()?;

let entity_df = entity_df
.clone()
.lazy()
.join(
entity_embedding_df.clone().lazy(),
[col("name")],
Expand All @@ -58,12 +33,6 @@ pub fn read_indexer_entities(
)
.collect()?;

let entity_df = entity_df
.clone()
.lazy()
.filter(len().over([col("name")]).gt(lit(1)))
.collect()?;

let entities = read_entities(
&entity_df,
"id",
Expand Down Expand Up @@ -134,77 +103,160 @@ pub fn read_entities(
.filter_map(|&v| v.map(|v| v.to_string()))
.collect::<Vec<_>>();

let column_names = column_names.into_iter().collect::<HashSet<String>>().into_vec();

let mut df = df.clone();
df.as_single_chunk_par();
let mut iters = df.columns(column_names)?.iter().map(|s| s.iter()).collect::<Vec<_>>();
let mut iters = df
.columns(column_names.clone())?
.iter()
.map(|s| s.iter())
.collect::<Vec<_>>();

let mut rows = Vec::new();
for _row in 0..df.height() {
let mut row_values = Vec::new();
for iter in &mut iters {
let value = iter.next();
if let Some(value) = value {
row_values.push(value.to_string());
row_values.push(value);
}
}
rows.push(row_values);
}

let mut entities = Vec::new();
for row in rows {
for (idx, row) in rows.iter().enumerate() {
let report = Entity {
id: row.get(0).unwrap_or(&String::new()).to_string(),
short_id: Some(row.get(1).unwrap_or(&String::new()).to_string()),
title: row.get(2).unwrap_or(&String::new()).to_string(),
entity_type: Some(row.get(3).unwrap_or(&String::new()).to_string()),
description: Some(row.get(4).unwrap_or(&String::new()).to_string()),
name_embedding: Some(
row.get(5)
.unwrap_or(&String::new())
.split(',')
.map(|v| v.parse::<f64>().unwrap_or(0.0))
.collect(),
),
description_embedding: Some(
row.get(6)
.unwrap_or(&String::new())
.split(',')
.map(|v| v.parse::<f64>().unwrap_or(0.0))
.collect(),
),
graph_embedding: Some(
row.get(7)
.unwrap_or(&String::new())
.split(',')
.map(|v| v.parse::<f64>().unwrap_or(0.0))
.collect(),
),
community_ids: Some(
row.get(8)
.unwrap_or(&String::new())
.split(',')
.map(|v| v.to_string())
.collect(),
),
text_unit_ids: Some(
row.get(9)
.unwrap_or(&String::new())
.split(',')
.map(|v| v.to_string())
.collect(),
id: get_field(&row, id_col, &column_names)
.map(|id| id.to_string())
.unwrap_or(String::new()),
short_id: Some(
short_id_col
.map(|short_id| get_field(&row, short_id, &column_names))
.flatten()
.map(|short_id| short_id.to_string())
.unwrap_or(idx.to_string()),
),
document_ids: Some(
row.get(10)
.unwrap_or(&String::new())
.split(',')
.map(|v| v.to_string())
.collect(),
),
rank: Some(row.get(11).and_then(|v| v.parse::<i32>().ok()).unwrap_or(0)),
title: get_field(&row, title_col, &column_names)
.map(|title| title.to_string())
.unwrap_or(String::new()),
entity_type: type_col
.map(|type_col| get_field(&row, type_col, &column_names))
.flatten()
.map(|entity_type| entity_type.to_string()),
description: description_col
.map(|description_col| get_field(&row, description_col, &column_names))
.flatten()
.map(|description| description.to_string()),
name_embedding: name_embedding_col.map(|name_embedding_col| {
get_field(&row, name_embedding_col, &column_names)
.map(|name_embedding| match name_embedding {
AnyValue::List(series) => series
.f64()
.unwrap_or(&ChunkedArray::from_vec(name_embedding_col, vec![]))
.iter()
.map(|v| v.unwrap_or(0.0))
.collect::<Vec<f64>>(),
value => vec![value.to_string().parse::<f64>().unwrap_or(0.0)],
})
.unwrap_or_else(|| Vec::new())
}),
description_embedding: description_embedding_col.map(|description_embedding_col| {
get_field(&row, description_embedding_col, &column_names)
.map(|description_embedding| match description_embedding {
AnyValue::List(series) => series
.f64()
.unwrap_or(&ChunkedArray::from_vec(description_embedding_col, vec![]))
.iter()
.map(|v| v.unwrap_or(0.0))
.collect::<Vec<f64>>(),
value => vec![value.to_string().parse::<f64>().unwrap_or(0.0)],
})
.unwrap_or_else(|| Vec::new())
}),
graph_embedding: graph_embedding_col.map(|graph_embedding_col| {
get_field(&row, graph_embedding_col, &column_names)
.map(|graph_embedding| match graph_embedding {
AnyValue::List(series) => series
.f64()
.unwrap_or(&ChunkedArray::from_vec(graph_embedding_col, vec![]))
.iter()
.map(|v| v.unwrap_or(0.0))
.collect::<Vec<f64>>(),
value => vec![value.to_string().parse::<f64>().unwrap_or(0.0)],
})
.unwrap_or_else(|| Vec::new())
}),
community_ids: community_col.map(|community_col| {
get_field(&row, community_col, &column_names)
.map(|community_ids| match community_ids {
AnyValue::List(series) => series
.str()
.unwrap_or(&StringChunked::default())
.iter()
.map(|v| v.unwrap_or("").to_string())
.collect::<Vec<String>>(),
value => vec![value.to_string()],
})
.unwrap_or_else(|| Vec::new())
}),
text_unit_ids: text_unit_ids_col.map(|text_unit_ids_col| {
get_field(&row, text_unit_ids_col, &column_names)
.map(|text_unit_ids| match text_unit_ids {
AnyValue::List(series) => series
.str()
.unwrap_or(&StringChunked::default())
.iter()
.map(|v| v.unwrap_or("").to_string())
.collect::<Vec<String>>(),
value => vec![value.to_string()],
})
.unwrap_or_else(|| Vec::new())
}),
document_ids: document_ids_col.map(|document_ids_col| {
get_field(&row, document_ids_col, &column_names)
.map(|document_ids| match document_ids {
AnyValue::List(series) => series
.str()
.unwrap_or(&StringChunked::default())
.iter()
.map(|v| v.unwrap_or("").to_string())
.collect::<Vec<String>>(),
value => vec![value.to_string()],
})
.unwrap_or_else(|| Vec::new())
}),
rank: rank_col
.map(|rank_col| {
get_field(&row, rank_col, &column_names).map(|v| v.to_string().parse::<i32>().unwrap_or(0))
})
.flatten(),
attributes: None,
};
entities.push(report);
}

Ok(entities)
let mut unique_entities: Vec<Entity> = Vec::new();
let mut entity_ids: HashSet<String> = HashSet::new();

for entity in entities {
if !entity_ids.contains(&entity.id) {
unique_entities.push(entity.clone());
entity_ids.insert(entity.id);
}
}

Ok(unique_entities)
}

pub fn get_field<'a>(
row: &'a Vec<AnyValue<'a>>,
column_name: &'a str,
column_names: &'a Vec<String>,
) -> Option<AnyValue<'a>> {
match column_names.iter().position(|x| x == column_name) {
Some(index) => row.get(index).cloned(),
None => None,
}
}
Loading

0 comments on commit 85d8658

Please sign in to comment.