Skip to content

Commit

Permalink
feat: add strategies for building query embedding vector
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Feb 28, 2024
1 parent baedf85 commit 6e3d6c0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 16 deletions.
20 changes: 18 additions & 2 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use uuid::Uuid;
use crate::backend::{build_body, build_headers, parse_generations};
use crate::document::Document;
use crate::error::{internal_error, Error, Result};
use crate::retrieval::BuildFrom;

mod backend;
mod config;
Expand Down Expand Up @@ -238,11 +239,21 @@ async fn build_prompt(
after_line = after_iter.next();
}
let before = before.into_iter().rev().collect::<Vec<_>>().join("");
let query = snippet_retriever
.read()
.await
.build_query(
format!("{before}{after}"),
BuildFrom::Cursor {
cursor_position: before.len(),
},
)
.await?;
let snippets = snippet_retriever
.read()
.await
.search(
format!("{before}{after}"),
&query,
Some(FilterBuilder::new().comparison(
"file_url".to_owned(),
Compare::Neq,
Expand Down Expand Up @@ -281,11 +292,16 @@ async fn build_prompt(
before.push(line);
}
let prompt = before.into_iter().rev().collect::<Vec<_>>().join("");
let query = snippet_retriever
.read()
.await
.build_query(prompt.clone(), BuildFrom::End)
.await?;
let snippets = snippet_retriever
.read()
.await
.search(
prompt.clone(),
&query,
Some(FilterBuilder::new().comparison(
"file_url".to_owned(),
Compare::Neq,
Expand Down
70 changes: 56 additions & 14 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ impl SnippetRetriever {
})
}

pub(crate) async fn initialse_database(&mut self, db_name: &str) -> Result<Db> {
pub(crate) async fn initialise_database(&mut self, db_name: &str) -> Result<Db> {
let uri = self.cache_path.join(db_name);
let mut db = Db::open(uri).await.expect("failed to open database");
match db
Expand Down Expand Up @@ -282,13 +282,15 @@ impl SnippetRetriever {
debug!("building workspace snippets");
let workspace_root = PathBuf::from(workspace_root);
if self.db.is_none() {
self.initialse_database(
self.initialise_database(&format!(
"{}--{}",
workspace_root
.file_name()
.ok_or_else(|| Error::NoFinalPath(workspace_root.clone()))?
.to_str()
.ok_or(Error::NonUnicode)?,
)
self.model_config.id.replace('/', "--"),
))
.await?;
}
let mut files = Vec::new();
Expand Down Expand Up @@ -377,29 +379,60 @@ impl SnippetRetriever {
Ok(())
}

pub(crate) async fn search(
pub(crate) async fn build_query(
&self,
snippet: String,
strategy: BuildFrom,
) -> Result<Vec<f32>> {
match strategy {
BuildFrom::Start => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
self.generate_embedding(encoding, self.model.clone()).await
}
BuildFrom::Cursor { cursor_position } => {
let (before, after) = snippet.split_at(cursor_position);
let mut before_encoding = self.tokenizer.encode(before, true)?;
let mut after_encoding = self.tokenizer.encode(after, true)?;
let share = self.model_config.max_input_size / 2;
before_encoding.truncate(share, 1, TruncationDirection::Left);
after_encoding.truncate(share, 1, TruncationDirection::Right);
before_encoding.take_overflowing();
after_encoding.take_overflowing();
before_encoding.merge_with(after_encoding, false);
self.generate_embedding(before_encoding, self.model.clone())
.await
}
BuildFrom::End => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Left,
);
self.generate_embedding(encoding, self.model.clone()).await
}
}
}

pub(crate) async fn search(
&self,
query: &[f32],
filter: Option<FilterBuilder>,
) -> Result<Vec<Snippet>> {
let db = match self.db.as_ref() {
Some(db) => db.clone(),
None => return Err(Error::UninitialisedDatabase),
};
let col = db.get_collection(&self.collection_name).await?;
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let query = self
.generate_embedding(encoding, self.model.clone())
.await?;
let result = col
.read()
.await
.get(&query, 5, filter)
.get(query, 5, filter)
.await?
.iter()
.map(TryInto::try_into)
Expand Down Expand Up @@ -537,3 +570,12 @@ impl SnippetRetriever {
Ok(())
}
}

pub(crate) enum BuildFrom {
Cursor {
cursor_position: usize,
},
End,
#[allow(dead_code)]
Start,
}

0 comments on commit 6e3d6c0

Please sign in to comment.