From a9831d57202386fe2e43d419e1f3f0c6561050db Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Tue, 6 Feb 2024 21:26:53 +0100 Subject: [PATCH 01/10] refactor: adaptor -> backend (#70) --- crates/llm-ls/src/{adaptors.rs => backend.rs} | 117 ++++++------------ crates/llm-ls/src/main.rs | 93 ++++++-------- 2 files changed, 71 insertions(+), 139 deletions(-) rename crates/llm-ls/src/{adaptors.rs => backend.rs} (57%) diff --git a/crates/llm-ls/src/adaptors.rs b/crates/llm-ls/src/backend.rs similarity index 57% rename from crates/llm-ls/src/adaptors.rs rename to crates/llm-ls/src/backend.rs index 553fc87..a139870 100644 --- a/crates/llm-ls/src/adaptors.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,26 +1,12 @@ use super::{ - internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, RequestParams, NAME, - VERSION, + internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION, }; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Map, Value}; use std::fmt::Display; use tower_lsp::jsonrpc; -fn build_tgi_body(prompt: String, params: &RequestParams) -> Value { - serde_json::json!({ - "inputs": prompt, - "parameters": { - "max_new_tokens": params.max_new_tokens, - "temperature": params.temperature, - "do_sample": params.do_sample, - "top_p": params.top_p, - "stop_tokens": params.stop_tokens.clone() - }, - }) -} - fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); @@ -45,7 +31,7 @@ fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { APIResponse::Generation(gen) => vec![gen], APIResponse::Generations(_) => { return Err(internal_error( - "You are attempting to parse a result in the API inference format when using the `tgi` adaptor", + "You are attempting to parse a result in the API inference format when using the `tgi` backend", )) } APIResponse::Error(err) => return Err(internal_error(err)), @@ -53,10 +39,6 @@ fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { Ok(generations) } -fn build_api_body(prompt: String, params: &RequestParams) -> Value { - build_tgi_body(prompt, params) -} - fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result { build_tgi_headers(api_token, ide) } @@ -70,20 +52,6 @@ fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { Ok(generations) } -fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value { - serde_json::json!({ - "prompt": prompt, - "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for ollama").get("model"), - "stream": false, - // As per [modelfile](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values) - "options": { - "num_predict": params.request_params.max_new_tokens, - "temperature": params.request_params.temperature, - "top_p": params.request_params.top_p, - "stop": params.request_params.stop_tokens.clone(), - } - }) -} fn build_ollama_headers() -> Result { Ok(HeaderMap::new()) } @@ -116,17 +84,6 @@ fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { Ok(generations) } -fn build_openai_body(prompt: String, params: &CompletionParams) -> Value { - serde_json::json!({ - "prompt": prompt, - "model": params.request_body.as_ref().ok_or_else(|| internal_error("missing request_body")).expect("Unable to make request for openai").get("model"), - "max_tokens": params.request_params.max_new_tokens, - "temperature": params.request_params.temperature, - "top_p": params.request_params.top_p, - "stop": params.request_params.stop_tokens.clone(), - }) -} - fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result { build_api_headers(api_token, ide) } @@ -206,51 +163,47 @@ fn parse_openai_text(text: &str) -> Result, jsonrpc::Error> { } } -pub(crate) const TGI: &str = "tgi"; -pub(crate) const HUGGING_FACE: &str = "huggingface"; -pub(crate) const OLLAMA: &str = "ollama"; -pub(crate) const OPENAI: &str = "openai"; -pub(crate) const DEFAULT_ADAPTOR: &str = HUGGING_FACE; - -fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error { - internal_error(format!("Unknown adaptor {:?}", adaptor)) +#[derive(Debug, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub(crate) enum Backend { + #[default] + HuggingFace, + Ollama, + OpenAi, + Tgi, } -pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result { - match params - .adaptor - .as_ref() - .unwrap_or(&DEFAULT_ADAPTOR.to_string()) - .as_str() - { - TGI => Ok(build_tgi_body(prompt, ¶ms.request_params)), - HUGGING_FACE => Ok(build_api_body(prompt, ¶ms.request_params)), - OLLAMA => Ok(build_ollama_body(prompt, params)), - OPENAI => Ok(build_openai_body(prompt, params)), - _ => Err(unknown_adaptor_error(params.adaptor.as_ref())), - } +pub fn build_body(prompt: String, params: &CompletionParams) -> Map { + let mut body = params.request_body.clone(); + match params.backend { + Backend::HuggingFace | Backend::Tgi => { + body.insert("inputs".to_string(), Value::String(prompt)) + } + Backend::Ollama | Backend::OpenAi => { + body.insert("prompt".to_string(), Value::String(prompt)) + } + }; + body } -pub fn adapt_headers( - adaptor: Option<&String>, +pub fn build_headers( + backend: &Backend, api_token: Option<&String>, ide: Ide, ) -> Result { - match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { - TGI => build_tgi_headers(api_token, ide), - HUGGING_FACE => build_api_headers(api_token, ide), - OLLAMA => build_ollama_headers(), - OPENAI => build_openai_headers(api_token, ide), - _ => Err(unknown_adaptor_error(adaptor)), + match backend { + Backend::HuggingFace => build_api_headers(api_token, ide), + Backend::Ollama => build_ollama_headers(), + Backend::OpenAi => build_openai_headers(api_token, ide), + Backend::Tgi => build_tgi_headers(api_token, ide), } } -pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result> { - match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() { - TGI => parse_tgi_text(text), - HUGGING_FACE => parse_api_text(text), - OLLAMA => parse_ollama_text(text), - OPENAI => parse_openai_text(text), - _ => Err(unknown_adaptor_error(adaptor)), +pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result> { + match backend { + Backend::HuggingFace => parse_api_text(text), + Backend::Ollama => parse_ollama_text(text), + Backend::OpenAi => parse_openai_text(text), + Backend::Tgi => parse_tgi_text(text), } } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 318be3b..fcd7d3c 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,8 +1,8 @@ -use adaptors::{adapt_body, adapt_headers, parse_generations}; +use backend::{build_body, build_headers, parse_generations, Backend}; use document::Document; -use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use ropey::Rope; use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::{Map, Value}; use std::collections::HashMap; use std::fmt::Display; use std::path::{Path, PathBuf}; @@ -19,7 +19,7 @@ use tracing_appender::rolling; use tracing_subscriber::EnvFilter; use uuid::Uuid; -mod adaptors; +mod backend; mod document; mod language_id; @@ -117,10 +117,7 @@ fn should_complete(document: &Document, position: Position) -> Result Result, + }, + Download { + url: String, + to: PathBuf, + }, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -209,7 +214,7 @@ pub enum APIResponse { Error(APIError), } -struct Backend { +struct LlmService { cache_dir: PathBuf, client: Client, document_map: Arc>>, @@ -272,19 +277,18 @@ struct RejectedCompletion { pub struct CompletionParams { #[serde(flatten)] text_document_position: TextDocumentPositionParams, - request_params: RequestParams, #[serde(default)] #[serde(deserialize_with = "parse_ide")] ide: Ide, fim: FimParams, api_token: Option, model: String, - adaptor: Option, + backend: Backend, tokens_to_clear: Vec, tokenizer_config: Option, context_window: usize, tls_skip_verify_insecure: bool, - request_body: Option>, + request_body: Map, } #[derive(Debug, Deserialize, Serialize)] @@ -413,12 +417,8 @@ async fn request_completion( ) -> Result> { let t = Instant::now(); - let json = adapt_body(prompt, params).map_err(internal_error)?; - let headers = adapt_headers( - params.adaptor.as_ref(), - params.api_token.as_ref(), - params.ide, - )?; + let json = build_body(prompt, params); + let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?; let res = http_client .post(build_url(¶ms.model)) .json(&json) @@ -429,7 +429,7 @@ async fn request_completion( let model = ¶ms.model; let generations = parse_generations( - params.adaptor.as_ref(), + ¶ms.backend, res.text().await.map_err(internal_error)?.as_str(), ); let time = t.elapsed().as_millis(); @@ -489,7 +489,7 @@ async fn download_tokenizer_file( ) .await .map_err(internal_error)?; - let headers = build_headers(api_token, ide)?; + let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; let mut file = tokio::fs::OpenOptions::new() .write(true) .create(true) @@ -538,7 +538,6 @@ async fn get_tokenizer( tokenizer_config: Option<&TokenizerConfig>, http_client: &reqwest::Client, cache_dir: impl AsRef, - api_token: Option<&String>, ide: Ide, ) -> Result>> { if let Some(tokenizer) = tokenizer_map.get(model) { @@ -553,11 +552,14 @@ async fn get_tokenizer( None } }, - TokenizerConfig::HuggingFace { repository } => { + TokenizerConfig::HuggingFace { + repository, + api_token, + } => { let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); let url = format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); - download_tokenizer_file(http_client, &url, api_token, &path, ide).await?; + download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?; match Tokenizer::from_file(path) { Ok(tokenizer) => Some(Arc::new(tokenizer)), Err(err) => { @@ -567,7 +569,7 @@ async fn get_tokenizer( } } TokenizerConfig::Download { url, to } => { - download_tokenizer_file(http_client, url, api_token, &to, ide).await?; + download_tokenizer_file(http_client, url, None, &to, ide).await?; match Tokenizer::from_file(to) { Ok(tokenizer) => Some(Arc::new(tokenizer)), Err(err) => { @@ -594,7 +596,7 @@ fn build_url(model: &str) -> String { } } -impl Backend { +impl LlmService { async fn get_completions(&self, params: CompletionParams) -> Result { let request_id = Uuid::new_v4(); let span = info_span!("completion_request", %request_id); @@ -611,15 +613,11 @@ impl Backend { language_id = %document.language_id, model = params.model, ide = %params.ide, - max_new_tokens = params.request_params.max_new_tokens, - temperature = params.request_params.temperature, - do_sample = params.request_params.do_sample, - top_p = params.request_params.top_p, - stop_tokens = ?params.request_params.stop_tokens, + request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?, "received completion request for {}", params.text_document_position.text_document.uri ); - let is_using_inference_api = params.adaptor.as_ref().unwrap_or(&adaptors::DEFAULT_ADAPTOR.to_owned()).as_str() == adaptors::HUGGING_FACE; + let is_using_inference_api = matches!(params.backend, Backend::HuggingFace); if params.api_token.is_none() && is_using_inference_api { let now = Instant::now(); let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; @@ -642,7 +640,6 @@ impl Backend { params.tokenizer_config.as_ref(), &self.http_client, &self.cache_dir, - params.api_token.as_ref(), params.ide, ) .await?; @@ -693,7 +690,7 @@ impl Backend { } #[tower_lsp::async_trait] -impl LanguageServer for Backend { +impl LanguageServer for LlmService { async fn initialize(&self, params: InitializeParams) -> Result { *self.workspace_folders.write().await = params.workspace_folders; Ok(InitializeResult { @@ -795,24 +792,6 @@ impl LanguageServer for Backend { } } -fn build_headers(api_token: Option<&String>, ide: Ide) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); - headers.insert( - USER_AGENT, - HeaderValue::from_str(&user_agent).map_err(internal_error)?, - ); - - if let Some(api_token) = api_token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, - ); - } - - Ok(headers) -} - #[tokio::main] async fn main() { let stdin = tokio::io::stdin(); @@ -846,7 +825,7 @@ async fn main() { .build() .expect("failed to build reqwest unsafe client"); - let (service, socket) = LspService::build(|client| Backend { + let (service, socket) = LspService::build(|client| LlmService { cache_dir, client, document_map: Arc::new(RwLock::new(HashMap::new())), @@ -860,9 +839,9 @@ async fn main() { .expect("instant to be in bounds"), )), }) - .custom_method("llm-ls/getCompletions", Backend::get_completions) - .custom_method("llm-ls/acceptCompletion", Backend::accept_completion) - .custom_method("llm-ls/rejectCompletion", Backend::reject_completion) + .custom_method("llm-ls/getCompletions", LlmService::get_completions) + .custom_method("llm-ls/acceptCompletion", LlmService::accept_completion) + .custom_method("llm-ls/rejectCompletion", LlmService::reject_completion) .finish(); Server::new(stdin, stdout, socket).serve(service).await; From 455b085c963a6b3900395451969b4825c2933d6b Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Wed, 7 Feb 2024 12:34:17 +0100 Subject: [PATCH 02/10] refactor: error handling (#71) --- Cargo.lock | 1 + crates/llm-ls/Cargo.toml | 1 + crates/llm-ls/src/backend.rs | 93 ++++++++++--------------- crates/llm-ls/src/document.rs | 127 +++++++++------------------------- crates/llm-ls/src/error.rs | 64 +++++++++++++++++ crates/llm-ls/src/main.rs | 118 +++++++++++++------------------ 6 files changed, 186 insertions(+), 218 deletions(-) create mode 100644 crates/llm-ls/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index c7185c2..3129b6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -979,6 +979,7 @@ dependencies = [ "ropey", "serde", "serde_json", + "thiserror", "tokenizers", "tokio", "tower-lsp", diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index 7ac7cec..f70e218 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = [ ] } serde = { version = "1", features = ["derive"] } serde_json = "1" +thiserror = "1" tokenizers = { version = "0.14", default-features = false, features = ["onig"] } tokio = { version = "1", features = [ "fs", diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index a139870..c9f18cd 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,59 +1,44 @@ -use super::{ - internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION, -}; +use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use std::fmt::Display; -use tower_lsp::jsonrpc; -fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { +use crate::error::{Error, Result}; + +fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); - headers.insert( - USER_AGENT, - HeaderValue::from_str(&user_agent).map_err(internal_error)?, - ); + headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); if let Some(api_token) = api_token { headers.insert( AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?, + HeaderValue::from_str(&format!("Bearer {api_token}"))?, ); } Ok(headers) } -fn parse_tgi_text(text: &str) -> Result, jsonrpc::Error> { - let generations = - match serde_json::from_str(text).map_err(internal_error)? { - APIResponse::Generation(gen) => vec![gen], - APIResponse::Generations(_) => { - return Err(internal_error( - "You are attempting to parse a result in the API inference format when using the `tgi` backend", - )) - } - APIResponse::Error(err) => return Err(internal_error(err)), - }; - Ok(generations) +fn parse_tgi_text(text: &str) -> Result> { + match serde_json::from_str(text)? { + APIResponse::Generation(gen) => Ok(vec![gen]), + APIResponse::Generations(_) => Err(Error::InvalidBackend), + APIResponse::Error(err) => Err(Error::Tgi(err)), + } } -fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result { +fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result { build_tgi_headers(api_token, ide) } -fn parse_api_text(text: &str) -> Result, jsonrpc::Error> { - let generations = match serde_json::from_str(text).map_err(internal_error)? { - APIResponse::Generation(gen) => vec![gen], - APIResponse::Generations(gens) => gens, - APIResponse::Error(err) => return Err(internal_error(err)), - }; - Ok(generations) -} - -fn build_ollama_headers() -> Result { - Ok(HeaderMap::new()) +fn parse_api_text(text: &str) -> Result> { + match serde_json::from_str(text)? { + APIResponse::Generation(gen) => Ok(vec![gen]), + APIResponse::Generations(gens) => Ok(gens), + APIResponse::Error(err) => Err(Error::InferenceApi(err)), + } } #[derive(Debug, Serialize, Deserialize)] @@ -76,16 +61,15 @@ enum OllamaAPIResponse { Error(APIError), } -fn parse_ollama_text(text: &str) -> Result, jsonrpc::Error> { - let generations = match serde_json::from_str(text).map_err(internal_error)? { - OllamaAPIResponse::Generation(gen) => vec![gen.into()], - OllamaAPIResponse::Error(err) => return Err(internal_error(err)), - }; - Ok(generations) +fn build_ollama_headers() -> HeaderMap { + HeaderMap::new() } -fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result { - build_api_headers(api_token, ide) +fn parse_ollama_text(text: &str) -> Result> { + match serde_json::from_str(text)? { + OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]), + OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)), + } } #[derive(Debug, Deserialize)] @@ -130,7 +114,7 @@ struct OpenAIErrorDetail { } #[derive(Debug, Deserialize)] -struct OpenAIError { +pub struct OpenAIError { detail: Vec, } @@ -153,13 +137,16 @@ enum OpenAIAPIResponse { Error(OpenAIError), } -fn parse_openai_text(text: &str) -> Result, jsonrpc::Error> { - match serde_json::from_str(text).map_err(internal_error) { - Ok(OpenAIAPIResponse::Generation(completion)) => { +fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result { + build_api_headers(api_token, ide) +} + +fn parse_openai_text(text: &str) -> Result> { + match serde_json::from_str(text)? { + OpenAIAPIResponse::Generation(completion) => { Ok(completion.choices.into_iter().map(|x| x.into()).collect()) } - Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)), - Err(err) => Err(internal_error(err)), + OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)), } } @@ -186,20 +173,16 @@ pub fn build_body(prompt: String, params: &CompletionParams) -> Map, - ide: Ide, -) -> Result { +pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result { match backend { Backend::HuggingFace => build_api_headers(api_token, ide), - Backend::Ollama => build_ollama_headers(), + Backend::Ollama => Ok(build_ollama_headers()), Backend::OpenAi => build_openai_headers(api_token, ide), Backend::Tgi => build_tgi_headers(api_token, ide), } } -pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result> { +pub fn parse_generations(backend: &Backend, text: &str) -> Result> { match backend { Backend::HuggingFace => parse_api_text(text), Backend::Ollama => parse_ollama_text(text), diff --git a/crates/llm-ls/src/document.rs b/crates/llm-ls/src/document.rs index 689e33b..bd06f28 100644 --- a/crates/llm-ls/src/document.rs +++ b/crates/llm-ls/src/document.rs @@ -1,172 +1,126 @@ use ropey::Rope; -use tower_lsp::jsonrpc::Result; use tower_lsp::lsp_types::Range; use tree_sitter::{InputEdit, Parser, Point, Tree}; +use crate::error::Result; +use crate::get_position_idx; use crate::language_id::LanguageId; -use crate::{get_position_idx, internal_error}; fn get_parser(language_id: LanguageId) -> Result { match language_id { LanguageId::Bash => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_bash::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_bash::language())?; Ok(parser) } LanguageId::C => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_c::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_c::language())?; Ok(parser) } LanguageId::Cpp => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_cpp::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_cpp::language())?; Ok(parser) } LanguageId::CSharp => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_c_sharp::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_c_sharp::language())?; Ok(parser) } LanguageId::Elixir => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_elixir::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_elixir::language())?; Ok(parser) } LanguageId::Erlang => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_erlang::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_erlang::language())?; Ok(parser) } LanguageId::Go => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_go::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_go::language())?; Ok(parser) } LanguageId::Html => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_html::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_html::language())?; Ok(parser) } LanguageId::Java => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_java::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_java::language())?; Ok(parser) } LanguageId::JavaScript | LanguageId::JavaScriptReact => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_javascript::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_javascript::language())?; Ok(parser) } LanguageId::Json => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_json::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_json::language())?; Ok(parser) } LanguageId::Kotlin => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_kotlin::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_kotlin::language())?; Ok(parser) } LanguageId::Lua => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_lua::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_lua::language())?; Ok(parser) } LanguageId::Markdown => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_md::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_md::language())?; Ok(parser) } LanguageId::ObjectiveC => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_objc::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_objc::language())?; Ok(parser) } LanguageId::Python => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_python::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_python::language())?; Ok(parser) } LanguageId::R => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_r::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_r::language())?; Ok(parser) } LanguageId::Ruby => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_ruby::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_ruby::language())?; Ok(parser) } LanguageId::Rust => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_rust::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_rust::language())?; Ok(parser) } LanguageId::Scala => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_scala::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_scala::language())?; Ok(parser) } LanguageId::Swift => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_swift::language()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_swift::language())?; Ok(parser) } LanguageId::TypeScript => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_typescript::language_typescript()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_typescript::language_typescript())?; Ok(parser) } LanguageId::TypeScriptReact => { let mut parser = Parser::new(); - parser - .set_language(tree_sitter_typescript::language_tsx()) - .map_err(internal_error)?; + parser.set_language(tree_sitter_typescript::language_tsx())?; Ok(parser) } LanguageId::Unknown => Ok(Parser::new()), @@ -200,19 +154,13 @@ impl Document { range.start.line as usize, range.start.character as usize, )?; - let start_byte = self - .text - .try_char_to_byte(start_idx) - .map_err(internal_error)?; + let start_byte = self.text.try_char_to_byte(start_idx)?; let old_end_idx = get_position_idx( &self.text, range.end.line as usize, range.end.character as usize, )?; - let old_end_byte = self - .text - .try_char_to_byte(old_end_idx) - .map_err(internal_error)?; + let old_end_byte = self.text.try_char_to_byte(old_end_idx)?; let start_position = Point { row: range.start.line as usize, column: range.start.character as usize, @@ -224,7 +172,7 @@ impl Document { let (new_end_idx, new_end_position) = if range.start == range.end { let row = range.start.line as usize; let column = range.start.character as usize; - let idx = self.text.try_line_to_char(row).map_err(internal_error)? + column; + let idx = self.text.try_line_to_char(row)? + column; let rope = Rope::from_str(text); let text_len = rope.len_chars(); let end_idx = idx + text_len; @@ -237,11 +185,10 @@ impl Document { }, ) } else { - let removal_idx = self.text.try_line_to_char(range.end.line as usize).map_err(internal_error)? + (range.end.character as usize); + let removal_idx = self.text.try_line_to_char(range.end.line as usize)? + + (range.end.character as usize); let slice_size = removal_idx - start_idx; - self.text - .try_remove(start_idx..removal_idx) - .map_err(internal_error)?; + self.text.try_remove(start_idx..removal_idx)?; self.text.insert(start_idx, text); let rope = Rope::from_str(text); let text_len = rope.len_chars(); @@ -251,11 +198,8 @@ impl Document { } else { removal_idx + character_difference as usize }; - let row = self - .text - .try_char_to_line(new_end_idx) - .map_err(internal_error)?; - let line_start = self.text.try_line_to_char(row).map_err(internal_error)?; + let row = self.text.try_char_to_line(new_end_idx)?; + let line_start = self.text.try_line_to_char(row)?; let column = new_end_idx - line_start; (new_end_idx, Point { row, column }) }; @@ -263,10 +207,7 @@ impl Document { let edit = InputEdit { start_byte, old_end_byte, - new_end_byte: self - .text - .try_char_to_byte(new_end_idx) - .map_err(internal_error)?, + new_end_byte: self.text.try_char_to_byte(new_end_idx)?, start_position, old_end_position, new_end_position, diff --git a/crates/llm-ls/src/error.rs b/crates/llm-ls/src/error.rs new file mode 100644 index 0000000..aa51588 --- /dev/null +++ b/crates/llm-ls/src/error.rs @@ -0,0 +1,64 @@ +use std::fmt::Display; + +use tower_lsp::jsonrpc::Error as LspError; +use tracing::error; + +pub fn internal_error(err: E) -> LspError { + let err_msg = err.to_string(); + error!(err_msg); + LspError { + code: tower_lsp::jsonrpc::ErrorCode::InternalError, + message: err_msg.into(), + data: None, + } +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("http error: {0}")] + Http(#[from] reqwest::Error), + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("inference api error: {0}")] + InferenceApi(crate::APIError), + #[error("You are attempting to parse a result in the API inference format when using the `tgi` backend")] + InvalidBackend, + #[error("invalid header value: {0}")] + InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue), + #[error("invalid repository id")] + InvalidRepositoryId, + #[error("invalid tokenizer path")] + InvalidTokenizerPath, + #[error("ollama error: {0}")] + Ollama(crate::APIError), + #[error("openai error: {0}")] + OpenAI(crate::backend::OpenAIError), + #[error("index out of bounds: {0}")] + OutOfBoundIndexing(usize), + #[error("line out of bounds: {0}")] + OutOfBoundLine(usize), + #[error("slice out of bounds: {0}..{1}")] + OutOfBoundSlice(usize, usize), + #[error("rope error: {0}")] + Rope(#[from] ropey::Error), + #[error("serde json error: {0}")] + SerdeJson(#[from] serde_json::Error), + #[error("tgi error: {0}")] + Tgi(crate::APIError), + #[error("tree-sitter language error: {0}")] + TreeSitterLanguage(#[from] tree_sitter::LanguageError), + #[error("tokenizer error: {0}")] + Tokenizer(#[from] tokenizers::Error), + #[error("tokio join error: {0}")] + TokioJoin(#[from] tokio::task::JoinError), + #[error("unknown backend: {0}")] + UnknownBackend(String), +} + +pub type Result = std::result::Result; + +impl From for LspError { + fn from(err: Error) -> Self { + internal_error(err) + } +} diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index fcd7d3c..c6f2c54 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,5 +1,3 @@ -use backend::{build_body, build_headers, parse_generations, Backend}; -use document::Document; use ropey::Rope; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::{Map, Value}; @@ -11,7 +9,7 @@ use std::time::{Duration, Instant}; use tokenizers::Tokenizer; use tokio::io::AsyncWriteExt; use tokio::sync::RwLock; -use tower_lsp::jsonrpc::{Error, Result}; +use tower_lsp::jsonrpc::Result as LspResult; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer, LspService, Server}; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -19,8 +17,13 @@ use tracing_appender::rolling; use tracing_subscriber::EnvFilter; use uuid::Uuid; +use crate::backend::{build_body, build_headers, parse_generations, Backend}; +use crate::document::Document; +use crate::error::{internal_error, Error, Result}; + mod backend; mod document; +mod error; mod language_id; const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); @@ -29,10 +32,10 @@ pub const VERSION: &str = env!("CARGO_PKG_VERSION"); const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co"; fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result { - Ok(rope.try_line_to_char(row).map_err(internal_error)? + Ok(rope.try_line_to_char(row)? + col.min( rope.get_line(row.min(rope.len_lines().saturating_sub(1))) - .ok_or_else(|| internal_error(format!("failed to find line at {row}")))? + .ok_or(Error::OutOfBoundLine(row))? .len_chars() .saturating_sub(1), )) @@ -80,16 +83,12 @@ fn should_complete(document: &Document, position: Position) -> Result Result &str { + &self.error + } +} + impl Display for APIError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.error) @@ -297,16 +295,6 @@ struct CompletionResult { completions: Vec, } -pub fn internal_error(err: E) -> Error { - let err_msg = err.to_string(); - error!(err_msg); - Error { - code: tower_lsp::jsonrpc::ErrorCode::InternalError, - message: err_msg.into(), - data: None, - } -} - fn build_prompt( pos: Position, text: &Rope, @@ -335,10 +323,7 @@ fn build_prompt( if let Some(before_line) = before_line { let before_line = before_line.to_string(); let tokens = if let Some(tokenizer) = tokenizer.clone() { - tokenizer - .encode(before_line.clone(), false) - .map_err(internal_error)? - .len() + tokenizer.encode(before_line.clone(), false)?.len() } else { before_line.len() }; @@ -351,10 +336,7 @@ fn build_prompt( if let Some(after_line) = after_line { let after_line = after_line.to_string(); let tokens = if let Some(tokenizer) = tokenizer.clone() { - tokenizer - .encode(after_line.clone(), false) - .map_err(internal_error)? - .len() + tokenizer.encode(after_line.clone(), false)?.len() } else { after_line.len() }; @@ -390,10 +372,7 @@ fn build_prompt( } let line = line.to_string(); let tokens = if let Some(tokenizer) = tokenizer.clone() { - tokenizer - .encode(line.clone(), false) - .map_err(internal_error)? - .len() + tokenizer.encode(line.clone(), false)?.len() } else { line.len() }; @@ -424,22 +403,18 @@ async fn request_completion( .json(&json) .headers(headers) .send() - .await - .map_err(internal_error)?; + .await?; let model = ¶ms.model; - let generations = parse_generations( - ¶ms.backend, - res.text().await.map_err(internal_error)?.as_str(), - ); + let generations = parse_generations(¶ms.backend, res.text().await?.as_str())?; let time = t.elapsed().as_millis(); info!( model, compute_generations_ms = time, - generations = serde_json::to_string(&generations).map_err(internal_error)?, + generations = serde_json::to_string(&generations)?, "{model} computed generations in {time} ms" ); - generations + Ok(generations) } fn format_generations( @@ -482,22 +457,19 @@ async fn download_tokenizer_file( if to.as_ref().exists() { return Ok(()); } - tokio::fs::create_dir_all( - to.as_ref() - .parent() - .ok_or_else(|| internal_error("invalid tokenizer path"))?, - ) - .await - .map_err(internal_error)?; + tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?; let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; let mut file = tokio::fs::OpenOptions::new() .write(true) .create(true) .open(to) - .await - .map_err(internal_error)?; + .await?; let http_client = http_client.clone(); let url = url.to_owned(); + // TODO: + // - create oneshot channel to send result of tokenizer download to display error message + // to user? + // - retry logic? tokio::spawn(async move { let res = match http_client.get(url).headers(headers).send().await { Ok(res) => res, @@ -527,8 +499,7 @@ async fn download_tokenizer_file( } }; }) - .await - .map_err(internal_error)?; + .await?; Ok(()) } @@ -556,7 +527,14 @@ async fn get_tokenizer( repository, api_token, } => { - let path = cache_dir.as_ref().join(repository).join("tokenizer.json"); + let (org, repo) = repository + .split_once('/') + .ok_or(Error::InvalidRepositoryId)?; + let path = cache_dir + .as_ref() + .join(org) + .join(repo) + .join("tokenizer.json"); let url = format!("https://huggingface.co/{repository}/resolve/main/tokenizer.json"); download_tokenizer_file(http_client, &url, api_token.as_ref(), &path, ide).await?; @@ -597,7 +575,7 @@ fn build_url(model: &str) -> String { } impl LlmService { - async fn get_completions(&self, params: CompletionParams) -> Result { + async fn get_completions(&self, params: CompletionParams) -> LspResult { let request_id = Uuid::new_v4(); let span = info_span!("completion_request", %request_id); async move { @@ -669,7 +647,7 @@ impl LlmService { }.instrument(span).await } - async fn accept_completion(&self, accepted: AcceptedCompletion) -> Result<()> { + async fn accept_completion(&self, accepted: AcceptedCompletion) -> LspResult<()> { info!( request_id = %accepted.request_id, accepted_position = accepted.accepted_completion, @@ -679,7 +657,7 @@ impl LlmService { Ok(()) } - async fn reject_completion(&self, rejected: RejectedCompletion) -> Result<()> { + async fn reject_completion(&self, rejected: RejectedCompletion) -> LspResult<()> { info!( request_id = %rejected.request_id, shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?, @@ -691,7 +669,7 @@ impl LlmService { #[tower_lsp::async_trait] impl LanguageServer for LlmService { - async fn initialize(&self, params: InitializeParams) -> Result { + async fn initialize(&self, params: InitializeParams) -> LspResult { *self.workspace_folders.write().await = params.workspace_folders; Ok(InitializeResult { server_info: Some(ServerInfo { @@ -743,7 +721,7 @@ impl LanguageServer for LlmService { } // ignore the output scheme - if uri.starts_with("output:") { + if params.text_document.uri.scheme() == "output" { return; } @@ -786,7 +764,7 @@ impl LanguageServer for LlmService { info!("{uri} closed"); } - async fn shutdown(&self) -> Result<()> { + async fn shutdown(&self) -> LspResult<()> { debug!("shutdown"); Ok(()) } From 54b25a873125aecd4e0e347858b39881361b3329 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Wed, 7 Feb 2024 15:06:58 +0100 Subject: [PATCH 03/10] feat: add socket connection (#72) --- Cargo.lock | 17 +++++++++-------- crates/llm-ls/Cargo.toml | 9 +++++---- crates/llm-ls/src/main.rs | 34 +++++++++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3129b6c..bd3f737 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -974,6 +974,7 @@ checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" name = "llm-ls" version = "0.4.0" dependencies = [ + "clap", "home", "reqwest", "ropey", @@ -2063,9 +2064,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" -version = "0.14.1" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9be88c795d8b9f9c4002b3a8f26a6d0876103a6f523b32ea3bac52d8560c17c" +checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812" dependencies = [ "aho-corasick", "derive_builder", @@ -2368,9 +2369,9 @@ dependencies = [ [[package]] name = "tree-sitter-erlang" -version = "0.2.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d110d62a7ae35b985d8cfbc4de6e9281c7cbf268c466e30ebb31c2d3f861141" +checksum = "93ced5145ebb17f83243bf055b74e108da7cc129e12faab4166df03f59b287f4" dependencies = [ "cc", "tree-sitter", @@ -2388,9 +2389,9 @@ dependencies = [ [[package]] name = "tree-sitter-html" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "184e6b77953a354303dc87bf5fe36558c83569ce92606e7b382a0dc1b7443443" +checksum = "017822b6bd42843c4bd67fabb834f61ce23254e866282dd93871350fd6b7fa1d" dependencies = [ "cc", "tree-sitter", @@ -2518,9 +2519,9 @@ dependencies = [ [[package]] name = "tree-sitter-swift" -version = "0.3.6" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee2dbeb101a88a1d9e4883e3fbda6c799cf676f6a1cf59e4fc3862e67e70118" +checksum = "452e6ee0a14b82a0dcd93400b8d3fe3784fdbd775191a89ef84586e5ccec6be7" dependencies = [ "cc", "tree-sitter", diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index f70e218..dbc3324 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" name = "llm-ls" [dependencies] +clap = { version = "4", features = ["derive"] } home = "0.5" ropey = { version = "1.6", default-features = false, features = [ "simd", @@ -19,7 +20,7 @@ reqwest = { version = "0.11", default-features = false, features = [ serde = { version = "1", features = ["derive"] } serde_json = "1" thiserror = "1" -tokenizers = { version = "0.14", default-features = false, features = ["onig"] } +tokenizers = { version = "0.15", default-features = false, features = ["onig"] } tokio = { version = "1", features = [ "fs", "io-std", @@ -37,9 +38,9 @@ tree-sitter-c = "0.20" tree-sitter-cpp = "0.20" tree-sitter-c-sharp = "0.20" tree-sitter-elixir = "0.1" -tree-sitter-erlang = "0.2" +tree-sitter-erlang = "0.4" tree-sitter-go = "0.20" -tree-sitter-html = "0.19" +tree-sitter-html = "0.20" tree-sitter-java = "0.20" tree-sitter-javascript = "0.20" tree-sitter-json = "0.20" @@ -52,7 +53,7 @@ tree-sitter-r = "0.19" tree-sitter-ruby = "0.20" tree-sitter-rust = "0.20" tree-sitter-scala = "0.20" -tree-sitter-swift = "0.3" +tree-sitter-swift = "0.4" tree-sitter-typescript = "0.20" [dependencies.uuid] diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index c6f2c54..cab0879 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,3 +1,4 @@ +use clap::Parser; use ropey::Rope; use serde::{Deserialize, Deserializer, Serialize}; use serde_json::{Map, Value}; @@ -8,6 +9,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tokenizers::Tokenizer; use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; use tokio::sync::RwLock; use tower_lsp::jsonrpc::Result as LspResult; use tower_lsp::lsp_types::*; @@ -770,10 +772,22 @@ impl LanguageServer for LlmService { } } +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Wether to use a tcp socket for data transfer + #[arg(long = "port")] + socket: Option, + + /// Wether to use stdio transport for data transfer, ignored because it is the default + /// behaviour + #[arg(short, long, default_value_t = true)] + stdio: bool, +} + #[tokio::main] async fn main() { - let stdin = tokio::io::stdin(); - let stdout = tokio::io::stdout(); + let args = Args::parse(); let home_dir = home::home_dir().ok_or(()).expect("failed to find home dir"); let cache_dir = home_dir.join(".cache/llm_ls"); @@ -822,5 +836,19 @@ async fn main() { .custom_method("llm-ls/rejectCompletion", LlmService::reject_completion) .finish(); - Server::new(stdin, stdout, socket).serve(service).await; + if let Some(port) = args.socket { + let addr = format!("127.0.0.1:{port}"); + let listener = TcpListener::bind(&addr) + .await + .unwrap_or_else(|_| panic!("failed to bind tcp listener to {addr}")); + let (stream, _) = listener + .accept() + .await + .unwrap_or_else(|_| panic!("failed to accept new connections on {addr}")); + let (read, write) = tokio::io::split(stream); + Server::new(read, write, socket).serve(service).await; + } else { + let (stdin, stdout) = (tokio::io::stdin(), tokio::io::stdout()); + Server::new(stdin, stdout, socket).serve(service).await; + } } From 92fc8855039da8c5e35d4afd8c122678ad500082 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Thu, 8 Feb 2024 22:43:56 +0100 Subject: [PATCH 04/10] fix: helix editor build crash (#73) --- Cargo.lock | 12 +++ crates/custom-types/Cargo.toml | 14 ++++ crates/custom-types/src/lib.rs | 2 + crates/custom-types/src/llm_ls.rs | 118 ++++++++++++++++++++++++++ crates/custom-types/src/request.rs | 32 +++++++ crates/llm-ls/Cargo.toml | 1 + crates/llm-ls/src/backend.rs | 15 +--- crates/llm-ls/src/main.rs | 121 ++++----------------------- crates/lsp-client/Cargo.toml | 1 - crates/lsp-client/src/client.rs | 5 +- crates/lsp-client/src/error.rs | 10 +++ crates/testbed/Cargo.toml | 1 + crates/testbed/holes/helix-smol.json | 2 +- crates/testbed/holes/helix.json | 2 +- crates/testbed/repositories-ci.yaml | 11 +-- crates/testbed/repositories.yaml | 11 +-- crates/testbed/src/main.rs | 29 ++++--- crates/testbed/src/types.rs | 90 -------------------- 18 files changed, 240 insertions(+), 237 deletions(-) create mode 100644 crates/custom-types/Cargo.toml create mode 100644 crates/custom-types/src/lib.rs create mode 100644 crates/custom-types/src/llm_ls.rs create mode 100644 crates/custom-types/src/request.rs delete mode 100644 crates/testbed/src/types.rs diff --git a/Cargo.lock b/Cargo.lock index bd3f737..5b4d6b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -427,6 +427,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "custom-types" +version = "0.1.0" +dependencies = [ + "lsp-types", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "darling" version = "0.14.4" @@ -975,6 +985,7 @@ name = "llm-ls" version = "0.4.0" dependencies = [ "clap", + "custom-types", "home", "reqwest", "ropey", @@ -1968,6 +1979,7 @@ version = "0.1.0" dependencies = [ "anyhow", "clap", + "custom-types", "futures", "futures-util", "home", diff --git a/crates/custom-types/Cargo.toml b/crates/custom-types/Cargo.toml new file mode 100644 index 0000000..c71324b --- /dev/null +++ b/crates/custom-types/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "custom-types" +version = "0.1.0" +edition.workspace = true +license.workspace = true +authors.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +lsp-types = "0.94" +serde = "1" +serde_json = "1" +uuid = "1" diff --git a/crates/custom-types/src/lib.rs b/crates/custom-types/src/lib.rs new file mode 100644 index 0000000..63ea557 --- /dev/null +++ b/crates/custom-types/src/lib.rs @@ -0,0 +1,2 @@ +pub mod llm_ls; +pub mod request; diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs new file mode 100644 index 0000000..10edc7c --- /dev/null +++ b/crates/custom-types/src/llm_ls.rs @@ -0,0 +1,118 @@ +use std::{fmt::Display, path::PathBuf}; + +use lsp_types::TextDocumentPositionParams; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::{Map, Value}; +use uuid::Uuid; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct AcceptCompletionParams { + pub request_id: Uuid, + pub accepted_completion: u32, + pub shown_completions: Vec, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RejectCompletionParams { + pub request_id: Uuid, + pub shown_completions: Vec, +} + +#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum Ide { + Neovim, + VSCode, + JetBrains, + Emacs, + Jupyter, + Sublime, + VisualStudio, + #[default] + Unknown, +} + +impl Display for Ide { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.serialize(f) + } +} + +fn parse_ide<'de, D>(d: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum Backend { + #[default] + HuggingFace, + Ollama, + OpenAi, + Tgi, +} + +impl Display for Backend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.serialize(f) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimParams { + pub enabled: bool, + pub prefix: String, + pub middle: String, + pub suffix: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum TokenizerConfig { + Local { + path: PathBuf, + }, + HuggingFace { + repository: String, + api_token: Option, + }, + Download { + url: String, + to: PathBuf, + }, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GetCompletionsParams { + #[serde(flatten)] + pub text_document_position: TextDocumentPositionParams, + #[serde(default)] + #[serde(deserialize_with = "parse_ide")] + pub ide: Ide, + pub fim: FimParams, + pub api_token: Option, + pub model: String, + pub backend: Backend, + pub tokens_to_clear: Vec, + pub tokenizer_config: Option, + pub context_window: usize, + pub tls_skip_verify_insecure: bool, + pub request_body: Map, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Completion { + pub generated_text: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct GetCompletionsResult { + pub request_id: Uuid, + pub completions: Vec, +} diff --git a/crates/custom-types/src/request.rs b/crates/custom-types/src/request.rs new file mode 100644 index 0000000..d9c4fe6 --- /dev/null +++ b/crates/custom-types/src/request.rs @@ -0,0 +1,32 @@ +use lsp_types::request::Request; + +use crate::llm_ls::{ + AcceptCompletionParams, GetCompletionsParams, GetCompletionsResult, RejectCompletionParams, +}; + +#[derive(Debug)] +pub enum GetCompletions {} + +impl Request for GetCompletions { + type Params = GetCompletionsParams; + type Result = GetCompletionsResult; + const METHOD: &'static str = "llm-ls/getCompletions"; +} + +#[derive(Debug)] +pub enum AcceptCompletion {} + +impl Request for AcceptCompletion { + type Params = AcceptCompletionParams; + type Result = (); + const METHOD: &'static str = "llm-ls/acceptCompletion"; +} + +#[derive(Debug)] +pub enum RejectCompletion {} + +impl Request for RejectCompletion { + type Params = RejectCompletionParams; + type Result = (); + const METHOD: &'static str = "llm-ls/rejectCompletion"; +} diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index dbc3324..64cd202 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -8,6 +8,7 @@ name = "llm-ls" [dependencies] clap = { version = "4", features = ["derive"] } +custom-types = { path = "../custom-types" } home = "0.5" ropey = { version = "1.6", default-features = false, features = [ "simd", diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index c9f18cd..76a5d7c 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,4 +1,5 @@ -use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION}; +use super::{APIError, APIResponse, Generation, NAME, VERSION}; +use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; @@ -150,17 +151,7 @@ fn parse_openai_text(text: &str) -> Result> { } } -#[derive(Debug, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub(crate) enum Backend { - #[default] - HuggingFace, - Ollama, - OpenAi, - Tgi, -} - -pub fn build_body(prompt: String, params: &CompletionParams) -> Map { +pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map { let mut body = params.request_body.clone(); match params.backend { Backend::HuggingFace | Backend::Tgi => { diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index cab0879..27441ad 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -1,7 +1,10 @@ use clap::Parser; +use custom_types::llm_ls::{ + AcceptCompletionParams, Backend, Completion, FimParams, GetCompletionsParams, + GetCompletionsResult, Ide, TokenizerConfig, +}; use ropey::Rope; -use serde::{Deserialize, Deserializer, Serialize}; -use serde_json::{Map, Value}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; use std::path::{Path, PathBuf}; @@ -19,7 +22,7 @@ use tracing_appender::rolling; use tracing_subscriber::EnvFilter; use uuid::Uuid; -use crate::backend::{build_body, build_headers, parse_generations, Backend}; +use crate::backend::{build_body, build_headers, parse_generations}; use crate::document::Document; use crate::error::{internal_error, Error, Result}; @@ -119,22 +122,6 @@ fn should_complete(document: &Document, position: Position) -> Result, - }, - Download { - url: String, - to: PathBuf, - }, -} - #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct RequestParams { @@ -145,14 +132,6 @@ pub struct RequestParams { stop_tokens: Option>, } -#[derive(Debug, Deserialize, Serialize)] -struct FimParams { - enabled: bool, - prefix: String, - middle: String, - suffix: String, -} - #[derive(Debug, Serialize)] struct APIParams { max_new_tokens: u32, @@ -225,78 +204,6 @@ struct LlmService { unauthenticated_warn_at: Arc>, } -#[derive(Debug, Deserialize, Serialize)] -struct Completion { - generated_text: String, -} - -#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum Ide { - Neovim, - VSCode, - JetBrains, - Emacs, - Jupyter, - Sublime, - VisualStudio, - #[default] - Unknown, -} - -impl Display for Ide { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.serialize(f) - } -} - -fn parse_ide<'de, D>(d: D) -> std::result::Result -where - D: Deserializer<'de>, -{ - Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -struct AcceptedCompletion { - request_id: Uuid, - accepted_completion: u32, - shown_completions: Vec, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -struct RejectedCompletion { - request_id: Uuid, - shown_completions: Vec, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct CompletionParams { - #[serde(flatten)] - text_document_position: TextDocumentPositionParams, - #[serde(default)] - #[serde(deserialize_with = "parse_ide")] - ide: Ide, - fim: FimParams, - api_token: Option, - model: String, - backend: Backend, - tokens_to_clear: Vec, - tokenizer_config: Option, - context_window: usize, - tls_skip_verify_insecure: bool, - request_body: Map, -} - -#[derive(Debug, Deserialize, Serialize)] -struct CompletionResult { - request_id: Uuid, - completions: Vec, -} - fn build_prompt( pos: Position, text: &Rope, @@ -394,7 +301,7 @@ fn build_prompt( async fn request_completion( http_client: &reqwest::Client, prompt: String, - params: &CompletionParams, + params: &GetCompletionsParams, ) -> Result> { let t = Instant::now(); @@ -577,7 +484,10 @@ fn build_url(model: &str) -> String { } impl LlmService { - async fn get_completions(&self, params: CompletionParams) -> LspResult { + async fn get_completions( + &self, + params: GetCompletionsParams, + ) -> LspResult { let request_id = Uuid::new_v4(); let span = info_span!("completion_request", %request_id); async move { @@ -592,6 +502,7 @@ impl LlmService { cursor_character = ?params.text_document_position.position.character, language_id = %document.language_id, model = params.model, + backend = %params.backend, ide = %params.ide, request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?, "received completion request for {}", @@ -611,7 +522,7 @@ impl LlmService { let completion_type = should_complete(document, params.text_document_position.position)?; info!(%completion_type, "completion type: {completion_type:?}"); if completion_type == CompletionType::Empty { - return Ok(CompletionResult { request_id, completions: vec![]}); + return Ok(GetCompletionsResult { request_id, completions: vec![]}); } let tokenizer = get_tokenizer( @@ -645,11 +556,11 @@ impl LlmService { .await?; let completions = format_generations(result, ¶ms.tokens_to_clear, completion_type); - Ok(CompletionResult { request_id, completions }) + Ok(GetCompletionsResult { request_id, completions }) }.instrument(span).await } - async fn accept_completion(&self, accepted: AcceptedCompletion) -> LspResult<()> { + async fn accept_completion(&self, accepted: AcceptCompletionParams) -> LspResult<()> { info!( request_id = %accepted.request_id, accepted_position = accepted.accepted_completion, @@ -659,7 +570,7 @@ impl LlmService { Ok(()) } - async fn reject_completion(&self, rejected: RejectedCompletion) -> LspResult<()> { + async fn reject_completion(&self, rejected: AcceptCompletionParams) -> LspResult<()> { info!( request_id = %rejected.request_id, shown_completions = serde_json::to_string(&rejected.shown_completions).map_err(internal_error)?, diff --git a/crates/lsp-client/Cargo.toml b/crates/lsp-client/Cargo.toml index 8dcd552..5cb1e60 100644 --- a/crates/lsp-client/Cargo.toml +++ b/crates/lsp-client/Cargo.toml @@ -13,4 +13,3 @@ serde = "1" serde_json = "1" tokio = { version = "1", features = ["io-util", "process"] } tracing = "0.1" - diff --git a/crates/lsp-client/src/client.rs b/crates/lsp-client/src/client.rs index 8fe3789..d6d6968 100644 --- a/crates/lsp-client/src/client.rs +++ b/crates/lsp-client/src/client.rs @@ -58,7 +58,7 @@ impl LspClient { pub async fn send_request( &self, params: R::Params, - ) -> Result { + ) -> Result { let (sender, receiver) = oneshot::channel::(); let request = self.res_queue @@ -68,7 +68,8 @@ impl LspClient { .register(R::METHOD.to_string(), params, sender); self.send(request.into()); - Ok(receiver.await?) + let (_, result) = receiver.await?.extract::()?; + Ok(result) } async fn complete_request( diff --git a/crates/lsp-client/src/error.rs b/crates/lsp-client/src/error.rs index 32f82fb..b176237 100644 --- a/crates/lsp-client/src/error.rs +++ b/crates/lsp-client/src/error.rs @@ -65,7 +65,9 @@ impl fmt::Display for ExtractError { pub enum Error { ChannelClosed(RecvError), Io(io::Error), + Extract(ExtractError), MissingBinaryPath, + Parse(String), } impl std::error::Error for Error {} @@ -76,6 +78,8 @@ impl fmt::Display for Error { Error::ChannelClosed(e) => write!(f, "Channel closed: {}", e), Error::Io(e) => write!(f, "IO error: {}", e), Error::MissingBinaryPath => write!(f, "Missing binary path"), + Error::Parse(e) => write!(f, "parse error: {}", e), + Error::Extract(e) => write!(f, "extract error: {}", e), } } } @@ -92,4 +96,10 @@ impl From for Error { } } +impl From for Error { + fn from(value: ExtractError) -> Self { + Self::Extract(value) + } +} + pub type Result = std::result::Result; diff --git a/crates/testbed/Cargo.toml b/crates/testbed/Cargo.toml index fdde7cc..05b7ea8 100644 --- a/crates/testbed/Cargo.toml +++ b/crates/testbed/Cargo.toml @@ -11,6 +11,7 @@ authors.workspace = true [dependencies] anyhow = "1" clap = { version = "4", features = ["derive"] } +custom-types = { path = "../custom-types" } futures = "0.3" futures-util = "0.3" home = "0.5" diff --git a/crates/testbed/holes/helix-smol.json b/crates/testbed/holes/helix-smol.json index 1a2b62b..261004c 100644 --- a/crates/testbed/holes/helix-smol.json +++ b/crates/testbed/holes/helix-smol.json @@ -1 +1 @@ -[{"cursor":{"line":73,"character":10},"file":"helix-core/src/chars.rs"},{"cursor":{"line":257,"character":11},"file":"helix-dap/src/types.rs"},{"cursor":{"line":39,"character":14},"file":"helix-view/src/info.rs"},{"cursor":{"line":116,"character":12},"file":"helix-term/src/ui/mod.rs"},{"cursor":{"line":1,"character":14},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":2,"character":5},"file":"helix-core/src/config.rs"},{"cursor":{"line":151,"character":14},"file":"helix-view/src/gutter.rs"},{"cursor":{"line":11,"character":10},"file":"helix-term/src/ui/lsp.rs"},{"cursor":{"line":18,"character":0},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":230,"character":3},"file":"helix-term/src/ui/markdown.rs"}] \ No newline at end of file +[{"cursor":{"line":9,"character":0},"file":"helix-core/src/increment/mod.rs"},{"cursor":{"line":47,"character":5},"file":"helix-stdx/src/env.rs"},{"cursor":{"line":444,"character":4},"file":"helix-term/src/ui/editor.rs"},{"cursor":{"line":939,"character":8},"file":"helix-tui/src/buffer.rs"},{"cursor":{"line":30,"character":6},"file":"helix-view/src/handlers.rs"},{"cursor":{"line":332,"character":0},"file":"helix-term/src/health.rs"},{"cursor":{"line":15,"character":2},"file":"helix-term/src/events.rs"},{"cursor":{"line":415,"character":2},"file":"helix-tui/src/widgets/reflow.rs"},{"cursor":{"line":316,"character":2},"file":"helix-core/src/shellwords.rs"},{"cursor":{"line":218,"character":2},"file":"helix-tui/src/backend/crossterm.rs"}] \ No newline at end of file diff --git a/crates/testbed/holes/helix.json b/crates/testbed/holes/helix.json index 3da8a85..2c144fa 100644 --- a/crates/testbed/holes/helix.json +++ b/crates/testbed/holes/helix.json @@ -1 +1 @@ -[{"cursor":{"line":330,"character":13},"file":"helix-core/src/position.rs"},{"cursor":{"line":21,"character":10},"file":"helix-term/src/lib.rs"},{"cursor":{"line":212,"character":13},"file":"helix-view/src/view.rs"},{"cursor":{"line":74,"character":8},"file":"helix-vcs/src/git.rs"},{"cursor":{"line":78,"character":7},"file":"helix-core/src/auto_pairs.rs"},{"cursor":{"line":61,"character":0},"file":"helix-term/src/ui/overlay.rs"},{"cursor":{"line":179,"character":5},"file":"helix-core/src/graphemes.rs"},{"cursor":{"line":82,"character":12},"file":"helix-tui/src/backend/test.rs"},{"cursor":{"line":486,"character":6},"file":"helix-term/src/ui/prompt.rs"},{"cursor":{"line":263,"character":8},"file":"helix-term/src/keymap/default.rs"},{"cursor":{"line":19,"character":4},"file":"helix-term/src/application.rs"},{"cursor":{"line":23,"character":5},"file":"helix-tui/src/backend/mod.rs"},{"cursor":{"line":54,"character":10},"file":"helix-term/src/ui/menu.rs"},{"cursor":{"line":9,"character":0},"file":"helix-core/src/fuzzy.rs"},{"cursor":{"line":22,"character":4},"file":"helix-view/src/info.rs"},{"cursor":{"line":58,"character":11},"file":"helix-vcs/src/lib.rs"},{"cursor":{"line":54,"character":7},"file":"helix-dap/src/client.rs"},{"cursor":{"line":177,"character":5},"file":"helix-view/src/register.rs"},{"cursor":{"line":54,"character":7},"file":"helix-core/src/increment/integer.rs"},{"cursor":{"line":53,"character":4},"file":"helix-core/src/increment/integer.rs"},{"cursor":{"line":2,"character":3},"file":"helix-core/src/lib.rs"},{"cursor":{"line":43,"character":1},"file":"helix-term/src/main.rs"},{"cursor":{"line":404,"character":13},"file":"helix-tui/src/widgets/block.rs"},{"cursor":{"line":405,"character":11},"file":"helix-dap/src/types.rs"},{"cursor":{"line":2,"character":0},"file":"helix-view/src/env.rs"},{"cursor":{"line":63,"character":3},"file":"helix-view/src/handlers/dap.rs"},{"cursor":{"line":77,"character":10},"file":"helix-tui/src/backend/crossterm.rs"},{"cursor":{"line":132,"character":13},"file":"helix-term/src/ui/markdown.rs"},{"cursor":{"line":190,"character":11},"file":"helix-tui/src/layout.rs"},{"cursor":{"line":62,"character":4},"file":"helix-core/src/auto_pairs.rs"},{"cursor":{"line":146,"character":14},"file":"helix-term/src/ui/prompt.rs"},{"cursor":{"line":280,"character":14},"file":"helix-core/src/shellwords.rs"},{"cursor":{"line":495,"character":14},"file":"helix-term/src/ui/prompt.rs"},{"cursor":{"line":274,"character":9},"file":"helix-lsp/src/transport.rs"},{"cursor":{"line":243,"character":10},"file":"helix-core/src/test.rs"},{"cursor":{"line":2,"character":10},"file":"helix-core/src/config.rs"},{"cursor":{"line":701,"character":1},"file":"helix-dap/src/types.rs"},{"cursor":{"line":67,"character":11},"file":"helix-view/src/lib.rs"},{"cursor":{"line":8,"character":4},"file":"helix-term/src/job.rs"},{"cursor":{"line":0,"character":4},"file":"helix-core/src/wrap.rs"},{"cursor":{"line":27,"character":12},"file":"helix-vcs/src/lib.rs"},{"cursor":{"line":1270,"character":0},"file":"helix-term/src/commands/lsp.rs"},{"cursor":{"line":109,"character":5},"file":"helix-tui/src/widgets/list.rs"},{"cursor":{"line":198,"character":3},"file":"helix-core/src/increment/integer.rs"},{"cursor":{"line":84,"character":10},"file":"helix-term/src/commands.rs"},{"cursor":{"line":102,"character":2},"file":"helix-view/src/base64.rs"},{"cursor":{"line":57,"character":6},"file":"helix-dap/src/types.rs"},{"cursor":{"line":7,"character":9},"file":"helix-view/src/gutter.rs"},{"cursor":{"line":99,"character":10},"file":"helix-term/src/keymap.rs"},{"cursor":{"line":317,"character":1},"file":"helix-core/src/increment/date_time.rs"},{"cursor":{"line":303,"character":13},"file":"helix-term/src/health.rs"},{"cursor":{"line":69,"character":0},"file":"helix-core/src/doc_formatter/test.rs"},{"cursor":{"line":79,"character":4},"file":"helix-tui/src/symbols.rs"},{"cursor":{"line":156,"character":14},"file":"helix-core/src/increment/date_time.rs"},{"cursor":{"line":3760,"character":14},"file":"helix-term/src/commands.rs"},{"cursor":{"line":256,"character":5},"file":"helix-view/src/theme.rs"},{"cursor":{"line":231,"character":11},"file":"helix-tui/src/widgets/list.rs"},{"cursor":{"line":27,"character":4},"file":"helix-core/src/search.rs"},{"cursor":{"line":293,"character":11},"file":"helix-core/src/auto_pairs.rs"},{"cursor":{"line":52,"character":7},"file":"helix-core/src/line_ending.rs"},{"cursor":{"line":144,"character":8},"file":"helix-core/src/comment.rs"},{"cursor":{"line":20,"character":3},"file":"helix-view/src/info.rs"},{"cursor":{"line":131,"character":12},"file":"helix-view/src/base64.rs"},{"cursor":{"line":8,"character":8},"file":"helix-view/src/lib.rs"},{"cursor":{"line":489,"character":14},"file":"helix-term/src/ui/mod.rs"},{"cursor":{"line":23,"character":7},"file":"helix-view/src/info.rs"},{"cursor":{"line":41,"character":4},"file":"helix-term/src/ui/overlay.rs"},{"cursor":{"line":48,"character":13},"file":"helix-core/src/diagnostic.rs"},{"cursor":{"line":6,"character":7},"file":"helix-core/src/lib.rs"},{"cursor":{"line":845,"character":5},"file":"helix-view/src/view.rs"},{"cursor":{"line":258,"character":2},"file":"helix-core/src/doc_formatter.rs"},{"cursor":{"line":5,"character":6},"file":"helix-tui/src/buffer.rs"},{"cursor":{"line":61,"character":4},"file":"helix-view/src/lib.rs"},{"cursor":{"line":157,"character":6},"file":"helix-term/src/ui/prompt.rs"},{"cursor":{"line":92,"character":3},"file":"helix-term/src/ui/lsp.rs"},{"cursor":{"line":128,"character":2},"file":"helix-term/src/ui/menu.rs"},{"cursor":{"line":1701,"character":11},"file":"helix-core/src/movement.rs"},{"cursor":{"line":1,"character":9},"file":"helix-view/src/env.rs"},{"cursor":{"line":330,"character":3},"file":"helix-view/src/keyboard.rs"},{"cursor":{"line":10,"character":0},"file":"helix-view/src/lib.rs"},{"cursor":{"line":625,"character":4},"file":"helix-dap/src/types.rs"},{"cursor":{"line":81,"character":11},"file":"helix-core/src/syntax.rs"},{"cursor":{"line":2268,"character":1},"file":"helix-term/src/commands/typed.rs"},{"cursor":{"line":21,"character":1},"file":"helix-core/src/fuzzy.rs"},{"cursor":{"line":57,"character":11},"file":"helix-term/src/ui/document.rs"},{"cursor":{"line":460,"character":7},"file":"helix-tui/src/text.rs"},{"cursor":{"line":7,"character":2},"file":"helix-term/src/lib.rs"},{"cursor":{"line":42,"character":12},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":39,"character":8},"file":"helix-term/src/ui/spinner.rs"},{"cursor":{"line":10,"character":2},"file":"helix-term/src/lib.rs"},{"cursor":{"line":57,"character":9},"file":"helix-vcs/src/git.rs"},{"cursor":{"line":15,"character":1},"file":"helix-tui/src/backend/test.rs"},{"cursor":{"line":109,"character":3},"file":"helix-dap/src/transport.rs"},{"cursor":{"line":119,"character":4},"file":"helix-core/src/history.rs"},{"cursor":{"line":18,"character":2},"file":"helix-core/src/path.rs"},{"cursor":{"line":13,"character":11},"file":"helix-lsp/src/snippet.rs"},{"cursor":{"line":0,"character":12},"file":"helix-loader/src/main.rs"},{"cursor":{"line":3,"character":6},"file":"helix-term/src/lib.rs"},{"cursor":{"line":49,"character":8},"file":"helix-term/src/job.rs"},{"cursor":{"line":176,"character":7},"file":"helix-parsec/src/lib.rs"}] \ No newline at end of file +[{"cursor":{"line":234,"character":14},"file":"helix-core/src/history.rs"},{"cursor":{"line":209,"character":1},"file":"helix-term/src/ui/menu.rs"},{"cursor":{"line":465,"character":0},"file":"helix-lsp/src/client.rs"},{"cursor":{"line":95,"character":2},"file":"helix-term/src/keymap/macros.rs"},{"cursor":{"line":14,"character":10},"file":"helix-term/src/ui/markdown.rs"},{"cursor":{"line":6,"character":5},"file":"helix-core/src/rope_reader.rs"},{"cursor":{"line":150,"character":7},"file":"helix-term/src/ui/document.rs"},{"cursor":{"line":48,"character":5},"file":"helix-view/src/macros.rs"},{"cursor":{"line":1582,"character":11},"file":"helix-term/src/commands.rs"},{"cursor":{"line":6,"character":8},"file":"helix-core/src/comment.rs"},{"cursor":{"line":5,"character":4},"file":"helix-core/src/rope_reader.rs"},{"cursor":{"line":88,"character":3},"file":"helix-tui/src/terminal.rs"},{"cursor":{"line":15,"character":14},"file":"helix-term/src/events.rs"},{"cursor":{"line":366,"character":3},"file":"helix-view/src/clipboard.rs"},{"cursor":{"line":7,"character":7},"file":"helix-view/src/events.rs"},{"cursor":{"line":152,"character":8},"file":"helix-dap/src/client.rs"},{"cursor":{"line":290,"character":3},"file":"helix-core/src/match_brackets.rs"},{"cursor":{"line":119,"character":6},"file":"helix-view/src/register.rs"},{"cursor":{"line":113,"character":2},"file":"helix-term/src/keymap/macros.rs"},{"cursor":{"line":34,"character":7},"file":"helix-core/src/rope_reader.rs"},{"cursor":{"line":37,"character":13},"file":"helix-core/src/object.rs"},{"cursor":{"line":209,"character":9},"file":"helix-term/src/keymap.rs"},{"cursor":{"line":4,"character":11},"file":"helix-term/build.rs"},{"cursor":{"line":31,"character":7},"file":"helix-core/src/rope_reader.rs"},{"cursor":{"line":1,"character":7},"file":"helix-stdx/src/lib.rs"},{"cursor":{"line":324,"character":0},"file":"helix-loader/src/lib.rs"},{"cursor":{"line":231,"character":13},"file":"helix-term/src/ui/mod.rs"},{"cursor":{"line":134,"character":7},"file":"helix-tui/src/terminal.rs"},{"cursor":{"line":121,"character":0},"file":"helix-tui/src/widgets/paragraph.rs"},{"cursor":{"line":276,"character":10},"file":"helix-core/src/match_brackets.rs"},{"cursor":{"line":565,"character":0},"file":"helix-core/src/selection.rs"},{"cursor":{"line":579,"character":8},"file":"helix-view/src/input.rs"},{"cursor":{"line":63,"character":5},"file":"helix-term/src/ui/text.rs"},{"cursor":{"line":6,"character":8},"file":"helix-term/src/ui/info.rs"},{"cursor":{"line":214,"character":9},"file":"helix-tui/src/widgets/block.rs"},{"cursor":{"line":270,"character":3},"file":"helix-term/src/ui/menu.rs"},{"cursor":{"line":130,"character":7},"file":"helix-term/src/ui/editor.rs"},{"cursor":{"line":8,"character":12},"file":"helix-core/src/search.rs"},{"cursor":{"line":404,"character":10},"file":"helix-tui/src/widgets/reflow.rs"},{"cursor":{"line":6,"character":2},"file":"helix-core/src/object.rs"},{"cursor":{"line":91,"character":14},"file":"helix-term/src/args.rs"},{"cursor":{"line":195,"character":14},"file":"helix-term/src/ui/prompt.rs"},{"cursor":{"line":118,"character":12},"file":"helix-vcs/src/git.rs"},{"cursor":{"line":906,"character":13},"file":"helix-lsp/src/client.rs"},{"cursor":{"line":7,"character":5},"file":"helix-view/src/events.rs"},{"cursor":{"line":239,"character":6},"file":"helix-core/src/shellwords.rs"},{"cursor":{"line":52,"character":0},"file":"helix-event/src/redraw.rs"},{"cursor":{"line":29,"character":12},"file":"helix-loader/src/grammar.rs"},{"cursor":{"line":603,"character":2},"file":"helix-core/src/selection.rs"},{"cursor":{"line":252,"character":10},"file":"helix-tui/src/layout.rs"},{"cursor":{"line":27,"character":13},"file":"helix-view/src/info.rs"},{"cursor":{"line":733,"character":11},"file":"helix-term/src/ui/picker.rs"},{"cursor":{"line":29,"character":0},"file":"helix-term/src/handlers.rs"},{"cursor":{"line":56,"character":3},"file":"helix-loader/src/grammar.rs"},{"cursor":{"line":249,"character":0},"file":"helix-term/src/ui/markdown.rs"},{"cursor":{"line":267,"character":4},"file":"helix-view/src/register.rs"},{"cursor":{"line":3143,"character":2},"file":"helix-term/src/commands/typed.rs"},{"cursor":{"line":73,"character":2},"file":"helix-core/src/increment/date_time.rs"},{"cursor":{"line":181,"character":0},"file":"helix-dap/src/types.rs"},{"cursor":{"line":28,"character":6},"file":"helix-term/src/ui/popup.rs"},{"cursor":{"line":15,"character":7},"file":"helix-event/src/cancel.rs"},{"cursor":{"line":40,"character":7},"file":"helix-event/src/lib.rs"},{"cursor":{"line":132,"character":2},"file":"helix-tui/src/widgets/table.rs"},{"cursor":{"line":25,"character":0},"file":"helix-view/src/info.rs"},{"cursor":{"line":141,"character":7},"file":"helix-tui/src/backend/crossterm.rs"},{"cursor":{"line":251,"character":1},"file":"helix-dap/src/client.rs"},{"cursor":{"line":327,"character":8},"file":"helix-term/src/handlers/completion.rs"},{"cursor":{"line":230,"character":7},"file":"helix-view/src/input.rs"},{"cursor":{"line":68,"character":0},"file":"helix-view/src/gutter.rs"},{"cursor":{"line":23,"character":11},"file":"helix-stdx/src/rope.rs"},{"cursor":{"line":639,"character":10},"file":"helix-view/src/tree.rs"},{"cursor":{"line":104,"character":3},"file":"helix-lsp/src/transport.rs"},{"cursor":{"line":117,"character":14},"file":"helix-core/src/doc_formatter/test.rs"},{"cursor":{"line":129,"character":4},"file":"helix-core/src/doc_formatter/test.rs"},{"cursor":{"line":22,"character":1},"file":"helix-tui/src/backend/mod.rs"},{"cursor":{"line":297,"character":4},"file":"helix-loader/src/lib.rs"},{"cursor":{"line":1884,"character":13},"file":"helix-view/src/document.rs"},{"cursor":{"line":62,"character":11},"file":"helix-vcs/src/diff/worker.rs"},{"cursor":{"line":42,"character":4},"file":"helix-event/src/debounce.rs"},{"cursor":{"line":371,"character":4},"file":"helix-core/src/position.rs"},{"cursor":{"line":248,"character":2},"file":"helix-core/src/line_ending.rs"},{"cursor":{"line":1,"character":10},"file":"helix-core/src/diagnostic.rs"},{"cursor":{"line":10,"character":5},"file":"helix-term/src/events.rs"},{"cursor":{"line":885,"character":13},"file":"helix-term/src/commands.rs"},{"cursor":{"line":297,"character":8},"file":"helix-view/src/gutter.rs"},{"cursor":{"line":152,"character":12},"file":"helix-core/src/diff.rs"},{"cursor":{"line":228,"character":4},"file":"helix-tui/src/widgets/table.rs"},{"cursor":{"line":25,"character":4},"file":"helix-view/src/lib.rs"},{"cursor":{"line":8,"character":0},"file":"helix-loader/src/main.rs"},{"cursor":{"line":80,"character":1},"file":"helix-event/src/hook.rs"},{"cursor":{"line":284,"character":13},"file":"helix-term/src/handlers/signature_help.rs"},{"cursor":{"line":7,"character":8},"file":"helix-view/src/graphics.rs"},{"cursor":{"line":214,"character":1},"file":"helix-core/src/increment/integer.rs"},{"cursor":{"line":19,"character":3},"file":"helix-loader/src/config.rs"},{"cursor":{"line":57,"character":0},"file":"helix-lsp/src/file_operations.rs"},{"cursor":{"line":16,"character":7},"file":"helix-loader/src/config.rs"},{"cursor":{"line":0,"character":6},"file":"helix-core/src/increment/mod.rs"},{"cursor":{"line":41,"character":2},"file":"helix-tui/src/widgets/mod.rs"},{"cursor":{"line":50,"character":12},"file":"helix-term/src/ui/overlay.rs"},{"cursor":{"line":173,"character":14},"file":"helix-tui/src/widgets/list.rs"}] \ No newline at end of file diff --git a/crates/testbed/repositories-ci.yaml b/crates/testbed/repositories-ci.yaml index 0c1bd21..3d49f1a 100644 --- a/crates/testbed/repositories-ci.yaml +++ b/crates/testbed/repositories-ci.yaml @@ -6,11 +6,12 @@ fim: middle: suffix: model: bigcode/starcoder -request_params: - maxNewTokens: 150 +backend: huggingface +request_body: + max_new_tokens: 150 temperature: 0.2 - doSample: true - topP: 0.95 + do_sample: true + top_p: 0.95 tls_skip_verify_insecure: false tokenizer_config: repository: bigcode/starcoder @@ -202,7 +203,7 @@ repositories: type: github owner: helix-editor name: helix - revision: ae6a0a9cfd377fbfa494760282498cf2ca322782 + revision: a1272bdb17a63361342a318982e46129d558743c exclude_paths: - .cargo - .github diff --git a/crates/testbed/repositories.yaml b/crates/testbed/repositories.yaml index 4418993..1ec7dfb 100644 --- a/crates/testbed/repositories.yaml +++ b/crates/testbed/repositories.yaml @@ -6,11 +6,12 @@ fim: middle: suffix: model: bigcode/starcoder -request_params: - maxNewTokens: 150 +backend: huggingface +request_body: + max_new_tokens: 150 temperature: 0.2 - doSample: true - topP: 0.95 + do_sample: true + top_p: 0.95 tls_skip_verify_insecure: false tokenizer_config: repository: bigcode/starcoder @@ -202,7 +203,7 @@ repositories: type: github owner: helix-editor name: helix - revision: ae6a0a9cfd377fbfa494760282498cf2ca322782 + revision: a1272bdb17a63361342a318982e46129d558743c exclude_paths: - .cargo - .github diff --git a/crates/testbed/src/main.rs b/crates/testbed/src/main.rs index 2169b80..fa7a543 100644 --- a/crates/testbed/src/main.rs +++ b/crates/testbed/src/main.rs @@ -10,9 +10,13 @@ use std::{ use anyhow::anyhow; use clap::Parser; +use custom_types::{ + llm_ls::{Backend, FimParams, GetCompletionsParams, Ide, TokenizerConfig}, + request::GetCompletions, +}; use futures_util::{stream::FuturesUnordered, StreamExt, TryStreamExt}; use lang::Language; -use lsp_client::{client::LspClient, error::ExtractError, msg::RequestId, server::Server}; +use lsp_client::{client::LspClient, error::ExtractError, server::Server}; use lsp_types::{ DidOpenTextDocumentParams, InitializeParams, TextDocumentIdentifier, TextDocumentItem, TextDocumentPositionParams, @@ -20,6 +24,7 @@ use lsp_types::{ use ropey::Rope; use runner::Runner; use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; use tempfile::TempDir; use tokio::{ fs::{self, read_to_string, File, OpenOptions}, @@ -32,19 +37,11 @@ use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing_subscriber::EnvFilter; use url::Url; -use crate::{ - holes_generator::generate_holes, - runner::run_test, - types::{ - FimParams, GetCompletions, GetCompletionsParams, GetCompletionsResult, Ide, RequestParams, - TokenizerConfig, - }, -}; +use crate::{holes_generator::generate_holes, runner::run_test}; mod holes_generator; mod lang; mod runner; -mod types; /// Testbed runs llm-ls' code completion to measure its performance #[derive(Parser, Debug)] @@ -201,11 +198,12 @@ struct RepositoriesConfig { context_window: usize, fim: FimParams, model: String, - request_params: RequestParams, + backend: Backend, repositories: Vec, tls_skip_verify_insecure: bool, tokenizer_config: Option, tokens_to_clear: Vec, + request_body: Map, } struct HoleCompletionResult { @@ -463,10 +461,11 @@ async fn complete_holes( context_window, fim, model, - request_params, + backend, tls_skip_verify_insecure, tokenizer_config, tokens_to_clear, + request_body, .. } = repos_config; async move { @@ -516,14 +515,14 @@ async fn complete_holes( }, }, ); - let response = client + let result = client .send_request::(GetCompletionsParams { api_token: api_token.clone(), context_window, fim: fim.clone(), ide: Ide::default(), model: model.clone(), - request_params: request_params.clone(), + backend, text_document_position: TextDocumentPositionParams { position: hole.cursor, text_document: TextDocumentIdentifier { uri }, @@ -531,9 +530,9 @@ async fn complete_holes( tls_skip_verify_insecure, tokens_to_clear: tokens_to_clear.clone(), tokenizer_config: tokenizer_config.clone(), + request_body: request_body.clone(), }) .await?; - let (_, result): (RequestId, GetCompletionsResult) = response.extract()?; file_content.insert(hole_start, &result.completions[0].generated_text); let mut file = OpenOptions::new() diff --git a/crates/testbed/src/types.rs b/crates/testbed/src/types.rs deleted file mode 100644 index c75da6d..0000000 --- a/crates/testbed/src/types.rs +++ /dev/null @@ -1,90 +0,0 @@ -use std::path::PathBuf; - -use lsp_types::{request::Request, TextDocumentPositionParams}; -use serde::{Deserialize, Deserializer, Serialize}; -use uuid::Uuid; - -#[derive(Debug)] -pub(crate) enum GetCompletions {} - -impl Request for GetCompletions { - type Params = GetCompletionsParams; - type Result = GetCompletionsResult; - const METHOD: &'static str = "llm-ls/getCompletions"; -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct RequestParams { - pub(crate) max_new_tokens: u32, - pub(crate) temperature: f32, - pub(crate) do_sample: bool, - pub(crate) top_p: f32, - pub(crate) stop_tokens: Option>, -} - -#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub(crate) enum Ide { - Neovim, - VSCode, - JetBrains, - Emacs, - Jupyter, - Sublime, - VisualStudio, - #[default] - Unknown, -} - -fn parse_ide<'de, D>(d: D) -> std::result::Result -where - D: Deserializer<'de>, -{ - Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub(crate) struct FimParams { - pub(crate) enabled: bool, - pub(crate) prefix: String, - pub(crate) middle: String, - pub(crate) suffix: String, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub(crate) enum TokenizerConfig { - Local { path: PathBuf }, - HuggingFace { repository: String }, - Download { url: String, to: PathBuf }, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct GetCompletionsParams { - #[serde(flatten)] - pub(crate) text_document_position: TextDocumentPositionParams, - pub(crate) request_params: RequestParams, - #[serde(default)] - #[serde(deserialize_with = "parse_ide")] - pub(crate) ide: Ide, - pub(crate) fim: FimParams, - pub(crate) api_token: Option, - pub(crate) model: String, - pub(crate) tokens_to_clear: Vec, - pub(crate) tokenizer_config: Option, - pub(crate) context_window: usize, - pub(crate) tls_skip_verify_insecure: bool, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub(crate) struct Completion { - pub(crate) generated_text: String, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub(crate) struct GetCompletionsResult { - request_id: Uuid, - pub(crate) completions: Vec, -} From 4891468c1ab785658ffc881405e3203d27a20b3a Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Fri, 9 Feb 2024 18:42:41 +0100 Subject: [PATCH 05/10] feat: update backend & model parameter (#74) * feat: update backend & model parameter * fix: add `stream: false` in request body for ollama & openai --- README.md | 8 ++++-- crates/custom-types/src/llm_ls.rs | 47 ++++++++++++++++++++++++------- crates/llm-ls/src/backend.rs | 40 +++++++++++++++----------- crates/llm-ls/src/main.rs | 28 ++++++++++-------- 4 files changed, 81 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 507b50f..9ac5857 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,16 @@ It also makes sure that you are within the context window of the model by tokeni Gathers information about requests and completions that can enable retraining. -Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file if you set the log level to `info`. +Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file (`~/.cache/llm_ls/llm-ls.log`) if you set the log level to `info`. ### Completion **llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion). +### Multiple backends + +**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server). + ## Compatible extensions - [x] [llm.nvim](https://github.com/huggingface/llm.nvim) @@ -38,6 +42,4 @@ Note that **llm-ls** does not export any data anywhere (other than setting a use - add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt - add context window fill percent or change context_window to `max_tokens` - filter bad suggestions (repetitive, same as below, etc) -- support for ollama -- support for llama.cpp - oltp traces ? diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs index 10edc7c..c1a81d2 100644 --- a/crates/custom-types/src/llm_ls.rs +++ b/crates/custom-types/src/llm_ls.rs @@ -5,6 +5,8 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_json::{Map, Value}; use uuid::Uuid; +const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co"; + #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct AcceptCompletionParams { @@ -47,19 +49,42 @@ where Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) } -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] +fn hf_default_url() -> String { + format!("https://{HF_INFERENCE_API_HOSTNAME}") +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase", tag = "backend")] pub enum Backend { - #[default] - HuggingFace, - Ollama, - OpenAi, - Tgi, + HuggingFace { + #[serde(default = "hf_default_url")] + url: String, + }, + Ollama { + url: String, + }, + OpenAi { + url: String, + }, + Tgi { + url: String, + }, } -impl Display for Backend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.serialize(f) +impl Default for Backend { + fn default() -> Self { + Self::HuggingFace { + url: hf_default_url(), + } + } +} + +impl Backend { + pub fn is_using_inference_api(&self) -> bool { + match self { + Self::HuggingFace { url } => url.contains(HF_INFERENCE_API_HOSTNAME), + _ => false, + } } } @@ -98,11 +123,13 @@ pub struct GetCompletionsParams { pub fim: FimParams, pub api_token: Option, pub model: String, + #[serde(flatten)] pub backend: Backend, pub tokens_to_clear: Vec, pub tokenizer_config: Option, pub context_window: usize, pub tls_skip_verify_insecure: bool, + #[serde(default)] pub request_body: Map, } diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index 76a5d7c..357a3d5 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,5 +1,5 @@ use super::{APIError, APIResponse, Generation, NAME, VERSION}; -use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide}; +use custom_types::llm_ls::{Backend, Ide}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; @@ -151,33 +151,39 @@ fn parse_openai_text(text: &str) -> Result> { } } -pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map { - let mut body = params.request_body.clone(); - match params.backend { - Backend::HuggingFace | Backend::Tgi => { - body.insert("inputs".to_string(), Value::String(prompt)) +pub fn build_body( + backend: &Backend, + model: String, + prompt: String, + mut request_body: Map, +) -> Map { + match backend { + Backend::HuggingFace { .. } | Backend::Tgi { .. } => { + request_body.insert("inputs".to_owned(), Value::String(prompt)); } - Backend::Ollama | Backend::OpenAi => { - body.insert("prompt".to_string(), Value::String(prompt)) + Backend::Ollama { .. } | Backend::OpenAi { .. } => { + request_body.insert("prompt".to_owned(), Value::String(prompt)); + request_body.insert("model".to_owned(), Value::String(model)); + request_body.insert("stream".to_owned(), Value::Bool(false)); } }; - body + request_body } pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result { match backend { - Backend::HuggingFace => build_api_headers(api_token, ide), - Backend::Ollama => Ok(build_ollama_headers()), - Backend::OpenAi => build_openai_headers(api_token, ide), - Backend::Tgi => build_tgi_headers(api_token, ide), + Backend::HuggingFace { .. } => build_api_headers(api_token, ide), + Backend::Ollama { .. } => Ok(build_ollama_headers()), + Backend::OpenAi { .. } => build_openai_headers(api_token, ide), + Backend::Tgi { .. } => build_tgi_headers(api_token, ide), } } pub fn parse_generations(backend: &Backend, text: &str) -> Result> { match backend { - Backend::HuggingFace => parse_api_text(text), - Backend::Ollama => parse_ollama_text(text), - Backend::OpenAi => parse_openai_text(text), - Backend::Tgi => parse_tgi_text(text), + Backend::HuggingFace { .. } => parse_api_text(text), + Backend::Ollama { .. } => parse_ollama_text(text), + Backend::OpenAi { .. } => parse_openai_text(text), + Backend::Tgi { .. } => parse_tgi_text(text), } } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 27441ad..17be8aa 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -34,7 +34,6 @@ mod language_id; const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); pub const NAME: &str = "llm-ls"; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); -const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co"; fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result { Ok(rope.try_line_to_char(row)? @@ -305,10 +304,15 @@ async fn request_completion( ) -> Result> { let t = Instant::now(); - let json = build_body(prompt, params); + let json = build_body( + ¶ms.backend, + params.model.clone(), + prompt, + params.request_body.clone(), + ); let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?; let res = http_client - .post(build_url(¶ms.model)) + .post(build_url(params.backend.clone(), ¶ms.model)) .json(&json) .headers(headers) .send() @@ -367,7 +371,7 @@ async fn download_tokenizer_file( return Ok(()); } tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?; - let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; + let headers = build_headers(&Backend::default(), api_token, ide)?; let mut file = tokio::fs::OpenOptions::new() .write(true) .create(true) @@ -475,11 +479,12 @@ async fn get_tokenizer( } } -fn build_url(model: &str) -> String { - if model.starts_with("http://") || model.starts_with("https://") { - model.to_owned() - } else { - format!("https://{HF_INFERENCE_API_HOSTNAME}/models/{model}") +fn build_url(backend: Backend, model: &str) -> String { + match backend { + Backend::HuggingFace { url } => format!("{url}/models/{model}"), + Backend::Ollama { url } => url, + Backend::OpenAi { url } => url, + Backend::Tgi { url } => url, } } @@ -502,14 +507,13 @@ impl LlmService { cursor_character = ?params.text_document_position.position.character, language_id = %document.language_id, model = params.model, - backend = %params.backend, + backend = ?params.backend, ide = %params.ide, request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?, "received completion request for {}", params.text_document_position.text_document.uri ); - let is_using_inference_api = matches!(params.backend, Backend::HuggingFace); - if params.api_token.is_none() && is_using_inference_api { + if params.api_token.is_none() && params.backend.is_using_inference_api() { let now = Instant::now(); let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT { From 8926969265990202e3b399955364cc090df389f4 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 12 Feb 2024 11:07:31 +0100 Subject: [PATCH 06/10] docs(README): replace llama.cpp link with python bindings --- README.md | 2 +- crates/custom-types/src/llm_ls.rs | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9ac5857..5aa2211 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Note that **llm-ls** does not export any data anywhere (other than setting a use ### Multiple backends -**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server). +**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like the [python llama.cpp server bindings](https://github.com/abetlen/llama-cpp-python?tab=readme-ov-file#openai-compatible-web-server). ## Compatible extensions diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs index c1a81d2..737acd1 100644 --- a/crates/custom-types/src/llm_ls.rs +++ b/crates/custom-types/src/llm_ls.rs @@ -60,6 +60,10 @@ pub enum Backend { #[serde(default = "hf_default_url")] url: String, }, + // TODO: + // LlamaCpp { + // url: String, + // }, Ollama { url: String, }, From 86043ce3af6b25a0c7bc56e152d2ac536f113b3d Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 12 Feb 2024 15:58:56 +0100 Subject: [PATCH 07/10] fix: change visiblity of internal functions --- crates/llm-ls/src/backend.rs | 10 +++++++--- crates/llm-ls/src/error.rs | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index 357a3d5..da15312 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -151,7 +151,7 @@ fn parse_openai_text(text: &str) -> Result> { } } -pub fn build_body( +pub(crate) fn build_body( backend: &Backend, model: String, prompt: String, @@ -170,7 +170,11 @@ pub fn build_body( request_body } -pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result { +pub(crate) fn build_headers( + backend: &Backend, + api_token: Option<&String>, + ide: Ide, +) -> Result { match backend { Backend::HuggingFace { .. } => build_api_headers(api_token, ide), Backend::Ollama { .. } => Ok(build_ollama_headers()), @@ -179,7 +183,7 @@ pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> } } -pub fn parse_generations(backend: &Backend, text: &str) -> Result> { +pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result> { match backend { Backend::HuggingFace { .. } => parse_api_text(text), Backend::Ollama { .. } => parse_ollama_text(text), diff --git a/crates/llm-ls/src/error.rs b/crates/llm-ls/src/error.rs index aa51588..d6f305e 100644 --- a/crates/llm-ls/src/error.rs +++ b/crates/llm-ls/src/error.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use tower_lsp::jsonrpc::Error as LspError; use tracing::error; -pub fn internal_error(err: E) -> LspError { +pub(crate) fn internal_error(err: E) -> LspError { let err_msg = err.to_string(); error!(err_msg); LspError { @@ -55,7 +55,7 @@ pub enum Error { UnknownBackend(String), } -pub type Result = std::result::Result; +pub(crate) type Result = std::result::Result; impl From for LspError { fn from(err: Error) -> Self { From 4437c0c8a61c6629281498b0212787270b2243e0 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 12 Feb 2024 20:07:59 +0100 Subject: [PATCH 08/10] fix: deserialize `url` null value w/ default if `backend: huggingface` (#75) --- crates/custom-types/src/llm_ls.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs index 737acd1..9b402f9 100644 --- a/crates/custom-types/src/llm_ls.rs +++ b/crates/custom-types/src/llm_ls.rs @@ -46,7 +46,14 @@ fn parse_ide<'de, D>(d: D) -> std::result::Result where D: Deserializer<'de>, { - Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) + Option::deserialize(d).map(|b| b.unwrap_or_default()) +} + +fn parse_url<'de, D>(d: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + Option::deserialize(d).map(|b| b.unwrap_or_else(hf_default_url)) } fn hf_default_url() -> String { @@ -57,7 +64,7 @@ fn hf_default_url() -> String { #[serde(rename_all = "lowercase", tag = "backend")] pub enum Backend { HuggingFace { - #[serde(default = "hf_default_url")] + #[serde(default = "hf_default_url", deserialize_with = "parse_url")] url: String, }, // TODO: From fe1f6aab477db580cf91e64c83821ce69771b221 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Tue, 13 Feb 2024 11:02:50 +0100 Subject: [PATCH 09/10] fix: always set `return_full_text` to false for better UX (#78) --- Cargo.lock | 2 +- crates/llm-ls/Cargo.toml | 2 +- crates/llm-ls/src/backend.rs | 8 +++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5b4d6b3..febca6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -982,7 +982,7 @@ checksum = "da2479e8c062e40bf0066ffa0bc823de0a9368974af99c9f6df941d2c231e03f" [[package]] name = "llm-ls" -version = "0.4.0" +version = "0.5.2" dependencies = [ "clap", "custom-types", diff --git a/crates/llm-ls/Cargo.toml b/crates/llm-ls/Cargo.toml index 64cd202..72f1f5b 100644 --- a/crates/llm-ls/Cargo.toml +++ b/crates/llm-ls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "llm-ls" -version = "0.4.0" +version = "0.5.2" edition = "2021" [[bin]] diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index da15312..5d0e68f 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -2,7 +2,7 @@ use super::{APIError, APIResponse, Generation, NAME, VERSION}; use custom_types::llm_ls::{Backend, Ide}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; -use serde_json::{Map, Value}; +use serde_json::{json, Map, Value}; use std::fmt::Display; use crate::error::{Error, Result}; @@ -160,6 +160,12 @@ pub(crate) fn build_body( match backend { Backend::HuggingFace { .. } | Backend::Tgi { .. } => { request_body.insert("inputs".to_owned(), Value::String(prompt)); + if let Some(Value::Object(params)) = request_body.get_mut("parameters") { + params.insert("return_full_text".to_owned(), Value::Bool(false)); + } else { + let params = json!({ "parameters": { "return_full_text": false } }); + request_body.insert("parameters".to_owned(), params); + } } Backend::Ollama { .. } | Backend::OpenAi { .. } => { request_body.insert("prompt".to_owned(), Value::String(prompt)); From 0b75e5dd601945d8ef4c876a2b897ea66d3a69e4 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 19 Feb 2024 10:27:01 +0100 Subject: [PATCH 10/10] refactor: cleanup unused code (#82) --- crates/llm-ls/src/backend.rs | 27 ++++++++++++++- crates/llm-ls/src/error.rs | 6 ++-- crates/llm-ls/src/main.rs | 66 ------------------------------------ 3 files changed, 29 insertions(+), 70 deletions(-) diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index 5d0e68f..dba1fab 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,4 +1,4 @@ -use super::{APIError, APIResponse, Generation, NAME, VERSION}; +use super::{Generation, NAME, VERSION}; use custom_types::llm_ls::{Backend, Ide}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; @@ -7,6 +7,31 @@ use std::fmt::Display; use crate::error::{Error, Result}; +#[derive(Debug, Deserialize)] +pub struct APIError { + error: String, +} + +impl std::error::Error for APIError { + fn description(&self) -> &str { + &self.error + } +} + +impl Display for APIError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error) + } +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum APIResponse { + Generation(Generation), + Generations(Vec), + Error(APIError), +} + fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}"); diff --git a/crates/llm-ls/src/error.rs b/crates/llm-ls/src/error.rs index d6f305e..6812cf8 100644 --- a/crates/llm-ls/src/error.rs +++ b/crates/llm-ls/src/error.rs @@ -20,7 +20,7 @@ pub enum Error { #[error("io error: {0}")] Io(#[from] std::io::Error), #[error("inference api error: {0}")] - InferenceApi(crate::APIError), + InferenceApi(crate::backend::APIError), #[error("You are attempting to parse a result in the API inference format when using the `tgi` backend")] InvalidBackend, #[error("invalid header value: {0}")] @@ -30,7 +30,7 @@ pub enum Error { #[error("invalid tokenizer path")] InvalidTokenizerPath, #[error("ollama error: {0}")] - Ollama(crate::APIError), + Ollama(crate::backend::APIError), #[error("openai error: {0}")] OpenAI(crate::backend::OpenAIError), #[error("index out of bounds: {0}")] @@ -44,7 +44,7 @@ pub enum Error { #[error("serde json error: {0}")] SerdeJson(#[from] serde_json::Error), #[error("tgi error: {0}")] - Tgi(crate::APIError), + Tgi(crate::backend::APIError), #[error("tree-sitter language error: {0}")] TreeSitterLanguage(#[from] tree_sitter::LanguageError), #[error("tokenizer error: {0}")] diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 17be8aa..1ee932e 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -121,77 +121,11 @@ fn should_complete(document: &Document, position: Position) -> Result>, -} - -#[derive(Debug, Serialize)] -struct APIParams { - max_new_tokens: u32, - temperature: f32, - do_sample: bool, - top_p: f32, - #[allow(dead_code)] - #[serde(skip_serializing)] - stop: Option>, - return_full_text: bool, -} - -impl From for APIParams { - fn from(params: RequestParams) -> Self { - Self { - max_new_tokens: params.max_new_tokens, - temperature: params.temperature, - do_sample: params.do_sample, - top_p: params.top_p, - stop: params.stop_tokens, - return_full_text: false, - } - } -} - -#[derive(Serialize)] -struct APIRequest { - inputs: String, - parameters: APIParams, -} - #[derive(Debug, Serialize, Deserialize)] pub struct Generation { generated_text: String, } -#[derive(Debug, Deserialize)] -pub struct APIError { - error: String, -} - -impl std::error::Error for APIError { - fn description(&self) -> &str { - &self.error - } -} - -impl Display for APIError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error) - } -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -pub enum APIResponse { - Generation(Generation), - Generations(Vec), - Error(APIError), -} - struct LlmService { cache_dir: PathBuf, client: Client,