diff --git a/crates/llm-ls/src/error.rs b/crates/llm-ls/src/error.rs index f6ac083..46f17f1 100644 --- a/crates/llm-ls/src/error.rs +++ b/crates/llm-ls/src/error.rs @@ -79,6 +79,8 @@ pub enum Error { UnknownBackend(String), #[error("yaml serialization error: {0}")] Yaml(#[from] serde_yaml::Error), + #[error("No embedding built")] + MissingEmbedding, } pub(crate) type Result = std::result::Result; diff --git a/crates/llm-ls/src/retrieval.rs b/crates/llm-ls/src/retrieval.rs index 09615d5..6bea3f8 100644 --- a/crates/llm-ls/src/retrieval.rs +++ b/crates/llm-ls/src/retrieval.rs @@ -7,11 +7,14 @@ use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use gitignore::Gitignore; use hf_hub::{api::tokio::Api, Repo, RepoType}; use std::collections::{HashMap, VecDeque}; +use std::iter::zip; use std::path::Path; use std::{path::PathBuf, sync::Arc}; use tinyvec_embed::db::{Collection, Compare, Db, Embedding, FilterBuilder, SimilarityResult}; use tinyvec_embed::similarity::Distance; -use tokenizers::{Encoding, Tokenizer, TruncationDirection}; +use tokenizers::{ + Encoding, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection, +}; use tokio::io::AsyncReadExt; use tokio::task::spawn_blocking; use tokio::time::Instant; @@ -156,9 +159,16 @@ async fn build_model_and_tokenizer( let config = tokio::fs::read_to_string(config_filename).await?; let config: Config = serde_json::from_str(&config)?; let mut tokenizer: Tokenizer = Tokenizer::from_file(tokenizer_filename)?; - tokenizer.with_padding(None); + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + pad_to_multiple_of: Some(8), + // TODO: use values provided in model config + pad_id: 0, + pad_type_id: 0, + pad_token: "".to_string(), + })); tokenizer.with_truncation(None)?; - let vb = VarBuilder::from_pth(&weights_filename, DTYPE, &device)?; let model = BertModel::load(vb, &config)?; debug!( @@ -191,6 +201,8 @@ fn device(cpu: bool) -> Result { pub(crate) struct Snippet { pub(crate) file_url: String, pub(crate) code: String, + pub(crate) start_line: usize, + pub(crate) end_line: usize, } impl TryFrom<&SimilarityResult> for Snippet { @@ -210,7 +222,20 @@ impl TryFrom<&SimilarityResult> for Snippet { .get("snippet") .ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))? .inner_string()?; - Ok(Snippet { file_url, code }) + let start_line = meta + .get("start_line_no") + .ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))? + .try_into()?; + let end_line = meta + .get("start_line_no") + .ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))? + .try_into()?; + Ok(Snippet { + file_url, + code, + start_line, + end_line, + }) } } @@ -280,6 +305,7 @@ impl SnippetRetriever { workspace_root: &str, ) -> Result<()> { debug!("building workspace snippets"); + let start = Instant::now(); let workspace_root = PathBuf::from(workspace_root); if self.db.is_none() { self.initialise_database(&format!( @@ -360,7 +386,10 @@ impl SnippetRetriever { }) .await; } - + debug!( + "Built workspace snippets in {} ms", + start.elapsed().as_millis() + ); Ok(()) } @@ -384,7 +413,7 @@ impl SnippetRetriever { snippet: String, strategy: BuildFrom, ) -> Result> { - match strategy { + let result = match strategy { BuildFrom::Start => { let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; encoding.truncate( @@ -392,7 +421,8 @@ impl SnippetRetriever { 1, TruncationDirection::Right, ); - self.generate_embedding(encoding, self.model.clone()).await + self.generate_embeddings(vec![encoding], self.model.clone()) + .await? } BuildFrom::Cursor { cursor_position } => { let (before, after) = snippet.split_at(cursor_position); @@ -404,8 +434,8 @@ impl SnippetRetriever { before_encoding.take_overflowing(); after_encoding.take_overflowing(); before_encoding.merge_with(after_encoding, false); - self.generate_embedding(before_encoding, self.model.clone()) - .await + self.generate_embeddings(vec![before_encoding], self.model.clone()) + .await? } BuildFrom::End => { let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; @@ -414,9 +444,15 @@ impl SnippetRetriever { 1, TruncationDirection::Left, ); - self.generate_embedding(encoding, self.model.clone()).await + self.generate_embeddings(vec![encoding], self.model.clone()) + .await? } + }; + if result.is_empty() { + return Err(Error::MissingEmbedding); } + let mut result = result; + Ok(result.remove(0)) } pub(crate) async fn search( @@ -477,21 +513,24 @@ impl SnippetRetriever { impl SnippetRetriever { // TODO: handle overflowing in Encoding - async fn generate_embedding( + /// Embedding order is preserved and stays the same as encoding input + async fn generate_embeddings( &self, - encoding: Encoding, + encodings: Vec, model: Arc, - ) -> Result> { + ) -> Result>> { let start = Instant::now(); - let embedding = spawn_blocking(move || -> Result> { - let tokens = encoding.get_ids().to_vec(); - let token_ids = Tensor::new(&tokens[..], &model.device)?.unsqueeze(0)?; + let embedding = spawn_blocking(move || -> Result>> { + let tokens = encodings + .iter() + .map(|elem| Ok(Tensor::new(elem.get_ids().to_vec(), &model.device)?)) + .collect::>>()?; + let token_ids = Tensor::stack(&tokens, 0)?; let token_type_ids = token_ids.zeros_like()?; let embedding = model.forward(&token_ids, &token_type_ids)?; let (_n_sentence, n_tokens, _hidden_size) = embedding.dims3()?; let embedding = (embedding.sum(1)? / (n_tokens as f64))?; - let embedding = embedding.get(0)?.to_vec1::()?; - Ok(embedding) + Ok(embedding.to_vec2::()?) }) .await?; debug!("embedding generated in {} ms", start.elapsed().as_millis()); @@ -512,6 +551,8 @@ impl SnippetRetriever { let file = tokio::fs::read_to_string(&file_url).await?; let lines = file.split('\n').collect::>(); let end = end.unwrap_or(lines.len()).min(lines.len()); + let mut snippets: Vec = Vec::new(); + debug!("Building embeddings for {file_url}"); for start_line in (start..end).step_by(self.window_step) { let end_line = (start_line + self.window_size - 1).min(lines.len()); if !col @@ -538,35 +579,68 @@ impl SnippetRetriever { let window = lines[start_line..end_line].to_vec(); let snippet = window.join("\n"); if snippet.is_empty() { + debug!("snippet {file_url}[{start_line}, {end_line}] empty"); continue; } + snippets.push(Snippet { + file_url: file_url.clone().into(), + code: snippet, + start_line, + end_line, + }); + } + { + let nb_snippets = snippets.len(); + let steps = self.window_step; + debug!("Build {nb_snippets} snippets for {file_url}: {start}, {end}, {steps}"); + } - let mut encoding = self.tokenizer.encode(snippet.clone(), true)?; - encoding.truncate( - self.model_config.max_input_size, - 1, - TruncationDirection::Right, - ); - let result = self.generate_embedding(encoding, self.model.clone()).await; - let embedding = match result { - Ok(e) => e, - Err(err) => { - error!( - "error generating embedding for {file_url}[{start_line}, {end_line}]: {err}", - ); - continue; - } - }; - col.write().await.insert(Embedding::new( - embedding, - Some(HashMap::from([ - ("file_url".to_owned(), file_url.clone().into()), - ("start_line_no".to_owned(), start_line.into()), - ("end_line_no".to_owned(), end_line.into()), - ("snippet".to_owned(), snippet.clone().into()), - ])), - ))?; + // Group by length to reduce padding effect + let snippets = spawn_blocking(|| -> Result> { + snippets.sort_unstable_by(|first, second| first.code.len().cmp(&second.code.len())); + Ok(snippets) + }) + .await?; + + // TODO: improvements to compute an efficient batch size: + // - batch size should be relative to the cumulative size of all elements in the batch, + // Set embedding_batch_size to 8 if device is GPU, use match + let embedding_batch_size = match self.model.device { + Device::Cpu => 2, + _ => 8, + }; + for batch in snippets?.chunks(embedding_batch_size) { + let batch_code = batch.iter().map(|snippet| snippet.code.clone()).collect(); + let encodings = self + .tokenizer + .encode_batch(batch_code, true)? + .iter_mut() + .map(|encoding| { + encoding.truncate(512, 1, TruncationDirection::Right); + encoding.clone() + }) + .collect(); + let results = self + .generate_embeddings(encodings, self.model.clone()) + .await?; + col.write().await.batch_insert( + zip(results, batch) + .map(|item| { + let (embedding, snippet) = item; + Embedding::new( + embedding, + Some(HashMap::from([ + ("file_url".to_owned(), snippet.file_url.clone().into()), + ("start_line_no".to_owned(), snippet.start_line.into()), + ("end_line_no".to_owned(), snippet.end_line.into()), + ("snippet".to_owned(), snippet.code.clone().into()), + ])), + ) + }) + .collect::>(), + )?; } + db.save().await?; Ok(()) } } diff --git a/crates/tinyvec-embed/src/db.rs b/crates/tinyvec-embed/src/db.rs index 8267d2b..dd20995 100644 --- a/crates/tinyvec-embed/src/db.rs +++ b/crates/tinyvec-embed/src/db.rs @@ -173,6 +173,17 @@ impl Collection { Ok(()) } + pub fn batch_insert(&mut self, embeddings: Vec) -> Result<()> { + if embeddings + .iter() + .any(|embedding| embedding.vector.len() != self.dimension) + { + return Err(CollectionError::DimensionMismatch.into()); + } + self.embeddings.extend(embeddings); + Ok(()) + } + /// Remove values matching filter. /// /// Empties the collection when `filter` is `None`. @@ -244,6 +255,18 @@ impl Value { } } +impl TryInto for &Value { + type Error = Error; + + fn try_into(self) -> Result { + if let Value::Number(n) = self { + Ok(n.clone() as usize) + } else { + Err(Error::ValueNotNumber(self.to_owned())) + } + } +} + impl From for Value { fn from(value: usize) -> Self { Self::Number(value as f32) diff --git a/crates/tinyvec-embed/src/error.rs b/crates/tinyvec-embed/src/error.rs index 6a31c86..8dc0f54 100644 --- a/crates/tinyvec-embed/src/error.rs +++ b/crates/tinyvec-embed/src/error.rs @@ -32,6 +32,8 @@ pub enum Error { InvalidFileName, #[error("io error: {0}")] Io(#[from] std::io::Error), + #[error("expected value to be a valid number, got: {0}")] + ValueNotNumber(Value), #[error("expected value to be string, got: {0}")] ValueNotString(Value), }