diff --git a/README.md b/README.md index 507b50f..9ac5857 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,16 @@ It also makes sure that you are within the context window of the model by tokeni Gathers information about requests and completions that can enable retraining. -Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file if you set the log level to `info`. +Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file (`~/.cache/llm_ls/llm-ls.log`) if you set the log level to `info`. ### Completion **llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion). +### Multiple backends + +**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server). + ## Compatible extensions - [x] [llm.nvim](https://github.com/huggingface/llm.nvim) @@ -38,6 +42,4 @@ Note that **llm-ls** does not export any data anywhere (other than setting a use - add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt - add context window fill percent or change context_window to `max_tokens` - filter bad suggestions (repetitive, same as below, etc) -- support for ollama -- support for llama.cpp - oltp traces ? diff --git a/crates/custom-types/src/llm_ls.rs b/crates/custom-types/src/llm_ls.rs index 10edc7c..c1a81d2 100644 --- a/crates/custom-types/src/llm_ls.rs +++ b/crates/custom-types/src/llm_ls.rs @@ -5,6 +5,8 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_json::{Map, Value}; use uuid::Uuid; +const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co"; + #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct AcceptCompletionParams { @@ -47,19 +49,42 @@ where Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown)) } -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] +fn hf_default_url() -> String { + format!("https://{HF_INFERENCE_API_HOSTNAME}") +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "lowercase", tag = "backend")] pub enum Backend { - #[default] - HuggingFace, - Ollama, - OpenAi, - Tgi, + HuggingFace { + #[serde(default = "hf_default_url")] + url: String, + }, + Ollama { + url: String, + }, + OpenAi { + url: String, + }, + Tgi { + url: String, + }, } -impl Display for Backend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.serialize(f) +impl Default for Backend { + fn default() -> Self { + Self::HuggingFace { + url: hf_default_url(), + } + } +} + +impl Backend { + pub fn is_using_inference_api(&self) -> bool { + match self { + Self::HuggingFace { url } => url.contains(HF_INFERENCE_API_HOSTNAME), + _ => false, + } } } @@ -98,11 +123,13 @@ pub struct GetCompletionsParams { pub fim: FimParams, pub api_token: Option, pub model: String, + #[serde(flatten)] pub backend: Backend, pub tokens_to_clear: Vec, pub tokenizer_config: Option, pub context_window: usize, pub tls_skip_verify_insecure: bool, + #[serde(default)] pub request_body: Map, } diff --git a/crates/llm-ls/src/backend.rs b/crates/llm-ls/src/backend.rs index 76a5d7c..357a3d5 100644 --- a/crates/llm-ls/src/backend.rs +++ b/crates/llm-ls/src/backend.rs @@ -1,5 +1,5 @@ use super::{APIError, APIResponse, Generation, NAME, VERSION}; -use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide}; +use custom_types::llm_ls::{Backend, Ide}; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; @@ -151,33 +151,39 @@ fn parse_openai_text(text: &str) -> Result> { } } -pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map { - let mut body = params.request_body.clone(); - match params.backend { - Backend::HuggingFace | Backend::Tgi => { - body.insert("inputs".to_string(), Value::String(prompt)) +pub fn build_body( + backend: &Backend, + model: String, + prompt: String, + mut request_body: Map, +) -> Map { + match backend { + Backend::HuggingFace { .. } | Backend::Tgi { .. } => { + request_body.insert("inputs".to_owned(), Value::String(prompt)); } - Backend::Ollama | Backend::OpenAi => { - body.insert("prompt".to_string(), Value::String(prompt)) + Backend::Ollama { .. } | Backend::OpenAi { .. } => { + request_body.insert("prompt".to_owned(), Value::String(prompt)); + request_body.insert("model".to_owned(), Value::String(model)); + request_body.insert("stream".to_owned(), Value::Bool(false)); } }; - body + request_body } pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result { match backend { - Backend::HuggingFace => build_api_headers(api_token, ide), - Backend::Ollama => Ok(build_ollama_headers()), - Backend::OpenAi => build_openai_headers(api_token, ide), - Backend::Tgi => build_tgi_headers(api_token, ide), + Backend::HuggingFace { .. } => build_api_headers(api_token, ide), + Backend::Ollama { .. } => Ok(build_ollama_headers()), + Backend::OpenAi { .. } => build_openai_headers(api_token, ide), + Backend::Tgi { .. } => build_tgi_headers(api_token, ide), } } pub fn parse_generations(backend: &Backend, text: &str) -> Result> { match backend { - Backend::HuggingFace => parse_api_text(text), - Backend::Ollama => parse_ollama_text(text), - Backend::OpenAi => parse_openai_text(text), - Backend::Tgi => parse_tgi_text(text), + Backend::HuggingFace { .. } => parse_api_text(text), + Backend::Ollama { .. } => parse_ollama_text(text), + Backend::OpenAi { .. } => parse_openai_text(text), + Backend::Tgi { .. } => parse_tgi_text(text), } } diff --git a/crates/llm-ls/src/main.rs b/crates/llm-ls/src/main.rs index 27441ad..17be8aa 100644 --- a/crates/llm-ls/src/main.rs +++ b/crates/llm-ls/src/main.rs @@ -34,7 +34,6 @@ mod language_id; const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600); pub const NAME: &str = "llm-ls"; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); -const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co"; fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result { Ok(rope.try_line_to_char(row)? @@ -305,10 +304,15 @@ async fn request_completion( ) -> Result> { let t = Instant::now(); - let json = build_body(prompt, params); + let json = build_body( + ¶ms.backend, + params.model.clone(), + prompt, + params.request_body.clone(), + ); let headers = build_headers(¶ms.backend, params.api_token.as_ref(), params.ide)?; let res = http_client - .post(build_url(¶ms.model)) + .post(build_url(params.backend.clone(), ¶ms.model)) .json(&json) .headers(headers) .send() @@ -367,7 +371,7 @@ async fn download_tokenizer_file( return Ok(()); } tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?; - let headers = build_headers(&Backend::HuggingFace, api_token, ide)?; + let headers = build_headers(&Backend::default(), api_token, ide)?; let mut file = tokio::fs::OpenOptions::new() .write(true) .create(true) @@ -475,11 +479,12 @@ async fn get_tokenizer( } } -fn build_url(model: &str) -> String { - if model.starts_with("http://") || model.starts_with("https://") { - model.to_owned() - } else { - format!("https://{HF_INFERENCE_API_HOSTNAME}/models/{model}") +fn build_url(backend: Backend, model: &str) -> String { + match backend { + Backend::HuggingFace { url } => format!("{url}/models/{model}"), + Backend::Ollama { url } => url, + Backend::OpenAi { url } => url, + Backend::Tgi { url } => url, } } @@ -502,14 +507,13 @@ impl LlmService { cursor_character = ?params.text_document_position.position.character, language_id = %document.language_id, model = params.model, - backend = %params.backend, + backend = ?params.backend, ide = %params.ide, request_body = serde_json::to_string(¶ms.request_body).map_err(internal_error)?, "received completion request for {}", params.text_document_position.text_document.uri ); - let is_using_inference_api = matches!(params.backend, Backend::HuggingFace); - if params.api_token.is_none() && is_using_inference_api { + if params.api_token.is_none() && params.backend.is_using_inference_api() { let now = Instant::now(); let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await; if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT {