diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs index 9b402f9..cd66d5c 100644 --- a/crates/custom-types/src/llm_ls.rs +++ b/crates/custom-types/src/llm_ls.rs @@ -67,10 +67,9 @@ pub enum Backend { #[serde(default = "hf_default_url", deserialize_with = "parse_url")] url: String, }, - // TODO: - // LlamaCpp { - // url: String, - // }, + LlamaCpp { + url: String, + }, Ollama { url: String, }, diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index dba1fab..90324c5 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -67,6 +67,37 @@ fn parse_api_text(text: &str) -> Result> { } } +#[derive(Debug, Serialize, Deserialize)] +struct LlamaCppGeneration { + content: String, +} + +impl From for Generation { + fn from(value: LlamaCppGeneration) -> Self { + Generation { + generated_text: value.content, + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum LlamaCppAPIResponse { + Generation(LlamaCppGeneration), + Error(APIError), +} + +fn build_llamacpp_headers() -> HeaderMap { + HeaderMap::new() +} + +fn parse_llamacpp_text(text: &str) -> Result> { + match serde_json::from_str(text)? { + LlamaCppAPIResponse::Generation(gen) => Ok(vec![gen.into()]), + LlamaCppAPIResponse::Error(err) => Err(Error::LlamaCpp(err)), + } +} + #[derive(Debug, Serialize, Deserialize)] struct OllamaGeneration { response: String, @@ -192,6 +223,9 @@ pub(crate) fn build_body( request_body.insert("parameters".to_owned(), params); } } + Backend::LlamaCpp { .. } => { + request_body.insert("prompt".to_owned(), Value::String(prompt)); + } Backend::Ollama { .. } | Backend::OpenAi { .. } => { request_body.insert("prompt".to_owned(), Value::String(prompt)); request_body.insert("model".to_owned(), Value::String(model)); @@ -208,6 +242,7 @@ pub(crate) fn build_headers( ) -> Result { match backend { Backend::HuggingFace { .. } => build_api_headers(api_token, ide), + Backend::LlamaCpp { .. } => Ok(build_llamacpp_headers()), Backend::Ollama { .. } => Ok(build_ollama_headers()), Backend::OpenAi { .. } => build_openai_headers(api_token, ide), Backend::Tgi { .. } => build_tgi_headers(api_token, ide), @@ -217,6 +252,7 @@ pub(crate) fn build_headers( pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result> { match backend { Backend::HuggingFace { .. } => parse_api_text(text), + Backend::LlamaCpp { .. } => parse_llamacpp_text(text), Backend::Ollama { .. } => parse_ollama_text(text), Backend::OpenAi { .. } => parse_openai_text(text), Backend::Tgi { .. } => parse_tgi_text(text), diff --git a/crates/llm-ls/src/document.rs b/crates/llm-ls/src/document.rs index cd5029d..83c1072 100644 --- a/crates/llm-ls/src/document.rs +++ b/crates/llm-ls/src/document.rs @@ -168,7 +168,7 @@ impl TryFrom> for PositionEncodi } impl PositionEncodingKind { - pub fn to_lsp_type(&self) -> tower_lsp::lsp_types::PositionEncodingKind { + pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind { match self { PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8, PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16, @@ -205,9 +205,10 @@ impl Document { ) -> Result<()> { match change.range { Some(range) => { - if range.start.line < range.end.line + if range.start.line > range.end.line || (range.start.line == range.end.line - && range.start.character <= range.end.character) { + && range.start.character > range.end.character) + { return Err(Error::InvalidRange(range)); } @@ -219,7 +220,10 @@ impl Document { // 1. Get the line at which the change starts. let change_start_line_idx = range.start.line as usize; - let change_start_line = self.text.get_line(change_start_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()))?; + let change_start_line = + self.text.get_line(change_start_line_idx).ok_or_else(|| { + Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()) + })?; // 2. Get the line at which the change ends. (Small optimization // where we first check whether start and end line are the @@ -228,7 +232,9 @@ impl Document { let change_end_line_idx = range.end.line as usize; let change_end_line = match same_line { true => change_start_line, - false => self.text.get_line(change_end_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()))?, + false => self.text.get_line(change_end_line_idx).ok_or_else(|| { + Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()) + })?, }; fn compute_char_idx( @@ -330,7 +336,7 @@ impl Document { self.tree = Some(new_tree); } None => { - return Err(Error::TreeSitterParseError); + return Err(Error::TreeSitterParsing); } } } @@ -416,7 +422,9 @@ mod test { let mut rope = Rope::from_str( "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';", ); - let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()).await.unwrap(); + let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()) + .await + .unwrap(); let mut parser = Parser::new(); parser @@ -464,7 +472,9 @@ mod test { #[tokio::test] async fn test_text_document_apply_content_change_bounds() { let rope = Rope::from_str(""); - let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()).await.unwrap(); + let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()) + .await + .unwrap(); assert!(doc .apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16) @@ -513,7 +523,9 @@ mod test { async fn test_document_update_tree_consistency_easy() { let a = "let a = '你好';\rlet b = 'Hi, 😊';"; - let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap(); + let mut document = Document::open(&LanguageId::JavaScript.to_string(), a) + .await + .unwrap(); document .apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16) @@ -541,7 +553,9 @@ mod test { async fn test_document_update_tree_consistency_medium() { let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';"; - let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap(); + let mut document = Document::open(&LanguageId::JavaScript.to_string(), a) + .await + .unwrap(); document .apply_content_change(new_change!(0, 14, 2, 13, ","), PositionEncodingKind::Utf16) diff --git a/crates/llm-ls/src/error.rs b/crates/llm-ls/src/error.rs index 8ab92b9..8b84474 100644 --- a/crates/llm-ls/src/error.rs +++ b/crates/llm-ls/src/error.rs @@ -33,6 +33,8 @@ pub enum Error { InvalidRepositoryId, #[error("invalid tokenizer path")] InvalidTokenizerPath, + #[error("llama.cpp error: {0}")] + LlamaCpp(crate::backend::APIError), #[error("ollama error: {0}")] Ollama(crate::backend::APIError), #[error("openai error: {0}")] @@ -50,7 +52,7 @@ pub enum Error { #[error("tgi error: {0}")] Tgi(crate::backend::APIError), #[error("tree-sitter parse error: timeout possibly exceeded")] - TreeSitterParseError, + TreeSitterParsing, #[error("tree-sitter language error: {0}")] TreeSitterLanguage(#[from] tree_sitter::LanguageError), #[error("tokenizer error: {0}")] @@ -60,7 +62,7 @@ pub enum Error { #[error("unknown backend: {0}")] UnknownBackend(String), #[error("unknown encoding kind: {0}")] - UnknownEncodingKind(String) + UnknownEncodingKind(String), } pub(crate) type Result = std::result::Result; diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 87bed4c..ff74be3 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -417,6 +417,17 @@ async fn get_tokenizer( fn build_url(backend: Backend, model: &str) -> String { match backend { Backend::HuggingFace { url } => format!("{url}/models/{model}"), + Backend::LlamaCpp { mut url } => { + if url.ends_with("/completions") { + url + } else if url.ends_with('/') { + url.push_str("completions"); + url + } else { + url.push_str("/completions"); + url + } + } Backend::Ollama { url } => url, Backend::OpenAi { url } => url, Backend::Tgi { url } => url, @@ -540,7 +551,8 @@ impl LanguageServer for LlmService { general_capabilities .position_encodings .map(TryFrom::try_from) - }).unwrap_or(Ok(document::PositionEncodingKind::Utf16))?; + }) + .unwrap_or(Ok(document::PositionEncodingKind::Utf16))?; *self.position_encoding.write().await = position_encoding;