Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Bugfix Cohere streaming stop_reason #150

Merged
merged 8 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/Providers/Amazon.Bedrock/src/BedrockModelStreamRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System.Text;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;

namespace LangChain.Providers.Amazon.Bedrock;

internal record BedrockModelStreamRequest
{
public static InvokeModelWithResponseStreamRequest Create(string modelId, JsonObject bodyJson)
{
bodyJson = bodyJson ?? throw new ArgumentNullException(nameof(bodyJson));

var byteArray = Encoding.UTF8.GetBytes(bodyJson.ToJsonString());
var stream = new MemoryStream(byteArray);

var bedrockRequest = new InvokeModelWithResponseStreamRequest
{
ModelId = modelId,
ContentType = "application/json",
Accept = "application/json",
Body = stream
};

return bedrockRequest;
}
}
82 changes: 63 additions & 19 deletions src/Providers/Amazon.Bedrock/src/Chat/AmazonTitanChatModel.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
Expand All @@ -11,8 +13,6 @@ public abstract class AmazonTitanChatModel(
string id)
: ChatModel(id)
{
public override int ContextLength => 4096;

public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
Expand All @@ -22,32 +22,60 @@ public override async Task<ChatResponse> GenerateAsync(

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();

var stringBuilder = new StringBuilder();

var usedSettings = BedrockChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
var response = await provider.Api.InvokeModelAsync(
Id,
new JsonObject

var bodyJson = CreateBodyJson(prompt, usedSettings);

if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken);

foreach (var payloadPart in response.Body)
{
["inputText"] = prompt,
["textGenerationConfig"] = new JsonObject
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["outputText"]!.GetValue<string>();

OnPartialResponseGenerated(delta!);
stringBuilder.Append(delta);

var finished = chunk?["completionReason"]?.GetValue<string>();
if (finished?.ToLower() == "finish")
{
["maxTokenCount"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["topP"] = usedSettings.TopP!.Value,
["stopSequences"] = usedSettings.StopSequences!.AsArray()
OnCompletedResponseGenerated(stringBuilder.ToString());
}
},
cancellationToken).ConfigureAwait(false);
}

OnPartialResponseGenerated(Environment.NewLine);
stringBuilder.Append(Environment.NewLine);

var generatedText = response?["results"]?[0]?["outputText"]?.GetValue<string>() ?? string.Empty;
var newMessage = new Message(
Content: stringBuilder.ToString(),
Role: MessageRole.Ai);
messages.Add(newMessage);

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(newMessage.Content);
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);

var generatedText = response?["results"]?[0]?["outputText"]?.GetValue<string>() ?? string.Empty;

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

// Unsupported
var usage = Usage.Empty with
{
Time = watch.Elapsed,
Expand All @@ -57,9 +85,25 @@ public override async Task<ChatResponse> GenerateAsync(

return new ChatResponse
{
Messages = result,
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}

private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["inputText"] = prompt,
["textGenerationConfig"] = new JsonObject
{
["maxTokenCount"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["topP"] = usedSettings.TopP!.Value,
["stopSequences"] = usedSettings.StopSequences!.AsArray()
}
};
return bodyJson;
}
}
85 changes: 65 additions & 20 deletions src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
Expand All @@ -19,31 +22,59 @@

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToRolePrompt();
var messages = request.Messages.ToList();

var stringBuilder = new StringBuilder();

var usedSettings = BedrockChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
var response = await provider.Api.InvokeModelAsync(
Id,
new JsonObject

var bodyJson = CreateBodyJson(prompt, usedSettings);

if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken);

foreach (var payloadPart in response.Body)
{
["prompt"] = prompt,
["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["top_p"] = usedSettings.TopP!.Value,
["top_k"] = usedSettings.TopK!.Value,
["stop_sequences"] = new JsonArray("\n\nHuman:")
},
cancellationToken).ConfigureAwait(false);

var generatedText = response?["completion"]?
.GetValue<string>() ?? "";

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());

// Unsupported
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["completion"]!.GetValue<string>();

OnPartialResponseGenerated(delta!);
stringBuilder.Append(delta);

