Skip to content

Commit

Permalink
feat(GoogleGemini): Added CountTokens
Browse files Browse the repository at this point in the history
fix: fixed usages
  • Loading branch information
gunpal5 committed Jun 6, 2024
1 parent ec08718 commit 28db041
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<PackageVersion Include="DotNet.ReproducibleBuilds" Version="1.1.1" />
<PackageVersion Include="FluentAssertions" Version="6.12.0" />
<PackageVersion Include="GitHubActionsTestLogger" Version="2.4.1" />
<PackageVersion Include="Google_GenerativeAI" Version="1.0.0" />
<PackageVersion Include="Google_GenerativeAI" Version="1.0.1" />
<PackageVersion Include="GroqSharp" Version="1.1.2" />
<PackageVersion Include="H.Generators.Extensions" Version="1.22.0" />
<PackageVersion Include="H.NSwag.Generator" Version="14.0.7.76">
Expand Down
24 changes: 24 additions & 0 deletions src/Providers/Google/src/GoogleChatModel.Tokens.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using GenerativeAI.Types;

namespace LangChain.Providers.Google
{
public partial class GoogleChatModel
{
public async Task<int> CountTokens(string text)
{
return await CountTokens(new Message[] { new Message(text, MessageRole.Human) }).ConfigureAwait(false);
}

public async Task<int> CountTokens(IEnumerable<Message> messages)
{
var response = await this.Api.CountTokens(new CountTokensRequest() { Contents = messages.Select(ToRequestMessage).ToArray() }).ConfigureAwait(false);

return response.TotalTokens;
}
}
}
26 changes: 12 additions & 14 deletions src/Providers/Google/src/GoogleChatModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ public override async Task<ChatResponse> GenerateAsync(
settings,
Settings,
provider.ChatSettings);
var usage = Usage.Empty;

if (usedSettings.UseStreaming == true)
{
var message = await StreamCompletionAsync(messages, cancellationToken).ConfigureAwait(false);
messages.Add(message);
usage += Usage.Empty with
{
Time = watch.Elapsed
};
}
else
{
Expand All @@ -145,14 +150,14 @@ public override async Task<ChatResponse> GenerateAsync(
OnCompletedResponseGenerated(response.Text() ?? string.Empty);


var usage2 = GetUsage(response) with
usage = GetUsage(response) with
{
Time = watch.Elapsed
};

//Add Usage
AddUsage(usage2);
provider.AddUsage(usage2);
AddUsage(usage);
provider.AddUsage(usage);

//Handle Function Call
while (ReplyToToolCallsAutomatically && response.IsFunctionCall())
Expand Down Expand Up @@ -185,28 +190,21 @@ public override async Task<ChatResponse> GenerateAsync(
messages.Add(message);

//Add Usage
usage2 = GetUsage(response) with
var usage2 = GetUsage(response) with
{
Time = watch.Elapsed
};
AddUsage(usage2);
provider.AddUsage(usage2);
usage += usage2;
}
}
}

//Add Usage
var usage = Usage.Empty with
{
Time = watch.Elapsed
};
AddUsage(usage);
provider.AddUsage(usage);


return new ChatResponse
{
Messages = messages,
Usage = Usage,
Usage = usage,
UsedSettings = ChatSettings.Default
};
}
Expand Down
2 changes: 1 addition & 1 deletion src/Providers/Google/src/Predefined/GeminiModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ public class Gemini15FlashModel(GoogleProvider provider)
public class Gemini15ProModel(GoogleProvider provider)
: GoogleChatModel(
provider,
GoogleAIModels.Gemini15Flash, 2 * 1024 * 1024, 3.5 * 0.000001, 10.50 * 0.000001, 7.0 * 0.000001, 21.00 * 0.000001);
GoogleAIModels.Gemini15Pro, 2 * 1024 * 1024, 3.5 * 0.000001, 10.50 * 0.000001, 7.0 * 0.000001, 21.00 * 0.000001);
7 changes: 4 additions & 3 deletions src/Providers/OpenAI/src/Chat/OpenAiChatModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,13 @@ public override async Task<ChatResponse> GenerateAsync(
OnPartialResponseGenerated(Environment.NewLine);
OnCompletedResponseGenerated(newMessage.Content);

usage = GetUsage(response) with
var usage2 = GetUsage(response) with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);
AddUsage(usage2);
provider.AddUsage(usage2);
usage += usage2;
}
}

Expand Down

0 comments on commit 28db041

Please sign in to comment.