Skip to content

Commit

Permalink
added create file function, and also made session history save on eve…
Browse files Browse the repository at this point in the history
…ry chat completion[
  • Loading branch information
cosmikwolf committed Jun 2, 2024
1 parent bccf2ef commit f823269
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 308,970 deletions.
17 changes: 17 additions & 0 deletions sazid-term/src/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,23 @@ impl Application {

Some(action) = self.session_events.next() => {
match action.clone() {
SessionAction::SaveSession => {
let data_folder = helix_loader::data_dir().join("session_history");
if !data_folder.exists() {
if let Err(e) = std::fs::create_dir_all(&data_folder) {
self.editor.set_error(format!("error creating data directory: {}", e));
}
}
let save_path = data_folder.join(self.session.config.title.clone()).with_extension("szd");
log::info!("saving session history to: {:#?}", save_path );
match self.session.save_session(save_path.clone()) {
Ok(_) => self.editor.set_status(format!("session saved to: {:?}", save_path)),
Err(e) => {
log::error!("error saving session: {}", e);
self.editor.set_error(format!("error saving session: {}", e));
},
};
},
SessionAction::ChatToolAction(event) => {
chat_tool_tx.send(event).unwrap();
},
Expand Down
2 changes: 1 addition & 1 deletion sazid-term/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ fn save_session(cx: &mut Context) {
}

fn load_session_picker(cx: &mut Context) {
let root = helix_loader::data_dir().join("sessions");
let root = helix_loader::data_dir().join("session_history");
if !root.exists() {
cx.editor.set_error("data directory does not exist");
return;
Expand Down
2 changes: 1 addition & 1 deletion sazid-term/src/ui/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1427,7 +1427,7 @@ pub fn session_picker(root: PathBuf, config: &helix_view::editor::Config) -> Pic
.git_ignore(config.file_picker.git_ignore)
.git_global(config.file_picker.git_global)
.git_exclude(config.file_picker.git_exclude)
.sort_by_file_name(|name1, name2| name1.cmp(name2))
.sort_by_file_name(|name1, name2| name2.cmp(name1))
.max_depth(config.file_picker.max_depth)
.filter_entry(move |entry| filter_picker_entry(entry, &absolute_root, dedup_symlinks));

Expand Down
6 changes: 3 additions & 3 deletions sazid-term/tests/lsi_interface_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ mod test {
match &app.get_session().test_tool_call_response {
Some((lsi_query, content)) => {
// read the contents of the file at workspace_path joined with file_path
let file = std::fs::read_to_string(workspace_path.join("src/main.rs"))?;
let debugprnt = format!("DEBUG:::: ----\n\n{:#?}", file);
println!("{}", debugprnt);
//let file = std::fs::read_to_string(workspace_path.join("src/main.rs"))?;
//let debugprnt = format!("DEBUG:::: ----\n\n{:#?}", file);
//println!("{}", debugprnt);
let symbol = serde_json::from_str::<Vec<SerializableSourceSymbol>>(content)
.expect("failed to parse symbol");
assert_eq!(query, *lsi_query);
Expand Down
308,910 changes: 0 additions & 308,910 deletions sazid-term/tests/test_assets/svd_to_csv/assets/stm32_h5_svd/STM32H563.svd

This file was deleted.

95 changes: 46 additions & 49 deletions sazid/src/app/model_tools/create_file_function.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{
collections::HashMap,
fs::{self, File},
io::Write,
path::Path,
path::{Path, PathBuf},
pin::Pin,
};

Expand All @@ -11,53 +12,52 @@ use serde::{Deserialize, Serialize};
use super::{
errors::ToolCallError,
tool_call::{ToolCallParams, ToolCallTrait},
types::{FunctionProperty, PropertyType},
types::{get_validated_argument, validate_arguments, FunctionProperty},
};

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct CreateFileFunction {
name: String,
description: String,
properties: Vec<FunctionProperty>,
parameters: FunctionProperty,
}

impl ToolCallTrait for CreateFileFunction {
fn name(&self) -> &str {
&self.name
}
fn init() -> Self {
fn init() -> Self
where
Self: Sized,
{
CreateFileFunction {
name: "create_file".to_string(),
description: "create a file at path with text. this command cannot overwrite files"
.to_string(),
properties: vec![
FunctionProperty {
name: "path".to_string(),
required: true,
property_type: PropertyType::String,
description: Some("path to file".to_string()),
enum_values: None,
},
FunctionProperty {
name: "text".to_string(),
required: true,
property_type: PropertyType::String,
description: Some("text to write to file.".to_string()),
enum_values: None,
},
FunctionProperty {
name: "overwrite".to_string(),
required: true,
property_type: PropertyType::Boolean,
description: Some("overwrite an existing file. default false".to_string()),
enum_values: None,
},
],
parameters: FunctionProperty::Parameters {
properties: HashMap::from([
(
"path".to_string(),
FunctionProperty::String {
required: false,
description: Some("path to new file".to_string()),
},
),
(
"content".to_string(),
FunctionProperty::String {
required: false,
description: Some("content of the newly created file".to_string()),
},
),
]),
},
}
}

fn parameters(&self) -> Vec<FunctionProperty> {
self.properties.clone()
fn name(&self) -> &str {
&self.name
}

fn parameters(&self) -> FunctionProperty {
self.parameters.clone()
}

fn description(&self) -> String {
Expand All @@ -68,15 +68,15 @@ impl ToolCallTrait for CreateFileFunction {
&self,
params: ToolCallParams,
) -> Pin<Box<dyn Future<Output = Result<Option<String>, ToolCallError>> + Send + 'static>> {
Box::pin(async move {
let path: Option<&str> = params.function_args.get("path").and_then(|s| s.as_str());
let text: Option<&str> = params.function_args.get("text").and_then(|s| s.as_str());
let overwrite =
params.function_args.get("overwrite").and_then(|b| b.as_bool()).unwrap_or(false);
let validated_arguments = validate_arguments(params.function_args, &self.parameters, None)
.expect("error validating arguments");

let path = get_validated_argument::<PathBuf>(&validated_arguments, "path");
let text = get_validated_argument::<String>(&validated_arguments, "content");
Box::pin(async move {
if let Some(path) = path {
if let Some(text) = text {
create_file(path, text, overwrite)
create_file(&path, text.as_str(), false)
} else {
Err(ToolCallError::new("text argument is required"))
}
Expand All @@ -88,13 +88,10 @@ impl ToolCallTrait for CreateFileFunction {
}

pub fn create_file(
path: &str,
path: &PathBuf,
text: &str,
overwrite: bool,
) -> Result<Option<String>, ToolCallError> {
// Convert the string path to a `Path` object to manipulate file paths.
let path = Path::new(path);

// Attempt to get the parent directory of the path.
if let Some(parent_dir) = path.parent() {
// Try to create the parent directory (and all necessary parent directories).
Expand Down Expand Up @@ -141,7 +138,7 @@ mod tests {
let file_path = tmp_dir.path().join("test_file.txt");
let file_contents = "Test file contents.";

let result = create_file(file_path.to_str().unwrap(), file_contents, false);
let result = create_file(&file_path, file_contents, false);
assert!(result.is_ok());
check_file_contents(&file_path, file_contents);
}
Expand All @@ -154,7 +151,7 @@ mod tests {
let file_path = non_existent_subfolder.join("test_file.txt");
let file_contents = "Test file contents.";

let result = create_file(file_path.to_str().unwrap(), file_contents, false);
let result = create_file(&file_path, file_contents, false);
assert!(result.is_ok());
check_file_contents(&file_path, file_contents);
}
Expand All @@ -166,7 +163,7 @@ mod tests {
let file_path = tmp_dir.path().join("\0"); // Null byte is not allowed in file names.
let file_contents = "Test file contents.";

let result = create_file(file_path.to_str().unwrap(), file_contents, false);
let result = create_file(&file_path, file_contents, false);
assert!(result.is_ok());
assert!(result.unwrap().unwrap().contains("error"));
}
Expand All @@ -180,7 +177,7 @@ mod tests {
let file_path = Path::new(permissions_dir).join("test_file.txt");
let file_contents = "Test file contents.";

let result = create_file(file_path.to_str().unwrap(), file_contents, false);
let result = create_file(&file_path, file_contents, false);
assert!(result.is_ok());
assert!(result.unwrap().unwrap().contains("error"));
}
Expand All @@ -194,7 +191,7 @@ mod tests {
let file_path = Path::new(read_only_dir).join("test_file.txt");
let file_contents = "Test file contents.";

let result = create_file(file_path.to_str().unwrap(), file_contents, false);
let result = create_file(&file_path, file_contents, false);
assert!(result.is_ok());
assert!(result.unwrap().unwrap().contains("error"));
}
Expand All @@ -214,7 +211,7 @@ mod tests {
}

// Perform the operation to create the file again with different contents.
let result = create_file(file_path.to_str().unwrap(), new_contents, false);
let result = create_file(&file_path, new_contents, false);
assert!(result.is_ok());
check_file_contents(&file_path, new_contents);
}
Expand All @@ -226,7 +223,7 @@ mod tests {
let file_path = tmp_dir.path().join("large_test_file.txt");
let file_contents = "a".repeat(10_000_000); // 10 MB of 'a'.

let result = create_file(file_path.to_str().unwrap(), &file_contents, false);
let result = create_file(&file_path, &file_contents, false);
assert!(result.is_ok());
check_file_contents(&file_path, &file_contents);
}
Expand Down
1 change: 1 addition & 0 deletions sazid/src/app/model_tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// pub mod read_file_lines_function;
// pub mod treesitter_function;

pub mod create_file_function;
pub mod lsp_get_diagnostics;
pub mod lsp_get_workspace_files;
pub mod lsp_goto_symbol_declaration;
Expand Down
4 changes: 3 additions & 1 deletion sazid/src/app/model_tools/tool_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionRequestToolMessage, ChatCompletionTool,
ChatCompletionToolType, FunctionObject, Role,
};
use lsp_types::CreateFile;
use serde_json::Value;
use std::{any::Any, collections::HashMap, pin::Pin, sync::Arc};
use tokio::sync::mpsc::UnboundedSender;
Expand All @@ -15,6 +16,7 @@ use futures_util::Future;
use crate::app::session_config::SessionConfig;

use super::{
create_file_function::CreateFileFunction,
errors::ToolCallError,
lsp_get_diagnostics::LspGetDiagnostics,
lsp_get_workspace_files::LspGetWorkspaceFiles,
Expand Down Expand Up @@ -106,7 +108,7 @@ impl ChatTools {
// Arc::new(FileSearchFunction::init()),
Arc::new(LspGetWorkspaceFiles::init()),
Arc::new(LspQuerySymbol::init()),
//Arc::new(LspReadSymbolSource::init()),
Arc::new(CreateFileFunction::init()),
Arc::new(LspReplaceSymbolText::init()),
Arc::new(LspGotoSymbolDefinition::init()),
Arc::new(LspGotoSymbolDeclaration::init()),
Expand Down
13 changes: 8 additions & 5 deletions sazid/src/app/model_tools/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ pub fn validate_arguments(
(Some(value), FunctionProperty::Pattern { required: _, .. }) => {
if let Some(pattern_str) = value.as_str() {
match regex::Regex::new(pattern_str) {
Ok(regex) => {
Ok(_regex) => {
validated_args.insert(name.clone(), Value::String(pattern_str.to_string()));
},
Err(err) => {
Expand All @@ -267,10 +267,13 @@ pub fn validate_arguments(
let path = PathBuf::from(path_str);

// return an error if path is not within workspace
if let Some(workspace_dir) = workspace_root {
if !path.starts_with(workspace_dir) {
return Err("cannot read files outside of the current working directory".into());
}
match workspace_root {
Some(workspace_dir) => {
if !path.starts_with(workspace_dir) {
return Err("cannot read files outside of the current working directory".into());
}
},
None => return Err("cannot create files without a workspace set".into()),
}

if !path.is_absolute() {
Expand Down
2 changes: 2 additions & 0 deletions sazid/src/app/session_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct WorkspaceParams {
pub struct SessionConfig {
pub prompt: String,
pub id: String,
pub title: String,
pub session_dir: PathBuf,
pub disabled_tools: Vec<String>,
pub tools_enabled: bool,
Expand All @@ -39,6 +40,7 @@ impl Default for SessionConfig {
SessionConfig {
prompt: String::new(),
id: Self::generate_session_id(),
title: chrono::Utc::now().to_rfc3339(),
session_dir: PathBuf::new(),
disabled_tools: vec![],
workspace: None,
Expand Down

0 comments on commit f823269

Please sign in to comment.