Skip to content

Commit

Permalink
Added batch embedding computing (#86)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Quentin Maire <quentin.maire@corp.ovh.com>
Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 6, 2024
1 parent 6e3d6c0 commit 64a4c38
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 43 deletions.
2 changes: 2 additions & 0 deletions crates/llm-ls/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ pub enum Error {
UnknownBackend(String),
#[error("yaml serialization error: {0}")]
Yaml(#[from] serde_yaml::Error),
#[error("No embedding built")]
MissingEmbedding,
}

pub(crate) type Result<T> = std::result::Result<T, Error>;
Expand Down
160 changes: 117 additions & 43 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use gitignore::Gitignore;
use hf_hub::{api::tokio::Api, Repo, RepoType};
use std::collections::{HashMap, VecDeque};
use std::iter::zip;
use std::path::Path;
use std::{path::PathBuf, sync::Arc};
use tinyvec_embed::db::{Collection, Compare, Db, Embedding, FilterBuilder, SimilarityResult};
use tinyvec_embed::similarity::Distance;
use tokenizers::{Encoding, Tokenizer, TruncationDirection};
use tokenizers::{
Encoding, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
};
use tokio::io::AsyncReadExt;
use tokio::task::spawn_blocking;
use tokio::time::Instant;
Expand Down Expand Up @@ -156,9 +159,16 @@ async fn build_model_and_tokenizer(
let config = tokio::fs::read_to_string(config_filename).await?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer: Tokenizer = Tokenizer::from_file(tokenizer_filename)?;
tokenizer.with_padding(None);
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
pad_to_multiple_of: Some(8),
// TODO: use values provided in model config
pad_id: 0,
pad_type_id: 0,
pad_token: "<pad>".to_string(),
}));
tokenizer.with_truncation(None)?;