var finished = chunk?["completionReason"]?.GetValue<string>();
if (finished?.ToLower() == "finish")

Check warning on line 52 in src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Specify a culture or use an invariant version to avoid implicit dependency on current culture (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1311)
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}

OnPartialResponseGenerated(Environment.NewLine);
stringBuilder.Append(Environment.NewLine);

var newMessage = new Message(
Content: stringBuilder.ToString(),
Role: MessageRole.Ai);
messages.Add(newMessage);

OnCompletedResponseGenerated(newMessage.Content);
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken).ConfigureAwait(false);

var generatedText = response?["completion"]?.GetValue<string>() ?? "";

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
Expand All @@ -53,9 +84,23 @@

return new ChatResponse
{
Messages = result,
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}

private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["prompt"] = prompt,
["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["top_p"] = usedSettings.TopP!.Value,
["top_k"] = usedSettings.TopK!.Value,
["stop_sequences"] = new JsonArray("\n\nHuman:")
};
return bodyJson;
}
}
9 changes: 8 additions & 1 deletion src/Providers/Amazon.Bedrock/src/Chat/BedrockChatSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ public class BedrockChatSettings : ChatSettings
{
StopSequences = ChatSettings.Default.StopSequences,
User = ChatSettings.Default.User,
UseStreaming = false,
Temperature = 0.7,
MaxTokens = 4096,
MaxTokens = 2048,
TopP = 0.9,
TopK = 0.0
};
Expand Down Expand Up @@ -66,6 +67,12 @@ public class BedrockChatSettings : ChatSettings
providerSettingsCasted?.User ??
Default.User ??
throw new InvalidOperationException("Default User is not set."),
UseStreaming =
requestSettings?.UseStreaming ??
modelSettings?.UseStreaming ??
providerSettings?.UseStreaming ??
Default.UseStreaming ??
throw new InvalidOperationException("Default UseStreaming is not set."),
Temperature =
requestSettingsCasted?.Temperature ??
modelSettingsCasted?.Temperature ??
Expand Down
66 changes: 52 additions & 14 deletions src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
Expand All @@ -19,27 +22,49 @@

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();

var stringBuilder = new StringBuilder();

var usedSettings = BedrockChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
var response = await provider.Api.InvokeModelAsync(
Id,
new JsonObject

var bodyJson = CreateBodyJson(prompt, usedSettings);

if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken);

foreach (var payloadPart in response.Body)
{
["prompt"] = prompt,
["max_tokens"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["p"] = usedSettings.TopP!.Value,
["k"] = usedSettings.TopK!.Value,
},
cancellationToken).ConfigureAwait(false);
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["generations"]?[0]?["text"]?.GetValue<string>() ?? string.Empty;

OnPartialResponseGenerated(delta!);
stringBuilder.Append(delta);

var generatedText = response?["generations"]?[0]?["text"]?.GetValue<string>() ?? string.Empty;
var finished = chunk?["generations"]?[0]?["finish_reason"]?.GetValue<string>() ?? string.Empty;
if (string.Equals(finished?.ToLower(), "complete", StringComparison.Ordinal))

Check warning on line 52 in src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

Specify a culture or use an invariant version to avoid implicit dependency on current culture (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1311)

Check warning on line 52 in src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs

View workflow job for this annotation

GitHub Actions / Build and test / Build, test and publish

The behavior of 'string.ToLower()' could vary based on the current user's locale settings. Replace this call in 'CohereCommandChatModel.GenerateAsync(ChatRequest,
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);

var result = request.Messages.ToList();
result.Add(generatedText.AsAiMessage());
var generatedText = response?["generations"]?[0]?["text"]?.GetValue<string>() ?? string.Empty;

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

var usage = Usage.Empty with
{
Expand All @@ -50,9 +75,22 @@

return new ChatResponse
{
Messages = result,
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}

private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["prompt"] = prompt,
["max_tokens"] = usedSettings.MaxTokens!.Value,
["temperature"] = usedSettings.Temperature!.Value,
["p"] = usedSettings.TopP!.Value,
["k"] = usedSettings.TopK!.Value,
};
return bodyJson;
}
}
Loading
Loading