-
-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add GroqSharp provider #303
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using GroqSharp.Models; | ||
using System.Diagnostics; | ||
using MessageGroqSharp = GroqSharp.Models.Message; | ||
|
||
namespace LangChain.Providers.GroqSharp | ||
{ | ||
public class GroqSharpChatModel( | ||
GroqSharpProvider provider, | ||
string id) | ||
: ChatModel(id), IChatModel | ||
{ | ||
public GroqSharpProvider Provider { get; } = provider ?? throw new ArgumentNullException(nameof(provider)); | ||
public override async Task<ChatResponse> GenerateAsync( | ||
ChatRequest request, | ||
ChatSettings settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
var prompt = ToPrompt(request.Messages); | ||
var watch = Stopwatch.StartNew(); | ||
var response = await Provider.Api.CreateChatCompletionAsync(prompt); | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
var result = request.Messages.ToList(); | ||
result.Add(response.AsAiMessage()); | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = result, | ||
Usage = usage, | ||
UsedSettings = ChatSettings.Default, | ||
}; | ||
} | ||
|
||
protected static MessageGroqSharp[] ToPrompt(IEnumerable<Message> messages) | ||
{ | ||
return messages.Select(ConvertMessage).ToArray(); | ||
} | ||
|
||
protected static MessageGroqSharp ConvertMessage(Message message) | ||
{ | ||
return new MessageGroqSharp { Role = ConvertRole(message.Role), Content = message.Content }; | ||
} | ||
protected static MessageRoleType ConvertRole(MessageRole role) | ||
{ | ||
return role switch | ||
{ | ||
MessageRole.Human => MessageRoleType.User, | ||
MessageRole.Ai => MessageRoleType.Assistant, | ||
MessageRole.System => MessageRoleType.System, | ||
_ => throw new NotSupportedException($"the role {role} is not supported") | ||
}; | ||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
namespace LangChain.Providers.GroqSharp | ||
{ | ||
public class GroqSharpConfiguration | ||
{ | ||
public const string SectionName = "GroqSharp"; | ||
public string ApiKey { get; set; } | ||
public string ModelId { get; set; } | ||
public double Temperature { get; set; } = 0.5; | ||
public int MaxTokens { get; set; } = int.MaxValue; | ||
public double TopP { get; set; } = 1.0; | ||
public string Stop { get; set; } = "NONE"; | ||
public int StructuredRetryPolicy { get; set; } = 5; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
namespace LangChain.Providers.GroqSharp | ||
{ | ||
public class GroqSharpProvider : Provider | ||
{ | ||
public IGroqClient Api { get; private set; } | ||
public GroqSharpProvider(GroqClient groqClient) | ||
: base(id: GroqSharpConfiguration.SectionName) | ||
{ | ||
Api = groqClient ?? throw new ArgumentNullException(nameof(groqClient)); | ||
} | ||
|
||
public GroqSharpProvider(GroqSharpConfiguration configuration) | ||
: base(id: GroqSharpConfiguration.SectionName) | ||
{ | ||
configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); | ||
var apiKey = configuration.ApiKey ?? throw new ArgumentException("ApiKey is not defined", nameof(configuration)); | ||
var apiModel = configuration.ModelId ?? throw new ArgumentException("ModelId is not defined", nameof(configuration)); | ||
|
||
Api = new GroqClient(apiKey, apiModel) | ||
.SetTemperature(configuration.Temperature) | ||
.SetMaxTokens(configuration.MaxTokens) | ||
.SetTopP(configuration.TopP) | ||
.SetStop(configuration.Stop) | ||
.SetStructuredRetryPolicy(configuration.StructuredRetryPolicy); | ||
|
||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFrameworks>net8.0</TargetFrameworks> | ||
<NoWarn>$(NoWarn);CA1054</NoWarn> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\Abstractions\src\LangChain.Providers.Abstractions.csproj" /> | ||
<PackageReference Include="GroqSharp" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup Label="Usings"> | ||
<Using Include="System.Net.Http" /> | ||
<Using Include="GroqSharp" Version="1.1.2" /> | ||
</ItemGroup> | ||
|
||
<PropertyGroup Label="NuGet"> | ||
<Description>Groq model provider.</Description> | ||
<PackageTags>$(PackageTags);groq;ai;api</PackageTags> | ||
</PropertyGroup> | ||
|
||
</Project> |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,16 @@ | ||||||||||
namespace LangChain.Providers.GroqSharp.Predefined | ||||||||||
{ | ||||||||||
|
||||||||||
public class Llama38b8192(GroqSharpProvider provider) | ||||||||||
: GroqSharpChatModel(provider, id: "llama3-8b-8192"); | ||||||||||
|
||||||||||
public class Llama370b8192(GroqSharpProvider provider) | ||||||||||
: GroqSharpChatModel(provider, id: "llama3-70b-8192"); | ||||||||||
Comment on lines
+7
to
+8
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class name should follow PascalCase naming convention. - public class Llama370b8192(GroqSharpProvider provider)
+ public class Llama370B8192(GroqSharpProvider provider) Committable suggestion
Suggested change
|
||||||||||
|
||||||||||
public class Mixtral8x7b32768(GroqSharpProvider provider) | ||||||||||
: GroqSharpChatModel(provider, id: "mixtral-8x7b-32768"); | ||||||||||
Comment on lines
+10
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class name should follow PascalCase naming convention. - public class Mixtral8x7b32768(GroqSharpProvider provider)
+ public class Mixtral8X7B32768(GroqSharpProvider provider) Committable suggestion
Suggested change
|
||||||||||
|
||||||||||
public class Gemma7bIt(GroqSharpProvider provider) | ||||||||||
: GroqSharpChatModel(provider, id: "gemma-7b-it"); | ||||||||||
Comment on lines
+13
to
+14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class name should follow PascalCase naming convention. - public class Gemma7bIt(GroqSharpProvider provider)
+ public class Gemma7BIt(GroqSharpProvider provider) Committable suggestion
Suggested change
|
||||||||||
|
||||||||||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,29 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
using LangChain.Providers.GroqSharp.Predefined; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
using NUnit.Framework; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
namespace LangChain.Providers.GroqSharp.Test; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[TestFixture] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[Explicit] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
public partial class GroqSharpTest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[Test] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
public async Task Chat() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
var config = new GroqSharpConfiguration() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ApiKey = Environment.GetEnvironmentVariable("GROQ_API_KEY") ?? | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InconclusiveException("GROQ_API_KEY is not set."), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ModelId = "llama3-70b-8192" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
var provider = new GroqSharpProvider(config); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
var model = new Llama370b8192(provider); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
string answer = await model.GenerateAsync("Generate some random name:"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Console.WriteLine(answer); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+10
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add assertions to verify the generated response. - Console.WriteLine(answer);
+ Assert.IsNotNull(answer);
+ Assert.IsNotEmpty(answer);
+ Console.WriteLine(answer); Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>net8.0</TargetFramework> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="NUnit" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\Abstractions\src\LangChain.Providers.Abstractions.csproj" /> | ||
<ProjectReference Include="..\src\LangChain.Providers.GroqSharp.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class name should follow PascalCase naming convention.
Committable suggestion