Skip to content

Commit

Permalink
working on building tests to validate ChatTransaction
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmikwolf committed Oct 18, 2023
1 parent 425bbe9 commit 44da7a5
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 72 deletions.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ async-openai = "0.14.3"
async-recursion = "1.0.5"
backoff = { version = "0.4.0", features = ["tokio"] }
better-panic = "0.3.0"
clap = { version = "4.4.5", features = ["derive", "cargo", "wrap_help", "unicode", "string", "unstable-styles"] }
clap = { version = "4.4.5", features = [
"derive",
"cargo",
"wrap_help",
"unicode",
"string",
"unstable-styles",
] }
color-eyre = "0.6.2"
config = "0.13.3"
console-subscriber = "0.2.0"
Expand Down
96 changes: 51 additions & 45 deletions src/app/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl From<ChatTransaction> for Option<CreateChatCompletionStreamResponse> {
pub enum ChatMessage {
Request(ChatCompletionRequestMessage),
Response(ChatChoice),
StreamResponse(Vec<ChatCompletionResponseStreamMessage>),
StreamResponse(ChatCompletionResponseStreamMessage),
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -179,6 +179,7 @@ impl From<FunctionCallStream> for RenderedFunctionCall {
RenderedFunctionCall { name: function_call.name, arguments: function_call.arguments }
}
}

impl From<ChatTransaction> for Vec<RenderedChatMessage> {
fn from(transaction: ChatTransaction) -> Self {
match transaction {
Expand All @@ -196,15 +197,52 @@ impl From<ChatTransaction> for Vec<RenderedChatMessage> {
rendered_response
})
.collect(),
ChatTransaction::StreamResponse(response_stream) => {
let mut rendered_response =
RenderedChatMessage::from(ChatMessage::StreamResponse(response_stream.choices.to_owned()));
rendered_response.id = Some(response_stream.id.clone());
vec![rendered_response]
},
ChatTransaction::StreamResponse(response_stream) => response_stream
.choices
.iter()
.map(|choice| {
// let mut content = String::new();
// let mut function_call: Option<RenderedFunctionCall> = None;
// let mut finish_reason = None;
// let mut role = None;
// content = content + response_stream.delta.content.clone().unwrap_or_default().as_str();
// if let Some(function_call_stream) = response_stream.delta.function_call.clone() {
// function_call = Some(RenderedFunctionCall {
// name: if function_call_stream.name.is_some() {
// Some(format!(
// "{}{}",
// function_call.clone().unwrap().name.unwrap(),
// function_call_stream.name.as_ref().unwrap_or(&String::new())
// ))
// } else {
// None
// },
// arguments: if function_call_stream.arguments.is_some() {
// Some(format!(
// "{}{}",
// function_call.unwrap().arguments.unwrap(),
// function_call_stream.arguments.as_ref().unwrap_or(&String::new())
// ))
// } else {
// None
// },
// })
// }
// if response_stream.finish_reason.is_some() {
// finish_reason = response_stream.finish_reason.clone();
// }
// if response_stream.delta.role.is_some() {
// role = response_stream.delta.role.clone();
// }
let mut rendered_response = RenderedChatMessage::from(ChatMessage::StreamResponse(choice.to_owned()));
rendered_response.id = Some(response_stream.id.clone());
rendered_response
})
.collect(),
}
}
}

