Skip to content

Commit

Permalink
refactor: error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
McPatate committed Feb 7, 2024
1 parent a9831d5 commit 5b2e23c
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 218 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/llm-ls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ reqwest = { version = "0.11", default-features = false, features = [
] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
tokenizers = { version = "0.14", default-features = false, features = ["onig"] }
tokio = { version = "1", features = [
"fs",
Expand Down
93 changes: 38 additions & 55 deletions crates/llm-ls/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,44 @@
use super::{
internal_error, APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION,
};
use super::{APIError, APIResponse, CompletionParams, Generation, Ide, NAME, VERSION};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::fmt::Display;
use tower_lsp::jsonrpc;

fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
use crate::error::{Error, Result};

fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
headers.insert(
USER_AGENT,
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
);
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);

if let Some(api_token) = api_token {
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
HeaderValue::from_str(&format!("Bearer {api_token}"))?,
);
}

Ok(headers)
}

fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations =
match serde_json::from_str(text).map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(_) => {
return Err(internal_error(
"You are attempting to parse a result in the API inference format when using the `tgi` backend",
))
}
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
APIResponse::Generation(gen) => Ok(vec![gen]),
APIResponse::Generations(_) => Err(Error::InvalidBackend),
APIResponse::Error(err) => Err(Error::Tgi(err)),
}
}

fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
build_tgi_headers(api_token, ide)
}

fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations = match serde_json::from_str(text).map_err(internal_error)? {
APIResponse::Generation(gen) => vec![gen],
APIResponse::Generations(gens) => gens,
APIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
}

fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
Ok(HeaderMap::new())
fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
APIResponse::Generation(gen) => Ok(vec![gen]),
APIResponse::Generations(gens) => Ok(gens),
APIResponse::Error(err) => Err(Error::InferenceApi(err)),
}
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -76,16 +61,15 @@ enum OllamaAPIResponse {
Error(APIError),
}

fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
let generations = match serde_json::from_str(text).map_err(internal_error)? {
OllamaAPIResponse::Generation(gen) => vec![gen.into()],
OllamaAPIResponse::Error(err) => return Err(internal_error(err)),
};
Ok(generations)
fn build_ollama_headers() -> HeaderMap {
HeaderMap::new()
}

fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
build_api_headers(api_token, ide)
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)),
}
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -130,7 +114,7 @@ struct OpenAIErrorDetail {
}

#[derive(Debug, Deserialize)]
struct OpenAIError {
pub struct OpenAIError {
detail: Vec<OpenAIErrorDetail>,
}

Expand All @@ -153,13 +137,16 @@ enum OpenAIAPIResponse {
Error(OpenAIError),
}

fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
match serde_json::from_str(text).map_err(internal_error) {
Ok(OpenAIAPIResponse::Generation(completion)) => {
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
build_api_headers(api_token, ide)
}

fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
OpenAIAPIResponse::Generation(completion) => {
Ok(completion.choices.into_iter().map(|x| x.into()).collect())
}
Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)),
Err(err) => Err(internal_error(err)),
OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)),
}
}

Expand All @@ -186,20 +173,16 @@ pub fn build_body(prompt: String, params: &CompletionParams) -> Map<String, Valu
body
}

pub fn build_headers(
backend: &Backend,
api_token: Option<&String>,
ide: Ide,
) -> Result<HeaderMap, jsonrpc::Error> {
pub fn build_headers(backend: &Backend, api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace => build_api_headers(api_token, ide),
Backend::Ollama => build_ollama_headers(),
Backend::Ollama => Ok(build_ollama_headers()),
Backend::OpenAi => build_openai_headers(api_token, ide),
Backend::Tgi => build_tgi_headers(api_token, ide),
}
}

pub fn parse_generations(backend: &Backend, text: &str) -> jsonrpc::Result<Vec<Generation>> {
pub fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace => parse_api_text(text),
Backend::Ollama => parse_ollama_text(text),
Expand Down
Loading

0 comments on commit 5b2e23c

Please sign in to comment.