Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to set lora weights at startup, add unit test #219

Merged
merged 12 commits into from
Aug 26, 2024
4 changes: 2 additions & 2 deletions Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async Task createButtons()
}
else if (modelIndex > 1)
{
if (modelLicenses[modelIndex] != null) Debug.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license.");
if (modelLicenses[modelIndex] != null) LLMUnitySetup.LogWarning($"The {modelOptions[modelIndex]} model is released under the following license: {modelLicenses[modelIndex]}. By using this model, you agree to the terms of the license.");
string filename = await LLMManager.DownloadModel(modelURLs[modelIndex], true, modelOptions[modelIndex]);
SetModelIfNone(filename, false);
UpdateModels(true);
Expand Down Expand Up @@ -300,7 +300,7 @@ void OnEnable()
}
else
{
isSelected = llmScript.lora.Split(" ").Contains(entry.filename);
isSelected = llmScript.loraManager.Contains(entry.filename);
bool newSelected = EditorGUI.Toggle(selectRect, isSelected);
if (newSelected && !isSelected) llmScript.AddLora(entry.filename);
else if (!newSelected && isSelected) llmScript.RemoveLora(entry.filename);
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ LLM for Unity is built on top of the awesome [llama.cpp](https://github.com/gger

## How to help
- [⭐ Star](https://github.com/undreamai/LLMUnity) the repo, leave us a [review](https://assetstore.unity.com/packages/slug/273604) and spread the word about the project!
- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi!
- [Contribute](CONTRIBUTING.md) by submitting feature requests or bugs as issues or even submiting a PR and become a collaborator!
- Join us at [Discord](https://discord.gg/RwXKQb6zdv) and say hi.
- [Contribute](CONTRIBUTING.md) by submitting feature requests, bugs or even your own PR.
- [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/amakropoulos) this work to allow even cooler features!


## Games using LLM for Unity
- [Verbal Verdict](https://store.steampowered.com/app/2778780/Verbal_Verdict/)
Expand Down
193 changes: 112 additions & 81 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,13 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using UnityEditor;
using UnityEngine;

namespace LLMUnity
{
/// \cond HIDE
public class LLMException : Exception
{
public int ErrorCode { get; private set; }

public LLMException(string message, int errorCode) : base(message)
{
ErrorCode = errorCode;
}
}

public class DestroyException : Exception {}
/// \endcond

[DefaultExecutionOrder(-1)]
/// @ingroup llm
/// <summary>
Expand Down Expand Up @@ -74,6 +59,8 @@ public class LLM : MonoBehaviour
/// <summary> the paths of the LORA models being used (relative to the Assets/StreamingAssets folder).
/// Models with .gguf format are allowed.</summary>
[ModelAdvanced] public string lora = "";
/// <summary> the weights of the LORA models being used.</summary>
[ModelAdvanced] public string loraWeights = "";
/// <summary> enable use of flash attention </summary>
[ModelExtras] public bool flashAttention = false;

Expand All @@ -86,8 +73,10 @@ public class LLM : MonoBehaviour
Thread llmThread = null;
List<StreamWrapper> streamWrappers = new List<StreamWrapper>();
public LLMManager llmManager = new LLMManager();
List<float> loraWeights = new List<float>();
private readonly object startLock = new object();
public LoraManager loraManager = new LoraManager();
string loraPre = "";
string loraWeightsPre = "";

/// \endcond

Expand All @@ -96,6 +85,15 @@ public LLM()
LLMManager.Register(this);
}

void OnValidate()
{
if (lora != loraPre || loraWeights != loraWeightsPre)
{
loraManager.FromStrings(lora, loraWeights);
(loraPre, loraWeightsPre) = (lora, loraWeights);
}
}

/// <summary>
/// The Unity Awake function that starts the LLM server.
/// The server can be started asynchronously if the asynchronousStartup option is set.
Expand Down Expand Up @@ -136,35 +134,55 @@ public static async Task<bool> WaitUntilModelSetup(Callback<float> downloadProgr
return !modelSetupFailed;
}

public string GetModelLoraPathRuntime(string path)
public static string GetLLMManagerAsset(string path)
{
string assetPath = LLMManager.GetAssetPath(path);
if (!string.IsNullOrEmpty(assetPath)) return assetPath;
return path;
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path);
#endif
return GetLLMManagerAssetRuntime(path);
}

public string GetModelLoraPath(string path, bool lora)
public static string GetLLMManagerAssetEditor(string path)
{
// empty
if (string.IsNullOrEmpty(path)) return path;
// LLMManager - return location the file will be stored in StreamingAssets
ModelEntry modelEntry = LLMManager.Get(path);
if (modelEntry != null) return modelEntry.filename;

string modelType = lora ? "Lora" : "Model";
string assetPath = LLMUnitySetup.GetAssetPath(path);
// StreamingAssets - return relative location within StreamingAssets
string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
string basePath = LLMUnitySetup.GetAssetPath();
if (File.Exists(assetPath))
{
if (LLMUnitySetup.IsSubPath(assetPath, basePath)) return LLMUnitySetup.RelativePath(assetPath, basePath);
}
// full path
if (!File.Exists(assetPath))
{
LLMUnitySetup.LogError($"The {modelType} file {path} was not found.");
return path;
LLMUnitySetup.LogError($"Model {path} was not found.");
}

if (!LLMUnitySetup.IsSubPath(assetPath, LLMUnitySetup.GetAssetPath()))
else
{
string errorMessage = $"The {modelType} file {path} was loaded locally. If you want to include it in the build:";
errorMessage += $"\n-Copy the {modelType} inside the StreamingAssets folder and use its relative path or";
errorMessage += $"\n-Load the {modelType} with the LLMManager: `string filename=LLMManager.Load{modelType}(path); llm.Set{modelType}(filename)`";
string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:";
errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
LLMUnitySetup.LogWarning(errorMessage);
}
return assetPath;
return path;
}

public static string GetLLMManagerAssetRuntime(string path)
{
// empty
if (string.IsNullOrEmpty(path)) return path;
// LLMManager
string managerPath = LLMManager.GetAssetPath(path);
if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath;
// StreamingAssets
string assetPath = LLMUnitySetup.GetAssetPath(path);
if (File.Exists(assetPath)) return assetPath;
// give up
return path;
}

/// <summary>
Expand All @@ -175,11 +193,11 @@ public string GetModelLoraPath(string path, bool lora)
/// <param name="path">path to model to use (.gguf format)</param>
public void SetModel(string path)
{
model = GetModelLoraPath(path, false);
model = GetLLMManagerAsset(path);
if (!string.IsNullOrEmpty(model))
{
ModelEntry modelEntry = LLMManager.Get(model);
if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model));
if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model));
SetTemplate(modelEntry.chatTemplate);
if (contextSize == 0 && modelEntry.contextLength > 32768)
{
Expand All @@ -197,10 +215,11 @@ public void SetModel(string path)
/// Models supported are in .gguf format.
/// </summary>
/// <param name="path">path to LORA model to use (.gguf format)</param>
public void SetLora(string path)
public void SetLora(string path, float weight = 1)
{
lora = "";
AddLora(path);
AssertNotStarted();
loraManager.Clear();
AddLora(path, weight);
}

/// <summary>
Expand All @@ -209,15 +228,11 @@ public void SetLora(string path)
/// Models supported are in .gguf format.
/// </summary>
/// <param name="path">path to LORA model to use (.gguf format)</param>
public void AddLora(string path)
public void AddLora(string path, float weight = 1)
{
string loraPath = GetModelLoraPath(path, true);
if (lora.Split(" ").Contains(loraPath)) return;
if (lora != "") lora += " ";
lora += loraPath;
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
#endif
AssertNotStarted();
loraManager.Add(path, weight);
UpdateLoras();
}

/// <summary>
Expand All @@ -227,15 +242,37 @@ public void AddLora(string path)
/// <param name="path">path to LORA model to remove (.gguf format)</param>
public void RemoveLora(string path)
{
string loraPath = GetModelLoraPath(path, true);
List<string> loras = new List<string>(lora.Split(" "));
loras.Remove(loraPath);
lora = "";
for (int i = 0; i < loras.Count; i++)
{
if (i > 0) lora += " ";
lora += loras[i];
}
AssertNotStarted();
loraManager.Remove(path);
UpdateLoras();
}

/// <summary>
/// Allows to remove all LORA models from the LLM.
/// </summary>
public void RemoveLoras()
{
AssertNotStarted();
loraManager.Clear();
UpdateLoras();
}

/// <summary>
/// Allows to change the scale (weight) of a LORA model in the LLM.
/// </summary>
/// <param name="path">path of LORA model to change (.gguf format)</param>
/// <param name="scale">scale of LORA</param>
public void SetLoraScale(string path, float scale)
{
loraManager.SetWeight(path, scale);
UpdateLoras();
if (started) ApplyLoras();
}

public void UpdateLoras()
{
(lora, loraWeights) = loraManager.ToStrings();
(loraPre, loraWeightsPre) = (lora, loraWeights);
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
#endif
Expand Down Expand Up @@ -271,7 +308,7 @@ protected virtual string GetLlamaccpArguments()
LLMUnitySetup.LogError("No model file provided!");
return null;
}
string modelPath = GetModelLoraPathRuntime(model);
string modelPath = GetLLMManagerAssetRuntime(model);
if (!File.Exists(modelPath))
{
LLMUnitySetup.LogError($"File {modelPath} not found!");
Expand All @@ -281,7 +318,7 @@ protected virtual string GetLlamaccpArguments()
foreach (string lora in lora.Trim().Split(" "))
{
if (lora == "") continue;
string loraPath = GetModelLoraPathRuntime(lora);
string loraPath = GetLLMManagerAssetRuntime(lora);
if (!File.Exists(loraPath))
{
LLMUnitySetup.LogError($"File {loraPath} not found!");
Expand Down Expand Up @@ -387,8 +424,7 @@ private void StartService()
llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
llmThread.Start();
while (!llmlib.LLM_Started(LLMObject)) {}
loraWeights = new List<float>();
for (int i = 0; i < lora.Split(" ").Count(); i++) loraWeights.Add(1f);
ApplyLoras();
started = true;
}

Expand Down Expand Up @@ -447,6 +483,16 @@ void AssertStarted()
}
}

void AssertNotStarted()
{
if (started)
{
string error = "This method can't be called when the LLM has started";
LLMUnitySetup.LogError(error);
throw new Exception(error);
}
}

void CheckLLMStatus(bool log = true)
{
if (llmlib == null) { return; }
Expand Down Expand Up @@ -537,46 +583,31 @@ public async Task<string> Embeddings(string json)
/// Sets the lora scale, only works after the LLM service has started
/// </summary>
/// <returns>switch result</returns>
public async Task<string> SetLoraScale(string loraToScale, float scale)
public void ApplyLoras()
{
AssertStarted();
List<string> loras = new List<string>(lora.Split(" "));
string loraToScalePath = GetModelLoraPath(loraToScale, true);

int index = loras.IndexOf(loraToScale);
if (index == -1) index = loras.IndexOf(loraToScalePath);
if (index == -1)
{
LLMUnitySetup.LogError($"LoRA {loraToScale} not loaded with the LLM");
return "";
}

loraWeights[index] = scale;
LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList();
loraWeightRequest.loraWeights = new List<LoraWeightRequest>();
for (int i = 0; i < loraWeights.Count; i++)
float[] weights = loraManager.GetWeights();
for (int i = 0; i < weights.Length; i++)
{
loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = loraWeights[i] });
loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = weights[i] });
}
;

string json = JsonUtility.ToJson(loraWeightRequest);
int startIndex = json.IndexOf("[");
int endIndex = json.LastIndexOf("]") + 1;
json = json.Substring(startIndex, endIndex - startIndex);

LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
{
llmlib.LLM_Lora_Weight(LLMObject, jsonData, strWrapper);
};
return await LLMReply(callback, json);
IntPtr stringWrapper = llmlib.StringWrapper_Construct();
llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper);
llmlib.StringWrapper_Delete(stringWrapper);
}