impl From<ChatMessage> for RenderedChatMessage {
fn from(message: ChatMessage) -> Self {
match message {
Expand All @@ -222,44 +260,12 @@ impl From<ChatMessage> for RenderedChatMessage {
function_call: response.message.function_call.map(|function_call| function_call.into()),
finish_reason: response.finish_reason,
},
ChatMessage::StreamResponse(response_streams) => {
let mut content = String::new();
let mut function_call: Option<RenderedFunctionCall> = None;
let mut finish_reason = None;
let mut role = None;

for response_stream in response_streams.iter() {
content = format!("{}{}", content, response_stream.delta.content.clone().unwrap_or_default());
if let Some(function_call_stream) = response_stream.delta.function_call.clone() {
function_call = Some(RenderedFunctionCall {
name: if function_call_stream.name.is_some() {
Some(format!(
"{}{}",
function_call.clone().unwrap().name.unwrap(),
function_call_stream.name.as_ref().unwrap_or(&String::new())
))
} else {
None
},
arguments: if function_call_stream.arguments.is_some() {
Some(format!(
"{}{}",
function_call.unwrap().arguments.unwrap(),
function_call_stream.arguments.as_ref().unwrap_or(&String::new())
))
} else {
None
},
})
}
if response_stream.finish_reason.is_some() {
finish_reason = response_stream.finish_reason.clone();
}
if response_stream.delta.role.is_some() {
role = response_stream.delta.role.clone();
}
}
RenderedChatMessage { id: None, role, content, function_call, finish_reason }
ChatMessage::StreamResponse(response_streams) => RenderedChatMessage {
id: None,
role: response_streams.delta.role,
content: response_streams.delta.content.unwrap_or_default(),
function_call: response_streams.delta.function_call.map(|function_call| function_call.into()),
finish_reason: response_streams.finish_reason,
},
}
}
Expand Down
77 changes: 51 additions & 26 deletions src/components/session.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use async_openai::error::OpenAIError;
use async_openai::types::{
ChatChoice, ChatCompletionRequestMessage, ChatCompletionResponseMessage, CreateChatCompletionRequest,
CreateChatCompletionResponse, CreateEmbeddingRequestArgs, CreateEmbeddingResponse, Role,
ChatChoice, ChatCompletionRequestMessage, ChatCompletionResponseMessage, ChatCompletionResponseStreamMessage,
CreateChatCompletionRequest, CreateChatCompletionResponse, CreateChatCompletionStreamResponse,
CreateEmbeddingRequestArgs, CreateEmbeddingResponse, Role,
};
use color_eyre::eyre::Result;
use crossterm::event::{KeyCode, KeyEvent, MouseEvent};
use futures::StreamExt;
use ratatui::layout::Rect;
use ratatui::{prelude::*, symbols::scrollbar, widgets::*};
use ratatui::{prelude::*, symbols::scrollbar, widgets::block::*, widgets::*};
use serde_derive::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
Expand All @@ -34,6 +35,9 @@ use crate::app::gpt_interface::{create_chat_completion_function_args, define_com
use crate::app::tools::utils::ensure_directory_exists;
use crate::components::home::Mode;

use std::fs::File;
use std::io::prelude::*;

#[derive(Serialize, Deserialize, Debug, Clone)]

pub struct SessionConfig {
Expand Down Expand Up @@ -137,16 +141,20 @@ impl Component for Session {
.direction(Direction::Horizontal)
.constraints(vec![Constraint::Length(1), Constraint::Min(10), Constraint::Length(1)])
.split(rects[0]);

// a function that will return a vec with an aribitrary numbr of the same item

let textbox = Layout::default()
.direction(Direction::Vertical)
.constraints(vec![Constraint::Length(1), Constraint::Min(1)])
.constraints(vec![Constraint::Min(2), Constraint::Length(3)])
.split(shorter[1]);

let get_style_from_role = |role| match role {
Role::User => Style::default().fg(Color::Yellow),
Role::Assistant => Style::default().fg(Color::Green),
Role::System => Style::default().fg(Color::Blue),
Role::Function => Style::default().fg(Color::Red),
Some(Role::User) => Style::default().fg(Color::Yellow),
Some(Role::Assistant) => Style::default().fg(Color::Green),
Some(Role::System) => Style::default().fg(Color::Blue),
Some(Role::Function) => Style::default().fg(Color::Red),
None => Style::default(),
};

let title = "Chat";
Expand All @@ -155,28 +163,37 @@ impl Component for Session {
.borders(Borders::ALL)
.gray()
.title(Span::styled(title, Style::default().add_modifier(Modifier::BOLD)));

for transaction in &self.transactions {
let messages: Vec<RenderedChatMessage> = transaction.clone().into();
for (i, message) in messages.iter().enumerate() {
let style = get_style_from_role(message.role.clone().unwrap());
let paragraph =
Paragraph::new(message.content.clone()).style(style).block(block.clone()).wrap(Wrap { trim: true });
f.render_widget(paragraph, textbox[1]);
let mut text = Vec::new();
for (index, transaction) in self.transactions.iter().enumerate() {
let mut style = Style::default().fg(Color::White);
let mut content = String::new() + "test";
let messages = <Vec<RenderedChatMessage>>::from(transaction.clone());
for message in messages.iter() {
style = get_style_from_role(message.role.clone());
content = content + message.content.clone().as_str();
}
text.push(Line::styled(content, style));
}

f.render_stateful_widget(
Scrollbar::default()
.orientation(ScrollbarOrientation::VerticalRight)
.begin_symbol(Some("↑"))
.end_symbol(Some("↓")),
textbox[1],
&mut self.vertical_scroll_state,
);
let block = Block::default()
.borders(Borders::TOP)
.gray()
.title(Title::from("left").alignment(Alignment::Left))
.title(Title::from("right").alignment(Alignment::Right));
let paragraph = Paragraph::new(text).block(block).wrap(Wrap { trim: true });
f.render_widget(paragraph, textbox[0]);

// f.render_stateful_widget(
// Scrollbar::default()
// .orientation(ScrollbarOrientation::VerticalRight)
// .begin_symbol(Some("↑"))
// .end_symbol(Some("↓")),
// textbox[1],
// &mut self.vertical_scroll_state,
// );
Ok(())
}
}

impl Session {
pub fn new() -> Session {
Self::default()
Expand All @@ -194,9 +211,11 @@ impl Session {
match stream_response {
true => {
let mut stream = client.chat().create_stream(request).await.unwrap();
let mut file = File::create("saved_response.txt").unwrap();
while let Some(response_result) = stream.next().await {
match response_result {
Ok(response) => {
let _ = file.write_all(serde_json::to_string(&response).unwrap().as_bytes());
tx.send(Action::ProcessResponse(Box::new(ChatTransaction::StreamResponse(response)))).unwrap()
},
Err(e) => {
Expand All @@ -220,7 +239,13 @@ impl Session {

pub fn process_response_handler(&mut self, transaction: ChatTransaction) {
let tx = self.action_tx.clone().unwrap();
self.transactions.push(transaction);
if let ChatTransaction::StreamResponse(mut t) = transaction {
if let Some(ChatTransaction::StreamResponse(r)) = self.transactions.last_mut() {
r.choices.append(&mut t.choices);
}
} else {
self.transactions.push(transaction);
}
tx.send(Action::Update).unwrap();
}

Expand Down
1 change: 1 addition & 0 deletions tests/assets/saved_stream_response.json

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions tests/stream_response_parsing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
extern crate sazid;
pub mod app;

#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
// a test that reads in the file tests/assets/saved_stream_response.json and parses it
// asserting that it is a ChatTransaction::StreamResponse
#[test]
fn test_stream_response_parsing() {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("tests/assets/saved_stream_response.json");
let mut file = File::open(path).unwrap();
let mut contents = String::new();
let transaction: ChatTransaction;
file.read_to_string(&mut contents).unwrap();
let parsed = serde_json::from_str::<ChatTransaction>(&contents).unwrap();
match parsed {
crate::ChatTransaction::StreamResponse { .. } => {},
_ => panic!("Parsed transaction was not a StreamResponse"),
}
}
}

0 comments on commit 44da7a5

Please sign in to comment.