let vb = VarBuilder::from_pth(&weights_filename, DTYPE, &device)?;
let model = BertModel::load(vb, &config)?;
debug!(
Expand Down Expand Up @@ -191,6 +201,8 @@ fn device(cpu: bool) -> Result<Device> {
pub(crate) struct Snippet {
pub(crate) file_url: String,
pub(crate) code: String,
pub(crate) start_line: usize,
pub(crate) end_line: usize,
}

impl TryFrom<&SimilarityResult> for Snippet {
Expand All @@ -210,7 +222,20 @@ impl TryFrom<&SimilarityResult> for Snippet {
.get("snippet")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.inner_string()?;
Ok(Snippet { file_url, code })
let start_line = meta
.get("start_line_no")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.try_into()?;
let end_line = meta
.get("start_line_no")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.try_into()?;
Ok(Snippet {
file_url,
code,
start_line,
end_line,
})
}
}

Expand Down Expand Up @@ -280,6 +305,7 @@ impl SnippetRetriever {
workspace_root: &str,
) -> Result<()> {
debug!("building workspace snippets");
let start = Instant::now();
let workspace_root = PathBuf::from(workspace_root);
if self.db.is_none() {
self.initialise_database(&format!(
Expand Down Expand Up @@ -360,7 +386,10 @@ impl SnippetRetriever {
})
.await;
}

debug!(
"Built workspace snippets in {} ms",
start.elapsed().as_millis()
);
Ok(())
}

Expand All @@ -384,15 +413,16 @@ impl SnippetRetriever {
snippet: String,
strategy: BuildFrom,
) -> Result<Vec<f32>> {
match strategy {
let result = match strategy {
BuildFrom::Start => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
self.generate_embedding(encoding, self.model.clone()).await
self.generate_embeddings(vec![encoding], self.model.clone())
.await?
}
BuildFrom::Cursor { cursor_position } => {
let (before, after) = snippet.split_at(cursor_position);
Expand All @@ -404,8 +434,8 @@ impl SnippetRetriever {
before_encoding.take_overflowing();
after_encoding.take_overflowing();
before_encoding.merge_with(after_encoding, false);
self.generate_embedding(before_encoding, self.model.clone())
.await
self.generate_embeddings(vec![before_encoding], self.model.clone())
.await?
}
BuildFrom::End => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
Expand All @@ -414,9 +444,15 @@ impl SnippetRetriever {
1,
TruncationDirection::Left,
);
self.generate_embedding(encoding, self.model.clone()).await
self.generate_embeddings(vec![encoding], self.model.clone())
.await?
}
};
if result.is_empty() {
return Err(Error::MissingEmbedding);
}
let mut result = result;
Ok(result.remove(0))
}

pub(crate) async fn search(
Expand Down Expand Up @@ -477,21 +513,24 @@ impl SnippetRetriever {

impl SnippetRetriever {
// TODO: handle overflowing in Encoding
async fn generate_embedding(
/// Embedding order is preserved and stays the same as encoding input
async fn generate_embeddings(
&self,
encoding: Encoding,
encodings: Vec<Encoding>,
model: Arc<BertModel>,
) -> Result<Vec<f32>> {
) -> Result<Vec<Vec<f32>>> {
let start = Instant::now();
let embedding = spawn_blocking(move || -> Result<Vec<f32>> {
let tokens = encoding.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], &model.device)?.unsqueeze(0)?;
let embedding = spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
let tokens = encodings
.iter()
.map(|elem| Ok(Tensor::new(elem.get_ids().to_vec(), &model.device)?))
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&tokens, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let embedding = model.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = embedding.dims3()?;
let embedding = (embedding.sum(1)? / (n_tokens as f64))?;
let embedding = embedding.get(0)?.to_vec1::<f32>()?;
Ok(embedding)
Ok(embedding.to_vec2::<f32>()?)
})
.await?;
debug!("embedding generated in {} ms", start.elapsed().as_millis());
Expand All @@ -512,6 +551,8 @@ impl SnippetRetriever {
let file = tokio::fs::read_to_string(&file_url).await?;
let lines = file.split('\n').collect::<Vec<_>>();
let end = end.unwrap_or(lines.len()).min(lines.len());
let mut snippets: Vec<Snippet> = Vec::new();
debug!("Building embeddings for {file_url}");
for start_line in (start..end).step_by(self.window_step) {
let end_line = (start_line + self.window_size - 1).min(lines.len());
if !col
Expand All @@ -538,35 +579,68 @@ impl SnippetRetriever {
let window = lines[start_line..end_line].to_vec();
let snippet = window.join("\n");
if snippet.is_empty() {
debug!("snippet {file_url}[{start_line}, {end_line}] empty");
continue;
}
snippets.push(Snippet {
file_url: file_url.clone().into(),
code: snippet,
start_line,
end_line,
});
}
{
let nb_snippets = snippets.len();
let steps = self.window_step;
debug!("Build {nb_snippets} snippets for {file_url}: {start}, {end}, {steps}");
}

let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let result = self.generate_embedding(encoding, self.model.clone()).await;
let embedding = match result {
Ok(e) => e,
Err(err) => {
error!(
"error generating embedding for {file_url}[{start_line}, {end_line}]: {err}",
);
continue;
}
};
col.write().await.insert(Embedding::new(
embedding,
Some(HashMap::from([
("file_url".to_owned(), file_url.clone().into()),
("start_line_no".to_owned(), start_line.into()),
("end_line_no".to_owned(), end_line.into()),
("snippet".to_owned(), snippet.clone().into()),
])),
))?;
// Group by length to reduce padding effect
let snippets = spawn_blocking(|| -> Result<Vec<Snippet>> {
snippets.sort_unstable_by(|first, second| first.code.len().cmp(&second.code.len()));
Ok(snippets)
})
.await?;

// TODO: improvements to compute an efficient batch size:
// - batch size should be relative to the cumulative size of all elements in the batch,
// Set embedding_batch_size to 8 if device is GPU, use match
let embedding_batch_size = match self.model.device {
Device::Cpu => 2,
_ => 8,
};
for batch in snippets?.chunks(embedding_batch_size) {
let batch_code = batch.iter().map(|snippet| snippet.code.clone()).collect();
let encodings = self
.tokenizer
.encode_batch(batch_code, true)?
.iter_mut()
.map(|encoding| {
encoding.truncate(512, 1, TruncationDirection::Right);
encoding.clone()
})
.collect();
let results = self
.generate_embeddings(encodings, self.model.clone())
.await?;
col.write().await.batch_insert(
zip(results, batch)
.map(|item| {
let (embedding, snippet) = item;
Embedding::new(
embedding,
Some(HashMap::from([
("file_url".to_owned(), snippet.file_url.clone().into()),
("start_line_no".to_owned(), snippet.start_line.into()),
("end_line_no".to_owned(), snippet.end_line.into()),
("snippet".to_owned(), snippet.code.clone().into()),
])),
)
})
.collect::<Vec<Embedding>>(),
)?;
}
db.save().await?;
Ok(())
}
}
Expand Down
23 changes: 23 additions & 0 deletions crates/tinyvec-embed/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ impl Collection {
Ok(())
}

pub fn batch_insert(&mut self, embeddings: Vec<Embedding>) -> Result<()> {
if embeddings
.iter()
.any(|embedding| embedding.vector.len() != self.dimension)
{
return Err(CollectionError::DimensionMismatch.into());
}
self.embeddings.extend(embeddings);
Ok(())
}

/// Remove values matching filter.
///
/// Empties the collection when `filter` is `None`.
Expand Down Expand Up @@ -244,6 +255,18 @@ impl Value {
}
}

impl TryInto<usize> for &Value {
type Error = Error;

fn try_into(self) -> Result<usize> {
if let Value::Number(n) = self {
Ok(n.clone() as usize)
} else {
Err(Error::ValueNotNumber(self.to_owned()))
}
}
}

impl From<usize> for Value {
fn from(value: usize) -> Self {
Self::Number(value as f32)
Expand Down
2 changes: 2 additions & 0 deletions crates/tinyvec-embed/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub enum Error {
InvalidFileName,
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("expected value to be a valid number, got: {0}")]
ValueNotNumber(Value),
#[error("expected value to be string, got: {0}")]
ValueNotString(Value),
}
Expand Down

0 comments on commit 64a4c38

Please sign in to comment.