Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added batch embedding computing #86

Merged
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/llm-ls/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ pub enum Error {
UnknownBackend(String),
#[error("yaml serialization error: {0}")]
Yaml(#[from] serde_yaml::Error),
#[error("No embedding built")]
MissingEmbedding,
}

pub(crate) type Result<T> = std::result::Result<T, Error>;
Expand Down
160 changes: 117 additions & 43 deletions crates/llm-ls/src/retrieval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use gitignore::Gitignore;
use hf_hub::{api::tokio::Api, Repo, RepoType};
use std::collections::{HashMap, VecDeque};
use std::iter::zip;
use std::path::Path;
use std::{path::PathBuf, sync::Arc};
use tinyvec_embed::db::{Collection, Compare, Db, Embedding, FilterBuilder, SimilarityResult};
use tinyvec_embed::similarity::Distance;
use tokenizers::{Encoding, Tokenizer, TruncationDirection};
use tokenizers::{
Encoding, PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
};
use tokio::io::AsyncReadExt;
use tokio::task::spawn_blocking;
use tokio::time::Instant;
Expand Down Expand Up @@ -156,9 +159,16 @@ async fn build_model_and_tokenizer(
let config = tokio::fs::read_to_string(config_filename).await?;
let config: Config = serde_json::from_str(&config)?;
let mut tokenizer: Tokenizer = Tokenizer::from_file(tokenizer_filename)?;
tokenizer.with_padding(None);
tokenizer.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Right,
pad_to_multiple_of: Some(8),
// TODO: use values provided in model config
pad_id: 0,
pad_type_id: 0,
pad_token: "<pad>".to_string(),
}));
tokenizer.with_truncation(None)?;

