From 8f6ff98640a3dd808f49a221b19e43ca9ff3be33 Mon Sep 17 00:00:00 2001 From: tenkai Date: Tue, 21 May 2024 00:19:36 -0700 Subject: [PATCH] fixed various bugs --- Readme.md | 2 +- languages.toml | 3 +- sazid-term/src/application.rs | 3 +- sazid-term/src/args.rs | 13 +- sazid-term/src/commands/llm.rs | 11 +- sazid-term/src/config.rs | 1 - sazid-term/src/ui/markdown_renderer.rs | 2 +- sazid/src/action.rs | 1 + sazid/src/app/consts.rs | 5 + sazid/src/app/lsi/interface.rs | 6 +- sazid/src/app/lsi/mod.rs | 139 ++++++++++++++---- sazid/src/app/lsi/symbol_types.rs | 9 +- sazid/src/app/lsi/tool_impl.rs | 31 +++- sazid/src/app/lsi/workspace.rs | 17 +-- .../app/model_tools/lsp_read_symbol_source.rs | 88 +++++++++++ sazid/src/app/model_tools/mod.rs | 1 + sazid/src/app/model_tools/tool_call.rs | 2 + sazid/src/app/session_config.rs | 2 +- 18 files changed, 279 insertions(+), 57 deletions(-) create mode 100644 sazid/src/app/model_tools/lsp_read_symbol_source.rs diff --git a/Readme.md b/Readme.md index 6eddf1d..d81d90f 100644 --- a/Readme.md +++ b/Readme.md @@ -5,7 +5,7 @@ [![CI](https://github.com/cosmikwolf/sazid/workflows/CI/badge.svg)](https://github.com/cosmikwolf/sazid/actions) **Currently a work in progress** ---- embeddings are not hooked up at the moment --- +--- embeddings are currently broken--- Sazid is an interactive LLM interface written in Rust that provides: diff --git a/languages.toml b/languages.toml index bf4d1bf..31de13b 100644 --- a/languages.toml +++ b/languages.toml @@ -185,8 +185,9 @@ command = [ [language-server.rust-analyzer] +# command = "ra-multiplex" command = "rust-analyzer" -args = ["-v", "--log-file", "/tmp/rust-analyzer.log"] +# args = ["-v", "--log-file", "/tmp/rust-analyzer.log"] [language-server.rust-analyzer.config] inlayHints.bindingModeHints.enable = false diff --git a/sazid-term/src/application.rs b/sazid-term/src/application.rs index 8fbc584..7cbb17c 100644 --- a/sazid-term/src/application.rs +++ b/sazid-term/src/application.rs @@ -181,9 +181,10 @@ impl Application { session_config.workspace = Some(WorkspaceParams { workspace_path, language, - language_server: "rust-analyzer".to_string(), + language_server: args.language_server.unwrap_or("rust-analyzer".to_string()), doc_path: None, }); + log::debug!("workspace: {:#?}", session_config.workspace); }, (None, None) => {}, (None, Some(_)) => { diff --git a/sazid-term/src/args.rs b/sazid-term/src/args.rs index 0f43ac0..999a112 100644 --- a/sazid-term/src/args.rs +++ b/sazid-term/src/args.rs @@ -19,6 +19,7 @@ pub struct Args { pub files: Vec<(PathBuf, Position)>, pub workspace: Option, pub language: Option, + pub language_server: Option, } impl Args { @@ -72,11 +73,15 @@ impl Args { None => anyhow::bail!("--log must specify a path to write"), } }, - "-l" | "--language" => match argv.next().as_deref() { - Some(language) => { + "-l" | "--language" => { + if let Some(language) = argv.next().as_deref() { args.language = Some(language.into()); - }, - None => {}, + } + }, + "-ls" | "--language-server" => { + if let Some(language_server) = argv.next().as_deref() { + args.language_server = Some(language_server.into()); + } }, "-w" | "--workspace" => match argv.next().as_deref() { Some(path) => { diff --git a/sazid-term/src/commands/llm.rs b/sazid-term/src/commands/llm.rs index 8d71ccd..8ad677c 100644 --- a/sazid-term/src/commands/llm.rs +++ b/sazid-term/src/commands/llm.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc}; +use std::sync::Arc; use super::Context; use crate::{ @@ -93,9 +93,9 @@ impl ChatMessageItem { if self.plaintext_wrapped_width == width { self.plain_text.len_lines() } else { - log::error!( - "need to update wrapping before trying to get wrapped height, or else it is not up to date" - ); + //log::error!( + // "need to update wrapping before trying to get wrapped height, or else it is not up to date" + //); self.plain_text.len_lines() } } @@ -212,7 +212,8 @@ impl ChatMessageItem { let content = if let ChatMessageType::Chat(ChatCompletionRequestMessage::Tool(_)) = &self.chat_message { if self.content().lines().count() > 1 { - "tool call response content" + //"tool call response content" + self.content() } else { self.content() } diff --git a/sazid-term/src/config.rs b/sazid-term/src/config.rs index e3e6f03..c8b256f 100644 --- a/sazid-term/src/config.rs +++ b/sazid-term/src/config.rs @@ -44,7 +44,6 @@ pub enum ConfigLoadError { BadConfig(TomlError), Error(IOError), } - impl Default for ConfigLoadError { fn default() -> Self { ConfigLoadError::Error(IOError::new(std::io::ErrorKind::NotFound, "place holder")) diff --git a/sazid-term/src/ui/markdown_renderer.rs b/sazid-term/src/ui/markdown_renderer.rs index 4161a56..86c33fc 100644 --- a/sazid-term/src/ui/markdown_renderer.rs +++ b/sazid-term/src/ui/markdown_renderer.rs @@ -285,7 +285,7 @@ impl MarkdownRenderer { }, // TaskListMarker(bool) true if checked _ => { - log::warn!("unhandled markdown event {:?}", event); + //log::warn!("unhandled markdown event {:?}", event); }, } // build up a vec of Paragraph tui widgets diff --git a/sazid/src/action.rs b/sazid/src/action.rs index 72c326d..5d237c0 100644 --- a/sazid/src/action.rs +++ b/sazid/src/action.rs @@ -64,6 +64,7 @@ pub enum LsiAction { QueryWorkspaceSymbols(LsiQuery), GetWorkspaceFiles(LsiQuery), ReplaceSymbolText(String, LsiQuery), + ReadSymbolSource(LsiQuery), GoToSymbolDefinition(LsiQuery), GoToSymbolDeclaration(LsiQuery), GoToTypeDefinition(LsiQuery), diff --git a/sazid/src/app/consts.rs b/sazid/src/app/consts.rs index dc7cf9b..c625372 100644 --- a/sazid/src/app/consts.rs +++ b/sazid/src/app/consts.rs @@ -10,6 +10,11 @@ pub const INGESTED_DIR: &str = ".local/share/sazid/data/ingested"; lazy_static! { // model constants + pub static ref GPT4_O: Model = Model { + name: "gpt-4o".to_string(), + endpoint: "https://api.openai.com/v1/completions".to_string(), + token_limit: 16384, + }; pub static ref GPT4_TURBO: Model = Model { name: "gpt-4-turbo-preview".to_string(), endpoint: "https://api.openai.com/v1/completions".to_string(), diff --git a/sazid/src/app/lsi/interface.rs b/sazid/src/app/lsi/interface.rs index 5a74c15..84a845a 100644 --- a/sazid/src/app/lsi/interface.rs +++ b/sazid/src/app/lsi/interface.rs @@ -49,6 +49,10 @@ impl LanguageServerInterface { log::error!("{}", error); Ok(None) }, + LsiAction::ReadSymbolSource(lsi_query) => { + let lsi_query_result = self.lsi_read_symbol_source(&lsi_query); + Self::handle_lsi_query_result(lsi_query, lsi_query_result) + }, LsiAction::ReplaceSymbolText(replacement_text, lsi_query) => { let lsi_query_result = self.lsi_replace_symbol_text(replacement_text, &lsi_query); Self::handle_lsi_query_result(lsi_query, lsi_query_result) @@ -514,7 +518,7 @@ impl LanguageServerInterface { enable_snippets, ) .find(|(name, _client)| name == languge_server_name) - .unwrap() + .expect("language server not found") .1 .map_err(|e| anyhow::anyhow!(e))?; Ok(Some(client)) diff --git a/sazid/src/app/lsi/mod.rs b/sazid/src/app/lsi/mod.rs index 3c27a91..3b251ed 100644 --- a/sazid/src/app/lsi/mod.rs +++ b/sazid/src/app/lsi/mod.rs @@ -20,26 +20,12 @@ fn position_gt(pos1: lsp::Position, pos2: lsp::Position) -> bool { } fn get_file_range_contents(file_path: &Path, range: lsp::Range) -> anyhow::Result { - let source_code = std::fs::read_to_string(file_path)?; - if range.start == range.end { - return Ok(String::new()); - } - let source_code = source_code - .lines() - .skip(range.start.line as usize) - .take((range.end.line - range.start.line) as usize + 1) - .enumerate() - .map(|(i, line)| { - if i == 0 { - line.chars().skip(range.start.character as usize).collect() - } else if i == (range.end.line - range.start.line) as usize { - line.chars().take(range.end.character as usize).collect() - } else { - line.to_string() - } - }) - .collect::>() - .join("\n"); + let rope = Rope::from_reader(std::fs::File::open(file_path)?)?; + + let start_char = rope.line_to_char(range.start.line as usize) + range.start.character as usize; + let end_char = rope.line_to_char(range.end.line as usize) + range.end.character as usize; + + let source_code = rope.slice(start_char..end_char).to_string(); Ok(source_code) } @@ -50,12 +36,15 @@ pub fn replace_file_range_contents( ) -> anyhow::Result { let mut rope = Rope::from_reader(std::fs::File::open(file_path)?)?; + println!("rope: {}-", rope); let start_char = rope.line_to_char(range.start.line as usize) + range.start.character as usize; let end_char = rope.line_to_char(range.end.line as usize) + range.end.character as usize; - rope.remove(start_char..end_char); + let end_rope = rope.split_off(end_char); + println!("end_rope: {}-", end_rope); + rope.remove(start_char..); rope.insert(start_char, &contents); - + rope.append(end_rope); let new_contents = rope.to_string(); std::fs::write(file_path, &new_contents)?; @@ -65,17 +54,113 @@ pub fn replace_file_range_contents( #[cfg(test)] mod tests { use super::*; + use lsp::Range; + use std::fs::read_to_string; use std::fs::File; use std::io::Write; use tempfile::tempdir; + #[test] + fn test_get_file_range_contents_standard_case() -> anyhow::Result<()> { + let tmp_dir = tempdir().unwrap(); + let file_path = tmp_dir.path().join("example.txt"); + + let mut file = File::create(&file_path)?; + write!(file, "line 1\nline 2\nline 3\nline 4")?; + + let range = Range { + start: lsp_types::Position { line: 1, character: 3 }, + end: lsp_types::Position { line: 2, character: 4 }, + }; + + let content = get_file_range_contents(&file_path, range)?; + assert_eq!(content, "e 2\nline"); + + Ok(()) + } + + #[test] + fn test_get_file_range_contents_empty_range() -> anyhow::Result<()> { + let tmp_dir = tempdir().unwrap(); + let file_path = tmp_dir.path().join("example.txt"); + + let mut file = File::create(&file_path)?; + write!(file, "line 1\nline 2\nline 3\nline 4")?; + + let range = Range { + start: lsp_types::Position { line: 1, character: 3 }, + end: lsp_types::Position { line: 1, character: 3 }, + }; + + let content = get_file_range_contents(&file_path, range)?; + assert_eq!(content, ""); + + Ok(()) + } + + #[test] + fn test_get_file_range_contents_whole_file() -> anyhow::Result<()> { + let tmp_dir = tempdir().unwrap(); + let file_path = tmp_dir.path().join("example.txt"); + + let mut file = File::create(&file_path)?; + write!(file, "line 1\nline 2\nline 3\nline 4")?; + + let range = Range { + start: lsp_types::Position { line: 0, character: 0 }, + end: lsp_types::Position { line: 3, character: 6 }, + }; + + let content = get_file_range_contents(&file_path, range)?; + assert_eq!(content, "line 1\nline 2\nline 3\nline 4"); + + Ok(()) + } + + #[test] + fn test_get_file_range_contents_single_line() -> anyhow::Result<()> { + let tmp_dir = tempdir().unwrap(); + let file_path = tmp_dir.path().join("example.txt"); + + let mut file = File::create(&file_path)?; + write!(file, "line 1\nline 2\nline 3\nline 4")?; + + let range = Range { + start: lsp_types::Position { line: 1, character: 2 }, + end: lsp_types::Position { line: 1, character: 5 }, + }; + + let content = get_file_range_contents(&file_path, range)?; + assert_eq!(content, "ne 2"); + + Ok(()) + } + + #[test] + fn test_get_file_range_contents_with_special_characters() -> anyhow::Result<()> { + let tmp_dir = tempdir().unwrap(); + let file_path = tmp_dir.path().join("example.txt"); + + let mut file = File::create(&file_path)?; + write!(file, "line 1\nlïne 2\nline 3\nlįne 4")?; + + let range = Range { + start: lsp_types::Position { line: 1, character: 1 }, + end: lsp_types::Position { line: 3, character: 3 }, + }; + + let content = get_file_range_contents(&file_path, range)?; + assert_eq!(content, "ïne 2\nline 3\nlįn"); + + Ok(()) + } #[test] fn test_replace_file_range_contents() { // Create a temporary directory and file for testing let temp_dir = tempdir().unwrap(); let file_path = temp_dir.path().join("test.txt"); let mut file = File::create(&file_path).unwrap(); - writeln!(file, "line 1\nline 2\nline 3\nline 4\nline 5").unwrap(); + write!(file, "line 1\nline 2\nline 3\nline 4\nline 5").unwrap(); // Test replacing content within multiple lines let range = lsp::Range { @@ -84,7 +169,7 @@ mod tests { }; let contents = "new content".to_string(); let result = replace_file_range_contents(&file_path, range, contents.clone()).unwrap(); - let expected_result = "line 1\nlinew content\nline 5".to_string(); + let expected_result = "line 1\nlinew content3\nline 4\nline 5".to_string(); assert_eq!(result, expected_result); // Check the contents of the file @@ -98,7 +183,7 @@ mod tests { }; let contents = "new".to_string(); let result = replace_file_range_contents(&file_path, range, contents).unwrap(); - let expected_result = "linew 1\nline 2\nline 3\nline 4\nline 5".to_string(); + let expected_result = "linew1\nlinew content3\nline 4\nline 5".to_string(); assert_eq!(result, expected_result); // Test replacing content from the beginning of the file to the middle of a line @@ -108,13 +193,13 @@ mod tests { }; let contents = "start".to_string(); let result = replace_file_range_contents(&file_path, range, contents).unwrap(); - let expected_result = "starte 2\nline 3\nline 4\nline 5".to_string(); + let expected_result = "startew content3\nline 4\nline 5".to_string(); assert_eq!(result, expected_result); // Test replacing the entire content of the file let range = lsp::Range { start: lsp::Position { line: 0, character: 0 }, - end: lsp::Position { line: 4, character: 6 }, + end: lsp::Position { line: 2, character: 6 }, }; let contents = "new file content".to_string(); let result = replace_file_range_contents(&file_path, range, contents).unwrap(); diff --git a/sazid/src/app/lsi/symbol_types.rs b/sazid/src/app/lsi/symbol_types.rs index e290ad3..4bf8c95 100644 --- a/sazid/src/app/lsi/symbol_types.rs +++ b/sazid/src/app/lsi/symbol_types.rs @@ -36,7 +36,7 @@ pub struct SerializableSourceSymbol { pub kind: lsp::SymbolKind, pub tags: Option>, pub range: lsp::Range, - pub selection_range: lsp::Range, + //pub selection_range: lsp::Range, pub workspace_path: PathBuf, pub file_path: PathBuf, pub hash: [u8; 32], @@ -51,7 +51,7 @@ impl From> for SerializableSourceSymbol { kind: symbol.kind, tags: symbol.tags.clone(), range: *symbol.range.lock().unwrap(), - selection_range: *symbol.selection_range.lock().unwrap(), + //selection_range: *symbol.selection_range.lock().unwrap(), workspace_path: symbol.workspace_path.clone(), file_path: symbol.file_path.clone(), hash: symbol.symbol_id, @@ -100,7 +100,10 @@ impl SourceSymbol { tags: doc_sym.tags.clone(), range: Arc::new(Mutex::new(doc_sym.range)), selection_range: Arc::new(Mutex::new(doc_sym.selection_range)), - file_path: file_path.to_path_buf(), + file_path: file_path + .strip_prefix(workspace_path) + .expect("file is not in workspace directory") + .to_path_buf(), parent: Arc::new(Mutex::new(Weak::new())), children: Arc::new(Mutex::new(vec![])), workspace_path: workspace_path.to_path_buf(), diff --git a/sazid/src/app/lsi/tool_impl.rs b/sazid/src/app/lsi/tool_impl.rs index dfa69dd..82e2191 100644 --- a/sazid/src/app/lsi/tool_impl.rs +++ b/sazid/src/app/lsi/tool_impl.rs @@ -156,17 +156,44 @@ impl LanguageServerInterface { .files .iter() .filter(|file| pattern.is_match(&file.file_path.display().to_string())) - .map(|file| file.file_path.clone()) + .map(|file| { + file + .file_path + .strip_prefix(workspace.workspace_path.clone()) + .expect("file path is not in workspace") + }) .collect::>(); Ok(json!(files).to_string()) }, None => { - let files = workspace.files.iter().map(|file| file.file_path.clone()).collect::>(); + let files = workspace + .files + .iter() + .map(|file| { + file + .file_path + .strip_prefix(workspace.workspace_path.clone()) + .expect("file path is not in workspace") + }) + .collect::>(); Ok(json!(files).to_string()) }, } } + pub fn lsi_read_symbol_source(&mut self, lsi_query: &LsiQuery) -> anyhow::Result { + match self.get_workspace(lsi_query)?.query_symbols(lsi_query) { + Ok(symbols) => match symbols.len() { + 0 => Ok("no symbols found".to_string()), + _ => { + let symbol = symbols.first().unwrap(); + symbol.get_source() + }, + }, + Err(e) => Err(anyhow::anyhow!("error querying workspace symbols: {}", e)), + } + } + pub fn lsi_replace_symbol_text( &mut self, replacement_text: String, diff --git a/sazid/src/app/lsi/workspace.rs b/sazid/src/app/lsi/workspace.rs index 5b2eea0..84950fc 100644 --- a/sazid/src/app/lsi/workspace.rs +++ b/sazid/src/app/lsi/workspace.rs @@ -2,8 +2,8 @@ use super::query::LsiQuery; use super::symbol_types::SourceSymbol; use super::workspace_file::WorkspaceFile; use helix_core::syntax::{FileType, LanguageConfiguration}; -use lsp_types::{DocumentSymbol, TextDocumentIdentifier}; use helix_lsp::Client; +use lsp_types::{DocumentSymbol, TextDocumentIdentifier}; use std::path::{Path, PathBuf}; use std::sync::{Arc, Weak}; @@ -98,12 +98,12 @@ impl Workspace { if let Some(regex) = &query.file_path_regex { let regex = regex::Regex::new(regex).unwrap(); - if !self - .files - .iter() - .any(|f| regex.is_match(f.file_path.file_name().unwrap().to_str().unwrap())) - { - return Err(anyhow::anyhow!("no files match the provided regex")); + if !self.files.iter().any(|f| { + let file_path = f.file_path.to_str().unwrap(); + log::warn!("\nfile_path: {:?}\nregex: {:?}", file_path, regex); + regex.is_match(file_path) + }) { + return Err(anyhow::anyhow!("no files match the provided regex\nregex: {:?}", regex)); } } @@ -115,8 +115,7 @@ impl Workspace { .filter(|s| { if let Some(file_name) = &query.file_path_regex { s.file_path.file_name().unwrap().to_str().unwrap() == file_name - || &s.file_path.strip_prefix(&self.workspace_path).unwrap().display().to_string() - == file_name + || &s.file_path.display().to_string() == file_name } else { true } diff --git a/sazid/src/app/model_tools/lsp_read_symbol_source.rs b/sazid/src/app/model_tools/lsp_read_symbol_source.rs new file mode 100644 index 0000000..035ebb4 --- /dev/null +++ b/sazid/src/app/model_tools/lsp_read_symbol_source.rs @@ -0,0 +1,88 @@ +use futures_util::Future; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::pin::Pin; + +use crate::action::{ChatToolAction, LsiAction}; +use crate::app::lsi::query::LsiQuery; + +use super::errors::ToolCallError; +use super::tool_call::{ToolCallParams, ToolCallTrait}; +use super::types::*; + +#[derive(Serialize, Deserialize)] +pub struct LspReadSymbolSource { + pub name: String, + pub description: String, + pub parameters: FunctionProperty, +} + +impl ToolCallTrait for LspReadSymbolSource { + fn init() -> Self + where + Self: Sized, + { + LspReadSymbolSource { + name: "lsp_read_symbol_source".to_string(), + description: "read the source code represented by symbol_id".to_string(), + parameters: FunctionProperty::Parameters { + properties: HashMap::from([( + "symbol_id".to_string(), + FunctionProperty::Array { + required: true, + description: Some("the 32 byte symbol_id to read".to_string()), + items: Box::new(FunctionProperty::Integer { + description: None, + required: true, + minimum: Some(0), + maximum: Some(255), + }), + min_items: Some(32), + max_items: Some(32), + }, + )]), + }, + } + } + + fn name(&self) -> &str { + &self.name + } + fn parameters(&self) -> FunctionProperty { + self.parameters.clone() + } + + fn description(&self) -> String { + self.description.clone() + } + + fn call( + &self, + params: ToolCallParams, + ) -> Pin, ToolCallError>> + Send + 'static>> { + let validated_arguments = validate_arguments(params.function_args, &self.parameters, None) + .expect("error validating arguments"); + + let symbol_id = get_validated_argument(&validated_arguments, "symbol_id"); + + let workspace_root = + params.session_config.workspace.expect("workspace not set").workspace_path.clone(); + + Box::pin(async move { + let query = LsiQuery { + symbol_id, + workspace_root, + + tool_call_id: params.tool_call_id, + session_id: params.session_id, + ..Default::default() + }; + + params + .tx + .send(ChatToolAction::LsiRequest(Box::new(LsiAction::ReadSymbolSource(query)))) + .unwrap(); + Ok(None) + }) + } +} diff --git a/sazid/src/app/model_tools/mod.rs b/sazid/src/app/model_tools/mod.rs index 4ffcd2d..6f0f0a9 100644 --- a/sazid/src/app/model_tools/mod.rs +++ b/sazid/src/app/model_tools/mod.rs @@ -12,6 +12,7 @@ pub mod lsp_goto_symbol_declaration; pub mod lsp_goto_symbol_definition; pub mod lsp_goto_type_definition; pub mod lsp_query_symbols; +pub mod lsp_read_symbol_source; pub mod lsp_replace_symbol_text; pub mod argument_validation; diff --git a/sazid/src/app/model_tools/tool_call.rs b/sazid/src/app/model_tools/tool_call.rs index 10a1a1d..ceb0a59 100644 --- a/sazid/src/app/model_tools/tool_call.rs +++ b/sazid/src/app/model_tools/tool_call.rs @@ -22,6 +22,7 @@ use super::{ lsp_goto_symbol_definition::LspGotoSymbolDefinition, lsp_goto_type_definition::LspGotoTypeDefinition, lsp_query_symbols::LspQuerySymbol, + lsp_read_symbol_source::LspReadSymbolSource, lsp_replace_symbol_text::LspReplaceSymbolText, types::{FunctionProperty, ToolCall}, }; @@ -105,6 +106,7 @@ impl ChatTools { // Arc::new(FileSearchFunction::init()), Arc::new(LspGetWorkspaceFiles::init()), Arc::new(LspQuerySymbol::init()), + Arc::new(LspReadSymbolSource::init()), Arc::new(LspReplaceSymbolText::init()), Arc::new(LspGotoSymbolDefinition::init()), Arc::new(LspGotoSymbolDeclaration::init()), diff --git a/sazid/src/app/session_config.rs b/sazid/src/app/session_config.rs index 1129bee..74cc7d9 100644 --- a/sazid/src/app/session_config.rs +++ b/sazid/src/app/session_config.rs @@ -44,7 +44,7 @@ impl Default for SessionConfig { workspace: None, tools_enabled: true, accessible_paths: vec![], - model: GPT4_TURBO.clone(), + model: GPT4_O.clone(), retrieval_augmentation_message_count: Some(10), user: "sazid_user_1234".to_string(), function_result_max_tokens: 8192,