Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add llama.cpp backend #94

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ env:
RUSTFLAGS: "-D warnings -W unreachable-pub"
RUSTUP_MAX_RETRIES: 10
FETCH_DEPTH: 0 # pull in the tags for the version string
MACOSX_DEPLOYMENT_TARGET: 10.15
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc
CARGO_TARGET_ARM_UNKNOWN_LINUX_GNUEABIHF_LINKER: arm-linux-gnueabihf-gcc

Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ jobs:
DEBIAN_FRONTEND=noninteractive apt install -y pkg-config protobuf-compiler libssl-dev curl build-essential git-all gfortran
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
rustflags: ''
toolchain: nightly
uses: dtolnay/rust-toolchain@stable

- name: Install Python 3.10
uses: actions/setup-python@v5
Expand Down
7 changes: 3 additions & 4 deletions crates/custom-types/src/llm_ls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ pub enum Backend {
#[serde(default = "hf_default_url", deserialize_with = "parse_url")]
url: String,
},
// TODO:
// LlamaCpp {
// url: String,
// },
LlamaCpp {
url: String,
},
Ollama {
url: String,
},
Expand Down
36 changes: 36 additions & 0 deletions crates/llm-ls/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
}
}

#[derive(Debug, Serialize, Deserialize)]
struct LlamaCppGeneration {
content: String,
}

