diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 0ad9ecf..38996dd 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -33,6 +33,7 @@ use uuid::Uuid; use crate::backend::{build_body, build_headers, parse_generations}; use crate::document::Document; use crate::error::{internal_error, Error, Result}; +use crate::retrieval::BuildFrom; mod backend; mod config; @@ -238,11 +239,21 @@ async fn build_prompt( after_line = after_iter.next(); } let before = before.into_iter().rev().collect::>().join(""); + let query = snippet_retriever + .read() + .await + .build_query( + format!("{before}{after}"), + BuildFrom::Cursor { + cursor_position: before.len(), + }, + ) + .await?; let snippets = snippet_retriever .read() .await .search( - format!("{before}{after}"), + &query, Some(FilterBuilder::new().comparison( "file_url".to_owned(), Compare::Neq, @@ -281,11 +292,16 @@ async fn build_prompt( before.push(line); } let prompt = before.into_iter().rev().collect::>().join(""); + let query = snippet_retriever + .read() + .await + .build_query(prompt.clone(), BuildFrom::End) + .await?; let snippets = snippet_retriever .read() .await .search( - prompt.clone(), + &query, Some(FilterBuilder::new().comparison( "file_url".to_owned(), Compare::Neq, diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 8860b2d..09615d5 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -251,7 +251,7 @@ impl SnippetRetriever { }) } - pub(crate) async fn initialse_database(&mut self, db_name: &str) -> Result { + pub(crate) async fn initialise_database(&mut self, db_name: &str) -> Result { let uri = self.cache_path.join(db_name); let mut db = Db::open(uri).await.expect("failed to open database"); match db @@ -282,13 +282,15 @@ impl SnippetRetriever { debug!("building workspace snippets"); let workspace_root = PathBuf::from(workspace_root); if self.db.is_none() { - self.initialse_database( + self.initialise_database(&format!( + "{}--{}", workspace_root .file_name() .ok_or_else(|| Error::NoFinalPath(workspace_root.clone()))? .to_str() .ok_or(Error::NonUnicode)?, - ) + self.model_config.id.replace('/', "--"), + )) .await?; } let mut files = Vec::new(); @@ -377,9 +379,49 @@ impl SnippetRetriever { Ok(()) } - pub(crate) async fn search( + pub(crate) async fn build_query( &self, snippet: String, + strategy: BuildFrom, + ) -> Result> { + match strategy { + BuildFrom::Start => { + let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; + encoding.truncate( + self.model_config.max_input_size, + 1, + TruncationDirection::Right, + ); + self.generate_embedding(encoding, self.model.clone()).await + } + BuildFrom::Cursor { cursor_position } => { + let (before, after) = snippet.split_at(cursor_position); + let mut before_encoding = self.tokenizer.encode(before, true)?; + let mut after_encoding = self.tokenizer.encode(after, true)?; + let share = self.model_config.max_input_size / 2; + before_encoding.truncate(share, 1, TruncationDirection::Left); + after_encoding.truncate(share, 1, TruncationDirection::Right); + before_encoding.take_overflowing(); + after_encoding.take_overflowing(); + before_encoding.merge_with(after_encoding, false); + self.generate_embedding(before_encoding, self.model.clone()) + .await + } + BuildFrom::End => { + let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; + encoding.truncate( + self.model_config.max_input_size, + 1, + TruncationDirection::Left, + ); + self.generate_embedding(encoding, self.model.clone()).await + } + } + } + + pub(crate) async fn search( + &self, + query: &[f32], filter: Option, ) -> Result> { let db = match self.db.as_ref() { @@ -387,19 +429,10 @@ impl SnippetRetriever { None => return Err(Error::UninitialisedDatabase), }; let col = db.get_collection(&self.collection_name).await?; - let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; - encoding.truncate( - self.model_config.max_input_size, - 1, - TruncationDirection::Right, - ); - let query = self - .generate_embedding(encoding, self.model.clone()) - .await?; let result = col .read() .await - .get(&query, 5, filter) + .get(query, 5, filter) .await? .iter() .map(TryInto::try_into) @@ -537,3 +570,12 @@ impl SnippetRetriever { Ok(()) } } + +pub(crate) enum BuildFrom { + Cursor { + cursor_position: usize, + }, + End, + #[allow(dead_code)] + Start, +}