Skip to content

Commit

Permalink
Only warn of rate-limits when using HF endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
HennerM committed Jan 20, 2024
1 parent 585ea3a commit 2ea196d
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions crates/llm-ls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod language_id;
const MAX_WARNING_REPEAT: Duration = Duration::from_secs(3_600);
pub const NAME: &str = "llm-ls";
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
const HUGGINGFACE_INFERENCE_HOSTNAME: &str = "api-inference.huggingface.co";

fn get_position_idx(rope: &Rope, row: usize, col: usize) -> Result<usize> {
Ok(rope.try_line_to_char(row).map_err(internal_error)?
Expand Down Expand Up @@ -247,15 +248,15 @@ where
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
// #[serde(rename_all = "camelCase")]
struct AcceptedCompletion {
request_id: Uuid,
accepted_completion: u32,
shown_completions: Vec<u32>,
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
// #[serde(rename_all = "camelCase")]
struct RejectedCompletion {
request_id: Uuid,
shown_completions: Vec<u32>,
Expand Down Expand Up @@ -584,10 +585,14 @@ fn build_url(model: &str) -> String {
if model.starts_with("http://") || model.starts_with("https://") {
model.to_owned()
} else {
format!("https://api-inference.huggingface.co/models/{model}")
format!("https://{HUGGINGFACE_INFERENCE_HOSTNAME}/models/{model}")
}
}

fn is_hf_model(model: &str) -> bool {
return build_url(model).contains(HUGGINGFACE_INFERENCE_HOSTNAME);
}

impl Backend {
async fn get_completions(&self, params: CompletionParams) -> Result<CompletionResult> {
let request_id = Uuid::new_v4();
Expand All @@ -613,7 +618,7 @@ impl Backend {
"received completion request for {}",
params.text_document_position.text_document.uri
);
if params.api_token.is_none() {
if params.api_token.is_none() && is_hf_model(&params.model) {
let now = Instant::now();
let unauthenticated_warn_at = self.unauthenticated_warn_at.read().await;
if now.duration_since(*unauthenticated_warn_at) > MAX_WARNING_REPEAT {
Expand Down Expand Up @@ -734,8 +739,11 @@ impl LanguageServer for Backend {

async fn did_change(&self, params: DidChangeTextDocumentParams) {
let uri = params.text_document.uri.to_string();
if params.content_changes.is_empty() {
return;
}
self.client
.log_message(MessageType::INFO, format!("{uri} changed"))
.log_message(MessageType::LOG, format!("{uri} changed"))
.await;
let mut document_map = self.document_map.write().await;
let doc = document_map.get_mut(&uri);
Expand Down

0 comments on commit 2ea196d

Please sign in to comment.