Skip to content

Commit

Permalink
feat: update backend & model parameter (#74)
Browse files Browse the repository at this point in the history
* feat: update backend & model parameter

* fix: add `stream: false` in request body for ollama & openai
  • Loading branch information
McPatate authored Feb 9, 2024
1 parent 92fc885 commit 4891468
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 42 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ It also makes sure that you are within the context window of the model by tokeni

Gathers information about requests and completions that can enable retraining.

Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file if you set the log level to `info`.
Note that **llm-ls** does not export any data anywhere (other than setting a user agent when querying the model API), everything is stored in a log file (`~/.cache/llm_ls/llm-ls.log`) if you set the log level to `info`.

### Completion

**llm-ls** parses the AST of the code to determine if completions should be multi line, single line or empty (no completion).

### Multiple backends

**llm-ls** is compatible with Hugging Face's [Inference API](https://huggingface.co/docs/api-inference/en/index), Hugging Face's [text-generation-inference](https://github.com/huggingface/text-generation-inference), [ollama](https://github.com/ollama/ollama) and OpenAI compatible APIs, like [llama.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server).

## Compatible extensions

- [x] [llm.nvim](https://github.com/huggingface/llm.nvim)
Expand All @@ -38,6 +42,4 @@ Note that **llm-ls** does not export any data anywhere (other than setting a use
- add `suffix_percent` setting that determines the ratio of # of tokens for the prefix vs the suffix in the prompt
- add context window fill percent or change context_window to `max_tokens`
- filter bad suggestions (repetitive, same as below, etc)
- support for ollama
- support for llama.cpp
- oltp traces ?
47 changes: 37 additions & 10 deletions crates/custom-types/src/llm_ls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{Map, Value};
use uuid::Uuid;

const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AcceptCompletionParams {
Expand Down Expand Up @@ -47,19 +49,42 @@ where
Deserialize::deserialize(d).map(|b: Option<_>| b.unwrap_or(Ide::Unknown))
}

#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
fn hf_default_url() -> String {
format!("https://{HF_INFERENCE_API_HOSTNAME}")
}

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "lowercase", tag = "backend")]
pub enum Backend {
#[default]
HuggingFace,
Ollama,
OpenAi,
Tgi,
HuggingFace {
#[serde(default = "hf_default_url")]
url: String,
},
Ollama {
url: String,
},
OpenAi {
url: String,
},
Tgi {
url: String,
},
}

impl Display for Backend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.serialize(f)
impl Default for Backend {
fn default() -> Self {
Self::HuggingFace {
url: hf_default_url(),
}
}
}

impl Backend {
pub fn is_using_inference_api(&self) -> bool {
match self {
Self::HuggingFace { url } => url.contains(HF_INFERENCE_API_HOSTNAME),
_ => false,
}
}
}

Expand Down Expand Up @@ -98,11 +123,13 @@ pub struct GetCompletionsParams {
pub fim: FimParams,
pub api_token: Option<String>,
pub model: String,
#[serde(flatten)]
pub backend: Backend,
pub tokens_to_clear: Vec<String>,
pub tokenizer_config: Option<TokenizerConfig>,
pub context_window: usize,
pub tls_skip_verify_insecure: bool,
#[serde(default)]
pub request_body: Map<String, Value>,
}

Expand Down
40 changes: 23 additions & 17 deletions crates/llm-ls/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{APIError, APIResponse, Generation, NAME, VERSION};
use custom_types::llm_ls::{Backend, GetCompletionsParams, Ide};
use custom_types::llm_ls::{Backend, Ide};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
Expand Down Expand Up @@ -151,33 +151,39 @@ fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
}
}

pub fn build_body(prompt: String, params: &GetCompletionsParams) -> Map<String, Value> {
let mut body = params.request_body.clone();
match params.backend {
Backend::HuggingFace | Backend::Tgi => {
body.insert("inputs".to_string(), Value::String(prompt))
pub fn build_body(
backend: &Backend,
model: String,
prompt: String,
mut request_body: Map<String, Value>,
) -> Map<String, Value> {
match backend {
Backend::HuggingFace { .. } | Backend::Tgi { .. } => {
request_body.insert("inputs".to_owned(), Value::String(prompt));
}
Backend::Ollama | Backend::OpenAi => {
body.insert("prompt".to_string(), Value::String(prompt))
Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model));
request_body.insert("stream".to_owned(), Value::Bool(false));
}
};
body
request_body
}

pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace => build_api_headers(api_token, ide),
Backend::Ollama => Ok(build_ollama_headers()),
Backend::OpenAi => build_openai_headers(api_token, ide),
Backend::Tgi => build_tgi_headers(api_token, ide),
Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
}
}

pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace => parse_api_text(text),
Backend::Ollama => parse_ollama_text(text),
Backend::OpenAi => parse_openai_text(text),
Backend::Tgi => parse_tgi_text(text),
Backend::HuggingFace { .. } => parse_api_text(text),
Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi { .. } => parse_tgi_text(text),
}
}
28 changes: 16 additions & 12 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ mod language_id;
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
pub const NAME: &str = "llm-ls";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
const HF_INFERENCE_API_HOSTNAME: &str = "api-inference.huggingface.co";

fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row)?
Expand Down Expand Up @@ -305,10 +304,15 @@ async fn request_completion(
) -> Result<Vec<Generation>> {
let t = Instant::now();

let json = build_body(prompt, params);
let json = build_body(
&params.backend,
params.model.clone(),
prompt,
params.request_body.clone(),
);
let headers = build_headers(&params.backend, params.api_token.as_ref(), params.ide)?;
let res = http_client
.post(build_url(&params.model))
.post(build_url(params.backend.clone(), &params.model))
.json(&json)
.headers(headers)
.send()
Expand Down Expand Up @@ -367,7 +371,7 @@ async fn download_tokenizer_file(
return Ok(());
}
tokio::fs::create_dir_all(to.as_ref().parent().ok_or(Error::InvalidTokenizerPath)?).await?;
let headers = build_headers(&Backend::HuggingFace, api_token, ide)?;
let headers = build_headers(&Backend::default(), api_token, ide)?;
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.create(true)
Expand Down Expand Up @@ -475,11 +479,12 @@ async fn get_tokenizer(
}
}

fn build_url(model: &str) -> String {
if model.starts_with("http://") || model.starts_with("https://") {
model.to_owned()
} else {
format!("https://{HF_INFERENCE_API_HOSTNAME}/models/{model}")
fn build_url(backend: Backend, model: &str) -> String {
match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
}
}

Expand All @@ -502,14 +507,13 @@ impl LlmService {
cursor_character = ?params.text_document_position.position.character,
language_id = %document.language_id,
model = params.model,
backend = %params.backend,
backend = ?params.backend,
ide = %params.ide,
request_body = serde_json::to_string(&params.request_body).map_err(internal_error)?,
"received completion request for {}",
params.text_document_position.text_document.uri
);
let is_using_inference_api = matches!(params.backend, Backend::HuggingFace);
if params.api_token.is_none() && is_using_inference_api {
if params.api_token.is_none() && params.backend.is_using_inference_api() {
let now = Instant::now();
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT {
Expand Down

0 comments on commit 4891468

Please sign in to comment.