Skip to content

Commit

Permalink
Merge pull request #33 from undreamai/feature/llamafile_v0.6
Browse files Browse the repository at this point in the history
AMD support, switch to llamafile 0.6
  • Loading branch information
amakropoulos authored Jan 17, 2024
2 parents 8c830ca + c5f3892 commit def2237
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 286 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ LLMUnity.csproj.meta
hooks.meta
hooks/pre-commit.meta
setup.sh.meta
*.api
*.api.meta
2 changes: 1 addition & 1 deletion Editor/LLMEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public override void OnInspectorGUI()
AddServerSettings(llmScriptSO);
GUI.enabled = LLM.binariesProgress == 1 && llmScript.modelProgress == 1 && llmScript.modelCopyProgress == 1;
AddModelLoaders(llmScriptSO, llmScript);
ShowProgress(LLM.binariesProgress, "Binaries Downloading");
ShowProgress(LLM.binariesProgress, "Setup Binaries");
ShowProgress(llmScript.modelProgress, "Model Downloading");
ShowProgress(llmScript.modelCopyProgress, "Model Copying");
if (llmScript.model != "")
Expand Down
28 changes: 0 additions & 28 deletions Editor/undream.llmunity.Editor.api

This file was deleted.

7 changes: 0 additions & 7 deletions Editor/undream.llmunity.Editor.api.meta

This file was deleted.

79 changes: 55 additions & 24 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
Expand Down Expand Up @@ -25,8 +26,8 @@ public class LLM : LLMClient
[ModelAdvanced] public int batchSize = 512;

[HideInInspector] public string modelUrl = "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/resolve/main/mistral-7b-instruct-v0.2.Q4_K_M.gguf?download=true";
private static readonly string serverUrl = "https://github.com/Mozilla-Ocho/llamafile/releases/download/0.4.1/llamafile-server-0.4.1";
private static readonly string server = GetAssetPath("llamafile-server.exe");
private static readonly string serverZipUrl = "https://github.com/Mozilla-Ocho/llamafile/releases/download/0.6/llamafile-0.6.zip";
private static readonly string server = Path.Combine(GetAssetPath(Path.GetFileNameWithoutExtension(serverZipUrl)), "bin/llamafile");
private static readonly string apeARMUrl = "https://cosmo.zip/pub/cosmos/bin/ape-arm64.elf";
private static readonly string apeARM = GetAssetPath("ape-arm64.elf");
private static readonly string apeX86_64Url = "https://cosmo.zip/pub/cosmos/bin/ape-x86_64.elf";
Expand All @@ -38,7 +39,7 @@ public class LLM : LLMClient
private static float binariesDone = 0;
private Process process;
private bool serverListening = false;
public ManualResetEvent serverStarted = new ManualResetEvent(false);
private ManualResetEvent serverBlock = new ManualResetEvent(false);

private static string GetAssetPath(string relPath = "")
{
Expand All @@ -51,25 +52,34 @@ private static string GetAssetPath(string relPath = "")
private static async Task InitializeOnLoad()
{
// Perform download when the build is finished
await DownloadBinaries();
await SetupBinaries();
}

private static async Task DownloadBinaries()
private static async Task SetupBinaries()
{
if (binariesProgress == 0) return;
binariesProgress = 0;
binariesDone = 0;
foreach ((string url, string path) in new[] {(serverUrl, server), (apeARMUrl, apeARM), (apeX86_64Url, apeX86_64)})
if (!File.Exists(apeARM)) await LLMUnitySetup.DownloadFile(apeARMUrl, apeARM, true, null, SetBinariesProgress);
binariesDone += 1;
if (!File.Exists(apeX86_64)) await LLMUnitySetup.DownloadFile(apeX86_64Url, apeX86_64, true, null, SetBinariesProgress);
binariesDone += 1;
if (!File.Exists(server))
{
if (!File.Exists(path)) await LLMUnitySetup.DownloadFile(url, path, true, null, SetBinariesProgress);
string serverZip = Path.Combine(Application.dataPath, "llamafile.zip");
if (!File.Exists(serverZip)) await LLMUnitySetup.DownloadFile(serverZipUrl, serverZip, false, null, SetBinariesProgress);
binariesDone += 1;
LLMUnitySetup.ExtractZip(serverZip, GetAssetPath());
LLMUnitySetup.makeExecutable(server);
File.Delete(serverZip);
binariesDone += 1;
}
binariesProgress = 1;
}

public static void SetBinariesProgress(float progress)
{
binariesProgress = binariesDone / 3f + 1f / 3f * progress;
binariesProgress = binariesDone / 4f + 1f / 4f * progress;
}

public void DownloadModel()
Expand Down Expand Up @@ -157,13 +167,38 @@ private void CheckIfListening(string message)
if (status.message == "HTTP server listening")
{
Debug.Log("LLM Server started!");
serverStarted.Set();
serverListening = true;
serverBlock.Set();
}
}
catch { }
}

private void ProcessExited(object sender, EventArgs e)
{
serverBlock.Set();
}

private void RunServerCommand(string exe, string args)
{
string binary = exe;
string arguments = args;
List<(string, string)> environment = null;
if (Application.platform != RuntimePlatform.WindowsEditor && Application.platform != RuntimePlatform.WindowsPlayer)
{
// use APE binary directly if not on Windows
arguments = $"\"{binary}\" {arguments}";
binary = SelectApeBinary();
if (numGPULayers <= 0)
{
// prevent nvcc building if not using GPU
environment = new List<(string, string)> { ("PATH", ""), ("CUDA_PATH", "") };
}
}
Debug.Log($"Server command: {binary} {arguments}");
process = LLMUnitySetup.CreateProcess(binary, arguments, CheckIfListening, DebugLogError, ProcessExited, environment);
}