/// <summary>
/// Gets a list of the lora adapters
/// </summary>
/// <returns>list of lara adapters</returns>
public async Task<string> ListLora()
public async Task<string> ListLoras()
{
AssertStarted();
LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
Expand Down
6 changes: 3 additions & 3 deletions Runtime/LLMManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur
filename = Path.GetFileName(path);
this.label = label == null ? filename : label;
this.lora = lora;
this.path = Path.GetFullPath(path).Replace('\\', '/');
this.path = LLMUnitySetup.GetFullPath(path);
this.url = url;
includeInBuild = true;
chatTemplate = null;
Expand Down Expand Up @@ -162,7 +162,7 @@ public static void SetTemplate(ModelEntry entry, string chatTemplate)
public static ModelEntry Get(string path)
{
string filename = Path.GetFileName(path);
string fullPath = Path.GetFullPath(path).Replace('\\', '/');
string fullPath = LLMUnitySetup.GetFullPath(path);
foreach (ModelEntry entry in modelEntries)
{
if (entry.filename == filename || entry.path == fullPath) return entry;
Expand Down Expand Up @@ -387,7 +387,7 @@ public static void Remove(ModelEntry entry)
foreach (LLM llm in llms)
{
if (!entry.lora && llm.model == entry.filename) llm.model = "";
else if (entry.lora && llm.lora == entry.filename) llm.lora = "";
else if (entry.lora) llm.RemoveLora(entry.filename);
}
}

Expand Down
Loading
Loading