let vb = VarBuilder::from_pth(&weights_filename, DTYPE, &device)?;
let model = BertModel::load(vb, &config)?;
debug!(
Expand Down Expand Up @@ -191,6 +201,8 @@ fn device(cpu: bool) -> Result<Device> {
pub(crate) struct Snippet {
pub(crate) file_url: String,
pub(crate) code: String,
pub(crate) start_line: usize,
pub(crate) end_line: usize,
}

impl TryFrom<&SimilarityResult> for Snippet {
Expand All @@ -210,7 +222,20 @@ impl TryFrom<&SimilarityResult> for Snippet {
.get("snippet")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.inner_string()?;
Ok(Snippet { file_url, code })
let start_line = meta
.get("start_line_no")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.try_into()?;
let end_line = meta
.get("start_line_no")
.ok_or_else(|| Error::MalformattedEmbeddingMetadata("snippet".to_owned()))?
.try_into()?;
Ok(Snippet {
file_url,
code,
start_line,
end_line,
})
}
}

Expand Down Expand Up @@ -280,6 +305,7 @@ impl SnippetRetriever {
workspace_root: &str,
) -> Result<()> {
debug!("building workspace snippets");
let start = Instant::now();
let workspace_root = PathBuf::from(workspace_root);
if self.db.is_none() {
self.initialise_database(&format!(
Expand Down Expand Up @@ -360,7 +386,10 @@ impl SnippetRetriever {
})
.await;
}

debug!(
"Built workspace snippets in {} ms",
start.elapsed().as_millis()
);
Ok(())
}

Expand All @@ -384,15 +413,16 @@ impl SnippetRetriever {
snippet: String,
strategy: BuildFrom,
) -> Result<Vec<f32>> {
match strategy {
let result = match strategy {
BuildFrom::Start => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
self.generate_embedding(encoding, self.model.clone()).await
self.generate_embeddings(vec![encoding], self.model.clone())
.await?
}
BuildFrom::Cursor { cursor_position } => {
let (before, after) = snippet.split_at(cursor_position);
Expand All @@ -404,8 +434,8 @@ impl SnippetRetriever {
before_encoding.take_overflowing();
after_encoding.take_overflowing();
before_encoding.merge_with(after_encoding, false);
self.generate_embedding(before_encoding, self.model.clone())
.await
self.generate_embeddings(vec![before_encoding], self.model.clone())
.await?
}
BuildFrom::End => {
let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
Expand All @@ -414,9 +444,15 @@ impl SnippetRetriever {
1,
TruncationDirection::Left,
);
self.generate_embedding(encoding, self.model.clone()).await
self.generate_embeddings(vec![encoding], self.model.clone())
.await?
}
};
if result.is_empty() {
return Err(Error::MissingEmbedding);
}
let mut result = result;
Ok(result.remove(0))
}

pub(crate) async fn search(
Expand Down Expand Up @@ -477,21 +513,24 @@ impl SnippetRetriever {

impl SnippetRetriever {
// TODO: handle overflowing in Encoding
async fn generate_embedding(
/// Embedding order is preserved and stays the same as encoding input
async fn generate_embeddings(
McPatate marked this conversation as resolved.
Show resolved Hide resolved
&self,
encoding: Encoding,
encodings: Vec<Encoding>,
model: Arc<BertModel>,
) -> Result<Vec<f32>> {
) -> Result<Vec<Vec<f32>>> {
let start = Instant::now();
let embedding = spawn_blocking(move || -> Result<Vec<f32>> {
let tokens = encoding.get_ids().to_vec();
let token_ids = Tensor::new(&tokens[..], &model.device)?.unsqueeze(0)?;
let embedding = spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
let tokens = encodings
.iter()
.map(|elem| Ok(Tensor::new(elem.get_ids().to_vec(), &model.device)?))
.collect::<Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&tokens, 0)?;
let token_type_ids = token_ids.zeros_like()?;
let embedding = model.forward(&token_ids, &token_type_ids)?;
let (_n_sentence, n_tokens, _hidden_size) = embedding.dims3()?;
let embedding = (embedding.sum(1)? / (n_tokens as f64))?;
let embedding = embedding.get(0)?.to_vec1::<f32>()?;
Ok(embedding)
Ok(embedding.to_vec2::<f32>()?)
})
.await?;
debug!("embedding generated in {} ms", start.elapsed().as_millis());
Expand All @@ -512,6 +551,8 @@ impl SnippetRetriever {
let file = tokio::fs::read_to_string(&file_url).await?;
let lines = file.split('\n').collect::<Vec<_>>();
let end = end.unwrap_or(lines.len()).min(lines.len());
let mut snippets: Vec<Snippet> = Vec::new();
debug!("Building embeddings for {file_url}");
for start_line in (start..end).step_by(self.window_step) {
let end_line = (start_line + self.window_size - 1).min(lines.len());
if !col
Expand All @@ -538,35 +579,68 @@ impl SnippetRetriever {
let window = lines[start_line..end_line].to_vec();
let snippet = window.join("\n");
if snippet.is_empty() {
debug!("snippet {file_url}[{start_line}, {end_line}] empty");
continue;
}
snippets.push(Snippet {
file_url: file_url.clone().into(),
code: snippet,
start_line,
end_line,
});
}
{
let nb_snippets = snippets.len();
let steps = self.window_step;
debug!("Build {nb_snippets} snippets for {file_url}: {start}, {end}, {steps}");
}

let mut encoding = self.tokenizer.encode(snippet.clone(), true)?;
encoding.truncate(
self.model_config.max_input_size,
1,
TruncationDirection::Right,
);
let result = self.generate_embedding(encoding, self.model.clone()).await;
let embedding = match result {
Ok(e) => e,
Err(err) => {
error!(
"error generating embedding for {file_url}[{start_line}, {end_line}]: {err}",
);
continue;
}
};
col.write().await.insert(Embedding::new(
embedding,
Some(HashMap::from([
("file_url".to_owned(), file_url.clone().into()),
("start_line_no".to_owned(), start_line.into()),
("end_line_no".to_owned(), end_line.into()),
("snippet".to_owned(), snippet.clone().into()),
])),
))?;
// Group by length to reduce padding effect
let snippets = spawn_blocking(|| -> Result<Vec<Snippet>> {
snippets.sort_unstable_by(|first, second| first.code.len().cmp(&second.code.len()));
Ok(snippets)
})
.await?;

// TODO: improvements to compute an efficient batch size:
// - batch size should be relative to the cumulative size of all elements in the batch,
// Set embedding_batch_size to 8 if device is GPU, use match
let embedding_batch_size = match self.model.device {
Device::Cpu => 2,
_ => 8,
};
for batch in snippets?.chunks(embedding_batch_size) {
let batch_code = batch.iter().map(|snippet| snippet.code.clone()).collect();
let encodings = self
.tokenizer
.encode_batch(batch_code, true)?
.iter_mut()
.map(|encoding| {
encoding.truncate(512, 1, TruncationDirection::Right);
encoding.clone()
})
.collect();
let results = self
.generate_embeddings(encodings, self.model.clone())
.await?;
col.write().await.batch_insert(
zip(results, batch)
.map(|item| {
let (embedding, snippet) = item;
Embedding::new(
embedding,
Some(HashMap::from([
("file_url".to_owned(), snippet.file_url.clone().into()),
("start_line_no".to_owned(), snippet.start_line.into()),
("end_line_no".to_owned(), snippet.end_line.into()),
("snippet".to_owned(), snippet.code.clone().into()),
])),
)
})
.collect::<Vec<Embedding>>(),
)?;
}
db.save().await?;
Ok(())
}
}
Expand Down
23 changes: 23 additions & 0 deletions crates/tinyvec-embed/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ impl Collection {
Ok(())
}

pub fn batch_insert(&mut self, embeddings: Vec<Embedding>) -> Result<()> {
if embeddings
.iter()
.any(|embedding| embedding.vector.len() != self.dimension)
{
return Err(CollectionError::DimensionMismatch.into());
}
self.embeddings.extend(embeddings);
Ok(())
}

/// Remove values matching filter.
///
/// Empties the collection when `filter` is `None`.
Expand Down Expand Up @@ -244,6 +255,18 @@ impl Value {
}
}

impl TryInto<usize> for &Value {
type Error = Error;

fn try_into(self) -> Result<usize> {
if let Value::Number(n) = self {
Ok(n.clone() as usize)
McPatate marked this conversation as resolved.
Show resolved Hide resolved
} else {
Err(Error::ValueNotNumber(self.to_owned()))
}
}
}

impl From<usize> for Value {
fn from(value: usize) -> Self {
Self::Number(value as f32)
Expand Down
4 changes: 4 additions & 0 deletions crates/tinyvec-embed/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ pub enum Error {
InvalidFileName,
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("expected value to be a valid number, got: {0}")]
ValueNotNumber(Value),
#[error("expected value to be string, got: {0}")]
ValueNotString(Value),
#[error("expected value to be a valid size, got: {0}")]
ValueNotUsize(Value),
Wats0ns marked this conversation as resolved.
Show resolved Hide resolved
McPatate marked this conversation as resolved.
Show resolved Hide resolved
McPatate marked this conversation as resolved.
Show resolved Hide resolved
}

pub type Result<T> = std::result::Result<T, Error>;