impl From<LlamaCppGeneration> for Generation {
fn from(value: LlamaCppGeneration) -> Self {
Generation {
generated_text: value.content,
}
}
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum LlamaCppAPIResponse {
Generation(LlamaCppGeneration),
Error(APIError),
}

fn build_llamacpp_headers() -> HeaderMap {
HeaderMap::new()
}

fn parse_llamacpp_text(text: &str) -> Result<Vec<Generation>> {
match serde_json::from_str(text)? {
LlamaCppAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
LlamaCppAPIResponse::Error(err) => Err(Error::LlamaCpp(err)),
}
}

#[derive(Debug, Serialize, Deserialize)]
struct OllamaGeneration {
response: String,
Expand Down Expand Up @@ -192,6 +223,9 @@ pub(crate) fn build_body(
request_body.insert("parameters".to_owned(), params);
}
}
Backend::LlamaCpp { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
}
Backend::Ollama { .. } | Backend::OpenAi { .. } => {
request_body.insert("prompt".to_owned(), Value::String(prompt));
request_body.insert("model".to_owned(), Value::String(model));
Expand All @@ -208,6 +242,7 @@ pub(crate) fn build_headers(
) -> Result<HeaderMap> {
match backend {
Backend::HuggingFace { .. } => build_api_headers(api_token, ide),
Backend::LlamaCpp { .. } => Ok(build_llamacpp_headers()),
Backend::Ollama { .. } => Ok(build_ollama_headers()),
Backend::OpenAi { .. } => build_openai_headers(api_token, ide),
Backend::Tgi { .. } => build_tgi_headers(api_token, ide),
Expand All @@ -217,6 +252,7 @@ pub(crate) fn build_headers(
pub(crate) fn parse_generations(backend: &Backend, text: &str) -> Result<Vec<Generation>> {
match backend {
Backend::HuggingFace { .. } => parse_api_text(text),
Backend::LlamaCpp { .. } => parse_llamacpp_text(text),
Backend::Ollama { .. } => parse_ollama_text(text),
Backend::OpenAi { .. } => parse_openai_text(text),
Backend::Tgi { .. } => parse_tgi_text(text),
Expand Down
34 changes: 24 additions & 10 deletions crates/llm-ls/src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl TryFrom<Vec<tower_lsp::lsp_types::PositionEncodingKind>> for PositionEncodi
}

impl PositionEncodingKind {
pub fn to_lsp_type(&self) -> tower_lsp::lsp_types::PositionEncodingKind {
pub fn to_lsp_type(self) -> tower_lsp::lsp_types::PositionEncodingKind {
match self {
PositionEncodingKind::Utf8 => tower_lsp::lsp_types::PositionEncodingKind::UTF8,
PositionEncodingKind::Utf16 => tower_lsp::lsp_types::PositionEncodingKind::UTF16,
Expand Down Expand Up @@ -205,9 +205,10 @@ impl Document {
) -> Result<()> {
match change.range {
Some(range) => {
if range.start.line < range.end.line
if range.start.line > range.end.line
|| (range.start.line == range.end.line
&& range.start.character <= range.end.character) {
&& range.start.character > range.end.character)
{
return Err(Error::InvalidRange(range));
}

Expand All @@ -219,7 +220,10 @@ impl Document {

// 1. Get the line at which the change starts.
let change_start_line_idx = range.start.line as usize;
let change_start_line = self.text.get_line(change_start_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines()))?;
let change_start_line =
self.text.get_line(change_start_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_start_line_idx, self.text.len_lines())
})?;

// 2. Get the line at which the change ends. (Small optimization
// where we first check whether start and end line are the
Expand All @@ -228,7 +232,9 @@ impl Document {
let change_end_line_idx = range.end.line as usize;
let change_end_line = match same_line {
true => change_start_line,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines()))?,
false => self.text.get_line(change_end_line_idx).ok_or_else(|| {
Error::OutOfBoundLine(change_end_line_idx, self.text.len_lines())
})?,
};

fn compute_char_idx(
Expand Down Expand Up @@ -330,7 +336,7 @@ impl Document {
self.tree = Some(new_tree);
}
None => {
return Err(Error::TreeSitterParseError);
return Err(Error::TreeSitterParsing);
}
}
}
Expand Down Expand Up @@ -416,7 +422,9 @@ mod test {
let mut rope = Rope::from_str(
"let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';",
);
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::JavaScript.to_string(), &rope.to_string())
.await
.unwrap();
let mut parser = Parser::new();

parser
Expand Down Expand Up @@ -464,7 +472,9 @@ mod test {
#[tokio::test]
async fn test_text_document_apply_content_change_bounds() {
let rope = Rope::from_str("");
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string()).await.unwrap();
let mut doc = Document::open(&LanguageId::Unknown.to_string(), &rope.to_string())
.await
.unwrap();

assert!(doc
.apply_content_change(new_change!(0, 0, 0, 1, ""), PositionEncodingKind::Utf16)
Expand Down Expand Up @@ -513,7 +523,9 @@ mod test {
async fn test_document_update_tree_consistency_easy() {
let a = "let a = '你好';\rlet b = 'Hi, 😊';";

let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();

document
.apply_content_change(new_change!(0, 9, 0, 11, "𐐀"), PositionEncodingKind::Utf16)
Expand Down Expand Up @@ -541,7 +553,9 @@ mod test {
async fn test_document_update_tree_consistency_medium() {
let a = "let a = '🥸 你好';\rfunction helloWorld() { return '🤲🏿'; }\nlet b = 'Hi, 😊';";

let mut document = Document::open(&LanguageId::JavaScript.to_string(), a).await.unwrap();
let mut document = Document::open(&LanguageId::JavaScript.to_string(), a)
.await
.unwrap();

document
.apply_content_change(new_change!(0, 14, 2, 13, ","), PositionEncodingKind::Utf16)
Expand Down
6 changes: 4 additions & 2 deletions crates/llm-ls/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub enum Error {
InvalidRepositoryId,
#[error("invalid tokenizer path")]
InvalidTokenizerPath,
#[error("llama.cpp error: {0}")]
LlamaCpp(crate::backend::APIError),
#[error("ollama error: {0}")]
Ollama(crate::backend::APIError),
#[error("openai error: {0}")]
Expand All @@ -50,7 +52,7 @@ pub enum Error {
#[error("tgi error: {0}")]
Tgi(crate::backend::APIError),
#[error("tree-sitter parse error: timeout possibly exceeded")]
TreeSitterParseError,
TreeSitterParsing,
#[error("tree-sitter language error: {0}")]
TreeSitterLanguage(#[from] tree_sitter::LanguageError),
#[error("tokenizer error: {0}")]
Expand All @@ -60,7 +62,7 @@ pub enum Error {
#[error("unknown backend: {0}")]
UnknownBackend(String),
#[error("unknown encoding kind: {0}")]
UnknownEncodingKind(String)
UnknownEncodingKind(String),
}

pub(crate) type Result<T> = std::result::Result<T, Error>;
Expand Down
14 changes: 13 additions & 1 deletion crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ async fn get_tokenizer(
fn build_url(backend: Backend, model: &str) -> String {
match backend {
Backend::HuggingFace { url } => format!("{url}/models/{model}"),
Backend::LlamaCpp { mut url } => {
if url.ends_with("/completions") {
url
} else if url.ends_with('/') {
url.push_str("completions");
url
} else {
url.push_str("/completions");
url
}
}
Backend::Ollama { url } => url,
Backend::OpenAi { url } => url,
Backend::Tgi { url } => url,
Expand Down Expand Up @@ -540,7 +551,8 @@ impl LanguageServer for LlmService {
general_capabilities
.position_encodings
.map(TryFrom::try_from)
}).unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;
})
.unwrap_or(Ok(document::PositionEncodingKind::Utf16))?;

*self.position_encoding.write().await = position_encoding;

Expand Down
12 changes: 6 additions & 6 deletions crates/testbed/repositories-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
context_window: 2000
fim:
enabled: true
prefix: <fim_prefix>
middle: <fim_middle>
suffix: <fim_suffix>
model: bigcode/starcoder
prefix: "<PRE> "
middle: " <MID>"
suffix: " <SUF>"
model: codellama/CodeLlama-13b-hf
backend: huggingface
request_body:
max_new_tokens: 150
Expand All @@ -14,8 +14,8 @@ request_body:
top_p: 0.95
tls_skip_verify_insecure: false
tokenizer_config:
repository: bigcode/starcoder
tokens_to_clear: ["<|endoftext|>"]
repository: codellama/CodeLlama-13b-hf
tokens_to_clear: ["<EOT>"]
repositories:
- source:
type: local
Expand Down
Loading