Skip to content

Commit

Permalink
Agents (#92)
Browse files Browse the repository at this point in the history
* added more settings for llama sharp configuration

* Agents

* Fixed forgotten message limit exception

* Changed MessagesLimitReachedException to InvalidOperationException
  • Loading branch information
TesAnti authored Dec 29, 2023
1 parent 94c3fea commit 052e3ca
Show file tree
Hide file tree
Showing 14 changed files with 481 additions and 14 deletions.
21 changes: 21 additions & 0 deletions src/libs/LangChain.Core/Chains/Chain.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Chains.HelperChains;
using LangChain.Chains.StackableChains;
using LangChain.Chains.StackableChains.Agents;
using LangChain.Chains.StackableChains.ReAct;
using LangChain.Indexes;
using LangChain.Memory;
using LangChain.Providers;
Expand Down Expand Up @@ -130,4 +132,23 @@ public static STTChain<T> STT<T>(ISpeechToTextModel<T> model,
{
return new STTChain<T>(model, settings, inputKey, outputKey);
}

public static ReActAgentExecutorChain ReActAgentExecutor(IChatModel model, string reActPrompt = null,

Check warning on line 136 in src/libs/LangChain.Core/Chains/Chain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Cannot convert null literal to non-nullable reference type.

Check warning on line 136 in src/libs/LangChain.Core/Chains/Chain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Missing XML comment for publicly visible type or member 'Chain.ReActAgentExecutor(IChatModel, string, int, string, string)'
int maxActions = 5, string inputKey = "input",
string outputKey = "final_answer")
{
return new ReActAgentExecutorChain(model, reActPrompt, maxActions, inputKey, outputKey);
}

public static ReActParserChain ReActParser(
string inputKey = "text", string outputKey = "answer")
{
return new ReActParserChain(inputKey, outputKey);
}

public static GroupChat GroupChat(
IList<AgentExecutorChain> agents, string? stopPhrase = null, int messagesLimit = 10, string inputKey = "input", string outputKey = "output")
{
return new GroupChat(agents, stopPhrase, messagesLimit, inputKey, outputKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Chains.HelperChains;
using LangChain.Memory;
using LangChain.Providers;
using LangChain.Schema;

namespace LangChain.Chains.StackableChains.Agents;

public class AgentExecutorChain: BaseStackableChain
{
public string HistoryKey { get; }
private readonly BaseStackableChain _originalChain;

private BaseStackableChain _chainWithHistory;

public string Name { get; private set; }

Check warning on line 17 in src/libs/LangChain.Core/Chains/StackableChains/Agents/AgentExecutorChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

'AgentExecutorChain.Name' hides inherited member 'BaseStackableChain.Name'. Use the new keyword if hiding was intended.

/// <summary>
/// Messages of this agent will not be added to the history
/// </summary>
public bool IsObserver { get; set; } = false;

public AgentExecutorChain(BaseStackableChain originalChain, string name, string historyKey="history",

Check warning on line 24 in src/libs/LangChain.Core/Chains/StackableChains/Agents/AgentExecutorChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable field '_chainWithHistory' must contain a non-null value when exiting constructor. Consider declaring the field as nullable.
string outputKey = "final_answer")
{
Name = name;
HistoryKey = historyKey;
_originalChain = originalChain;

InputKeys = new[] { historyKey};
OutputKeys = new[] { outputKey };

SetHistory("");
}

public void SetHistory(string history)
{

_chainWithHistory =
Chain.Set(history, HistoryKey)
|_originalChain;
}

protected override async Task<IChainValues> InternalCall(IChainValues values)
{
var res=await _chainWithHistory.CallAsync(values);
return res;
}
}
93 changes: 93 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using LangChain.Abstractions.Schema;
using LangChain.Chains.HelperChains;
using LangChain.Chains.HelperChains.Exceptions;
using LangChain.Memory;
using LangChain.Providers;

namespace LangChain.Chains.StackableChains.Agents;

public class GroupChat:BaseStackableChain
{
private readonly IList<AgentExecutorChain> _agents;

private readonly string _stopPhrase;
private readonly int _messagesLimit;
private readonly string _inputKey;
private readonly string _outputKey;


int _currentAgentId=0;
private readonly ConversationBufferMemory _conversationBufferMemory;


public bool ThrowOnLimit { get; set; } = false;
public GroupChat(IList<AgentExecutorChain> agents, string? stopPhrase=null, int messagesLimit=10, string inputKey="input", string outputKey="output")

Check warning on line 24 in src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Non-nullable field '_stopPhrase' must contain a non-null value when exiting constructor. Consider declaring the field as nullable.
{
_agents = agents;

_stopPhrase = stopPhrase;

Check warning on line 28 in src/libs/LangChain.Core/Chains/StackableChains/Agents/GroupChat.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference assignment.
_messagesLimit = messagesLimit;
_inputKey = inputKey;
_outputKey = outputKey;
_conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false };
InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };

}

public IReadOnlyList<Message> GetHistory()
{
return _conversationBufferMemory.ChatHistory.Messages;
}


protected override async Task<IChainValues> InternalCall(IChainValues values)
{

await _conversationBufferMemory.Clear().ConfigureAwait(false);
foreach (var agent in _agents)
{
agent.SetHistory("");
}
var firstAgent = _agents[0];
var firstAgentMessage = (string)values.Value[_inputKey];
await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{firstAgent.Name}: {firstAgentMessage}",
MessageRole.System)).ConfigureAwait(false);
int messagesCount = 1;
while (messagesCount<_messagesLimit)
{
var agent = GetNextAgent();
agent.SetHistory(_conversationBufferMemory.BufferAsString+"\n"+$"{agent.Name}:");
var res = await agent.CallAsync(values).ConfigureAwait(false);
var message = (string)res.Value[agent.OutputKeys[0]];
if (message.Contains(_stopPhrase))
{
break;
}

if (!agent.IsObserver)
{
await _conversationBufferMemory.ChatHistory.AddMessage(new Message($"{agent.Name}: {message}",
MessageRole.System)).ConfigureAwait(false);
}
}

var result = _conversationBufferMemory.ChatHistory.Messages.Last();
messagesCount = _conversationBufferMemory.ChatHistory.Messages.Count;
if (ThrowOnLimit && messagesCount >= _messagesLimit)
{
throw new InvalidOperationException($"Message limit reached:{_messagesLimit}");
}
values.Value.Add(_outputKey, result);
return values;

}

AgentExecutorChain GetNextAgent()
{
_currentAgentId++;
if (_currentAgentId >= _agents.Count)
_currentAgentId = 0;
return _agents[_currentAgentId];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using LangChain.Chains.HelperChains;
using LangChain.Providers;

namespace LangChain.Chains.StackableChains.Agents;

public class PromptedAgent: AgentExecutorChain
{
public const string Template =
@"{system}
{history}";

private static BaseStackableChain MakeChain(string name, string system, IChatModel model, string outputKey)
{
return Chain.Set(system, "system")
| Chain.Template(Template)
| Chain.LLM(model,outputKey: outputKey);
}


public PromptedAgent(string name, string prompt, IChatModel model, string outputKey = "final_answer") : base(MakeChain(name,prompt,model, outputKey),name, "history", outputKey)
{

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Chains.HelperChains;
using LangChain.Chains.StackableChains.ReAct;
using LangChain.Memory;
using LangChain.Providers;
using LangChain.Schema;
using System.Reflection;
using static LangChain.Chains.Chain;

namespace LangChain.Chains.StackableChains.Agents;

public class ReActAgentExecutorChain : BaseStackableChain
{
public const string DefaultPrompt =
@"Answer the following questions as best you can. You have access to the following tools:
{tools}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
(this Thought/Action/Action Input/Observation can repeat multiple times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Always add [END] after final answer
Begin!
Question: {input}
Thought:{history}";

private IChain? _chain = null;
private bool _useCache;
Dictionary<string, ReActAgentTool> _tools = new();
private readonly IChatModel _model;
private readonly string _reActPrompt;
private readonly int _maxActions;
private readonly ConversationBufferMemory _conversationBufferMemory;


public ReActAgentExecutorChain(IChatModel model, string reActPrompt = null, int maxActions = 5, string inputKey = "answer",

Check warning on line 46 in src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Cannot convert null literal to non-nullable reference type.
string outputKey = "final_answer")
{
reActPrompt ??= DefaultPrompt;
_model = model;
_reActPrompt = reActPrompt;
_maxActions = maxActions;

InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };

_conversationBufferMemory = new ConversationBufferMemory(new ChatMessageHistory()) { AiPrefix = "", HumanPrefix = "", SystemPrefix = "", SaveHumanMessages = false };

}

private string? _userInput = null;
private const string ReActAnswer = "answer";
private void InitializeChain()
{
string tool_names = string.Join(",", _tools.Select(x => x.Key));
string tools = string.Join("\n", _tools.Select(x => $"{x.Value.Name}, {x.Value.Description}"));

var chain =
Set(() => _userInput, "input")

Check warning on line 69 in src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Possible null reference return.
| Set(tools, "tools")
| Set(tool_names, "tool_names")
| Set(() => _conversationBufferMemory.BufferAsString, "history")
| Template(_reActPrompt)
| Chain.LLM(_model).UseCache(_useCache)
| UpdateMemory(_conversationBufferMemory, requestKey: "input", responseKey: "text")
| ReActParser(inputKey: "text", outputKey: ReActAnswer);

_chain = chain;
}

protected override async Task<IChainValues> InternalCall(IChainValues values)
{

var input = (string)values.Value[InputKeys[0]];
var values_chain = new ChainValues();

_userInput = input;


if (_chain == null)
{
InitializeChain();
}

for (int i = 0; i < _maxActions; i++)
{
var res = await _chain.CallAsync(values_chain);

Check warning on line 97 in src/libs/LangChain.Core/Chains/StackableChains/Agents/ReActAgentExecutorChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Dereference of a possibly null reference.
if (res.Value[ReActAnswer] is AgentAction)
{
var action = (AgentAction)res.Value[ReActAnswer];
var tool = _tools[action.Action];
var tool_res = tool.ToolCall(action.ActionInput);
await _conversationBufferMemory.ChatHistory.AddMessage(new Message("Observation: " + tool_res, MessageRole.System))
.ConfigureAwait(false);
await _conversationBufferMemory.ChatHistory.AddMessage(new Message("Thought:", MessageRole.System))
.ConfigureAwait(false);
continue;
}
else if (res.Value[ReActAnswer] is AgentFinish)
{
var finish = (AgentFinish)res.Value[ReActAnswer];
values.Value.Add(OutputKeys[0], finish.Output);
return values;
}
}



return values;
}

public ReActAgentExecutorChain UseCache(bool enabled = true)
{
_useCache = enabled;
return this;
}


public ReActAgentExecutorChain UseTool(ReActAgentTool tool)
{
_tools.Add(tool.Name, tool);
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ public async Task<IChainValues> Run()
return res.Value[resultKey].ToString();
}

public async Task<T> Run<T>(string resultKey)
{
var res = await CallAsync(new ChainValues()).ConfigureAwait(false);
return (T)res.Value[resultKey];
}

/// <summary>
///
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace LangChain.Chains.StackableChains.ReAct;

public class ReActAgentTool
{
public ReActAgentTool(string name, string description, Func<string, string> func)
{
Name = name;
Description = description;
ToolCall = func;
}

public string Name { get; set; }
public string Description { get; set; }

public Func<string, string> ToolCall { get; set; }

}
Loading

0 comments on commit 052e3ca

Please sign in to comment.