Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude system prompt from saving of chat history #240

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,17 +239,17 @@ protected async Task LoadHistory()
}
}

protected string GetSavePath(string filename)
public string GetSavePath(string filename)
{
return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
}

protected string GetJsonSavePath(string filename)
public string GetJsonSavePath(string filename)
{
return GetSavePath(filename + ".json");
}

protected string GetCacheSavePath(string filename)
public string GetCacheSavePath(string filename)
{
// this is saved already in the Application.persistentDataPath folder
return GetSavePath(filename + ".cache");
Expand Down Expand Up @@ -648,7 +648,7 @@ public async Task<string> Save(string filename)
string filepath = GetJsonSavePath(filename);
string dirname = Path.GetDirectoryName(filepath);
if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat });
string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
File.WriteAllText(filepath, json);

string cachepath = GetCacheSavePath(filename);
Expand All @@ -671,7 +671,9 @@ public async Task<string> Load(string filename)
return null;
}
string json = File.ReadAllText(filepath);
chat = JsonUtility.FromJson<ChatListWrapper>(json).chat;
List<ChatMessage> chatHistory = JsonUtility.FromJson<ChatListWrapper>(json).chat;
InitPrompt(true);
chat.AddRange(chatHistory);
LLMUnitySetup.Log($"Loaded {filepath}");

string cachepath = GetCacheSavePath(filename);
Expand Down
51 changes: 47 additions & 4 deletions Tests/Runtime/TestLLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ public virtual async Task Tests()

public void TestInitParameters(int nkeep, int chats)
{
Assert.That(llmCharacter.nKeep == nkeep);
Assert.AreEqual(llmCharacter.nKeep, nkeep);
Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.AIName).Length > 0);
Assert.That(llmCharacter.chat.Count == chats);
Assert.AreEqual(llmCharacter.chat.Count, chats);
}

public void TestTokens(List<int> tokens)
Expand Down Expand Up @@ -410,7 +410,7 @@ public override async Task Tests()
public class TestLLM_Double : TestLLM
{
LLM llm1;
LLMCharacter lLMCharacter1;
LLMCharacter llmCharacter1;

public override async Task Init()
{
Expand All @@ -421,8 +421,51 @@ public override async Task Init()
llm = CreateLLM();
llmCharacter = CreateLLMCharacter();
llm1 = CreateLLM();
lLMCharacter1 = CreateLLMCharacter();
llmCharacter1 = CreateLLMCharacter();
gameObject.SetActive(true);
}
}

public class TestLLMCharacter_Save : TestLLM
{
string saveName = "TestLLMCharacter_Save";

public override LLMCharacter CreateLLMCharacter()
{
LLMCharacter llmCharacter = base.CreateLLMCharacter();
llmCharacter.save = saveName;
llmCharacter.saveCache = true;
return llmCharacter;
}

public override async Task Tests()
{
await base.Tests();
TestSave();
}

public void TestSave()
{
string jsonPath = llmCharacter.GetJsonSavePath(saveName);
string cachePath = llmCharacter.GetCacheSavePath(saveName);
Assert.That(File.Exists(jsonPath));
Assert.That(File.Exists(cachePath));
string json = File.ReadAllText(jsonPath);
File.Delete(jsonPath);
File.Delete(cachePath);

List<ChatMessage> chatHistory = JsonUtility.FromJson<ChatListWrapper>(json).chat;
Assert.AreEqual(chatHistory.Count, 2);
Assert.AreEqual(chatHistory[0].role, llmCharacter.playerName);
Assert.AreEqual(chatHistory[0].content, "hi");
Assert.AreEqual(chatHistory[1].role, llmCharacter.AIName);

Assert.AreEqual(llmCharacter.chat.Count, chatHistory.Count + 1);
for (int i = 0; i < chatHistory.Count; i++)
{
Assert.AreEqual(chatHistory[i].role, llmCharacter.chat[i + 1].role);
Assert.AreEqual(chatHistory[i].content, llmCharacter.chat[i + 1].content);
}
}
}
}
Loading