private void StartLLMServer()
{
// Start the LLM server in a cross-platform way
Expand All @@ -182,25 +217,21 @@ private void StartLLMServer()
string binary = server;
string arguments = $" --port {port} -m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable --nobrowser -np {slots}";
if (numThreads > 0) arguments += $" -t {numThreads}";
if (numGPULayers > 0) arguments += $" -ngl {numGPULayers}";
if (loraPath != "") arguments += $" --lora \"{loraPath}\"";
List<(string, string)> environment = null;

if (Application.platform != RuntimePlatform.WindowsEditor && Application.platform != RuntimePlatform.WindowsPlayer)
string GPUArgument = numGPULayers <= 0 ? "" : $" -ngl {numGPULayers}";
RunServerCommand(binary, arguments + GPUArgument);
serverBlock.WaitOne(60000);

if (process.HasExited && numGPULayers > 0)
{
// use APE binary directly if not on Windows
arguments = $"\"{binary}\" {arguments}";
binary = SelectApeBinary();
if (numGPULayers <= 0)
{
// prevent nvcc building if not using GPU
environment = new List<(string, string)> { ("PATH", ""), ("CUDA_PATH", "") };
}
Debug.Log("GPU failed, fallback to CPU");
serverBlock.Reset();
RunServerCommand(binary, arguments);
serverBlock.WaitOne(60000);
}
Debug.Log($"Server command: {binary} {arguments}");
process = LLMUnitySetup.CreateProcess(binary, arguments, CheckIfListening, DebugLogError, environment);
// wait for at most 2'
serverStarted.WaitOne(60000);

if (process.HasExited) throw new System.Exception("Server could not be started!");
}

public void StopProcess()
Expand Down
41 changes: 32 additions & 9 deletions Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Debug = UnityEngine.Debug;
using System.Threading.Tasks;
using System.Collections.Generic;
using System;
using System.IO.Compression;

namespace LLMUnity
{
Expand All @@ -19,7 +19,7 @@ public class LLMUnitySetup : MonoBehaviour
{
public static Process CreateProcess(
string command, string commandArgs = "",
Callback<string> outputCallback = null, Callback<string> errorCallback = null,
Callback<string> outputCallback = null, Callback<string> errorCallback = null, System.EventHandler exitCallback = null,
List<(string, string)> environment = null,
bool redirectOutput = false, bool redirectError = false
)
Expand All @@ -44,6 +44,11 @@ public static Process CreateProcess(
Process process = new Process { StartInfo = startInfo };
if (outputCallback != null) process.OutputDataReceived += (sender, e) => outputCallback(e.Data);
if (errorCallback != null) process.ErrorDataReceived += (sender, e) => errorCallback(e.Data);
if (exitCallback != null)
{
process.EnableRaisingEvents = true;
process.Exited += exitCallback;
}
process.Start();
if (outputCallback != null) process.BeginOutputReadLine();
if (errorCallback != null) process.BeginErrorReadLine();
Expand All @@ -53,13 +58,22 @@ public static Process CreateProcess(
public static string RunProcess(string command, string commandArgs = "", Callback<string> outputCallback = null, Callback<string> errorCallback = null)
{
// run a process and re#turn the output
Process process = CreateProcess(command, commandArgs, null, null, null, true);
Process process = CreateProcess(command, commandArgs, null, null, null, null, true);
string output = process.StandardOutput.ReadToEnd();
process.WaitForExit();
return output;
}

#if UNITY_EDITOR
public static void makeExecutable(string path)
{
if (Application.platform != RuntimePlatform.WindowsEditor && Application.platform != RuntimePlatform.WindowsPlayer)
{
// macOS/Linux: Set executable permissions using chmod
RunProcess("chmod", $"+x \"{path}\"");
}
}

public static async Task DownloadFile(
string fileUrl, string savePath, bool executable = false,
TaskCallback<string> callback = null, Callback<float> progresscallback = null,
Expand Down Expand Up @@ -117,11 +131,7 @@ public static async Task DownloadFile(
}
}

if (executable && Application.platform != RuntimePlatform.WindowsEditor && Application.platform != RuntimePlatform.WindowsPlayer)
{
// macOS/Linux: Set executable permissions using chmod
RunProcess("chmod", $"+x \"{savePath}\"");
}
if (executable) makeExecutable(savePath);
AssetDatabase.StopAssetEditing();
Debug.Log($"Download complete!");
}
Expand All @@ -143,7 +153,7 @@ public static async Task<string> AddAsset(string assetPath, string basePath)
{
// if the asset is not in the assets dir copy it over
fullPath = Path.Combine(basePath, Path.GetFileName(assetPath));
Debug.Log("copying " + assetPath + " to " + fullPath);
Debug.Log($"copying {assetPath} to {fullPath}");
AssetDatabase.StartAssetEditing();
await Task.Run(() =>
{
Expand All @@ -160,6 +170,19 @@ await Task.Run(() =>
return fullPath.Substring(basePath.Length + 1);
}

public static void ExtractZip(string zipPath, string extractToPath)
{
Debug.Log($"extracting {zipPath} to {extractToPath}");
AssetDatabase.StartAssetEditing();
if (!Directory.Exists(extractToPath))
{
Directory.CreateDirectory(extractToPath);
}
ZipFile.ExtractToDirectory(zipPath, extractToPath);
AssetDatabase.StopAssetEditing();
Debug.Log($"extraction complete!");
}

#endif
}
}
Loading

0 comments on commit def2237

Please sign in to comment.