Skip to content

Commit

Permalink
feat: add llama.cpp backend
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed May 21, 2024
1 parent 078d4c7 commit 9f79420
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 17 deletions.
7 changes: 3 additions & 4 deletions crates/custom-types/src/llm_ls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ pub enum Backend {
#[serde(default = "hf_default_url", deserialize_with = "parse_url")]
url: String,
},
// TODO:
// LlamaCpp {
// url: String,
// },
LlamaCpp {
url: String,
},
Ollama {
url: String,
},
Expand Down
36 changes: 36 additions & 0 deletions crates/llm-ls/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
}
}

#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppGeneration {
content: String,
}

impl From<LlamaCppGeneration> for Generation {
fn from(value: LlamaCppGeneration) -> Self {
Generation {
generated_text: value.content,
}
}
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum LlamaCppAPIResponse {
Generation(LlamaCppGeneration),
Error(APIError),
}

fn build_llamacpp_headers() -> HeaderMap {
HeaderMap::new()
}

fn parse_llamacpp_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
LlamaCppAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
LlamaCppAPIResponse::Error(err) => Err(Error::LlamaCpp(err)),
}
}

#[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration {
response: String,
Expand Down Expand Up @@ -192,6 +223,9 @@ pub(crate) fn build_body(
request_body.insert("parameters".to_owned(), params);
}
}
Backend::LlamaCpp { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
}
Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model));
Expand All @@ -208,6 +242,7 @@ pub(crate) fn build_headers(
) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::LlamaCpp { .. } => Ok(build_llamacpp_headers()),
Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
Expand All @@ -217,6 +252,7 @@ pub(crate) fn build_headers(
pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace { .. } => parse_api_text(text),
Backend::LlamaCpp { .. } => parse_llamacpp_text(text),
Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi { .. } => parse_tgi_text(text),
Expand Down
34 changes: 24 additions & 10 deletions crates/llm-ls/src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
}

impl PositionEncodingKind {
pub fn to_lsp_type(&self) -> tower_lsp::lsp_types::PositionEncodingKind {
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
match self {
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
Expand Down Expand Up @@ -205,9 +205,10 @@ impl Document {
) -> Result<()> {
match change.range {
Some(range) => {
if range.start.line < range.end.line
if range.start.line > range.end.line
|| (range.start.line == range.end.line
&& range.start.character <= range.end.character) {
&& range.start.character > range.end.character)
{
return Err(Error::InvalidRange(range));
}

Expand All @@ -219,7 +220,10 @@ impl Document {

// 1. Get the line at which the change starts.
let change_start_line_idx = range.start.line as usize;
let change_start_line = self.text.get_line(change_start_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()))?;
let change_start_line =
self.text.get_line(change_start_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines())
})?;

// 2. Get the line at which the change ends. (Small optimization
// where we first check whether start and end line are the
Expand All @@ -228,7 +232,9 @@ impl Document {
let change_end_line_idx = range.end.line as usize;
let change_end_line = match same_line {
true => change_start_line,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()))?,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines())
})?,
};

fn compute_char_idx(
Expand Down Expand Up @@ -330,7 +336,7 @@ impl Document {
self.tree = Some(new_tree);
}
None => {
return Err(Error::TreeSitterParseError);
return Err(Error::TreeSitterParsing);
}
}
}
Expand Down Expand Up @@ -416,7 +422,9 @@ mod test {
let mut rope = Rope::from_str(
"let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';",
);
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string())
.await
.unwrap();
let mut parser = Parser::new();

parser
Expand Down Expand Up @@ -464,7 +472,9 @@ mod test {
#[tokio::test]
async fn test_text_document_apply_content_change_bounds() {
let rope = Rope::from_str("");
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string())
.await
.unwrap();

assert!(doc
.apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16)
Expand Down Expand Up @@ -513,7 +523,9 @@ mod test {
async fn test_document_update_tree_consistency_easy() {
let a = "let a = '你好';\rlet b = 'Hi, 😊';";

let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();

document
.apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16)
Expand Down Expand Up @@ -541,7 +553,9 @@ mod test {
async fn test_document_update_tree_consistency_medium() {
let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';";

let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();

document
.apply_content_change(new_change!(0, 14, 2, 13, ","), PositionEncodingKind::Utf16)
Expand Down
6 changes: 4 additions & 2 deletions crates/llm-ls/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub enum Error {
InvalidRepositoryId,
#[error("invalid tokenizer path")]
InvalidTokenizerPath,
#[error("llama.cpp error: {0}")]
LlamaCpp(crate::backend::APIError),
#[error("ollama error: {0}")]
Ollama(crate::backend::APIError),
#[error("openai error: {0}")]
Expand All @@ -50,7 +52,7 @@ pub enum Error {
#[error("tgi error: {0}")]
Tgi(crate::backend::APIError),
#[error("tree-sitter parse error: timeout possibly exceeded")]
TreeSitterParseError,
TreeSitterParsing,
#[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")]
Expand All @@ -60,7 +62,7 @@ pub enum Error {
#[error("unknown backend: {0}")]
UnknownBackend(String),
#[error("unknown encoding kind: {0}")]
UnknownEncodingKind(String)
UnknownEncodingKind(String),
}

pub(crate) type Result<T> = std::result::Result<T, Error>;
Expand Down
14 changes: 13 additions & 1 deletion crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ async fn get_tokenizer(
fn build_url(backend: Backend, model: &str) -> String {
match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
if url.ends_with("/completions") {
url
} else if url.ends_with('/') {
url.push_str("completions");
url
} else {
url.push_str("/completions");
url
}
}
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
Expand Down Expand Up @@ -540,7 +551,8 @@ impl LanguageServer for LlmService {
general_capabilities
.position_encodings
.map(TryFrom::try_from)
}).unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;
})
.unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;

*self.position_encoding.write().await = position_encoding;

Expand Down

0 comments on commit 9f79420

Please sign in to comment.