Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix load custom google credentials #388

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Providers/Google.VertexAI/src/VertexAIChatModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Google.Cloud.AIPlatform.V1;
using Google.Apis.Auth.OAuth2;
using Google.Cloud.AIPlatform.V1;
using System.Diagnostics;

namespace LangChain.Providers.Google.VertexAI
Expand Down Expand Up @@ -38,9 +39,10 @@ public override async Task<ChatResponse> GenerateAsync(ChatRequest request, Chat

private GenerateContentRequest ToPrompt(IEnumerable<Message> messages)
{
var serviceAccountCredential = Provider.Configuration.GoogleCredential?.UnderlyingCredential as ServiceAccountCredential;
return new GenerateContentRequest
{
Model = $"projects/{Provider.Configuration.ProjectId}/locations/{Provider.Configuration.Location}/publishers/{Provider.Configuration.Publisher}/models/{Id}",
Model = $"projects/{serviceAccountCredential?.ProjectId}/locations/{Provider.Configuration.Location}/publishers/{Provider.Configuration.Publisher}/models/{Id}",
Contents = { messages.Select(ConvertMessage) },
GenerationConfig = Provider.Configuration.GenerationConfig
};
Expand Down
5 changes: 3 additions & 2 deletions src/Providers/Google.VertexAI/src/VertexAIConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Google.Cloud.AIPlatform.V1;
using Google.Apis.Auth.OAuth2;
using Google.Cloud.AIPlatform.V1;

namespace LangChain.Providers.Google.VertexAI
{
Expand All @@ -7,7 +8,7 @@ public class VertexAIConfiguration
public const string SectionName = "VertexAI";
public string Location { get; set; } = "us-central1";
public string Publisher { get; set; } = "google";
public required string ProjectId { get; set; }
public GoogleCredential? GoogleCredential { get; set; } = GoogleCredential.GetApplicationDefault();
public GenerationConfig? GenerationConfig { get; set; }
}
}
6 changes: 4 additions & 2 deletions src/Providers/Google.VertexAI/src/VertexAITextToImageModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Google.Cloud.AIPlatform.V1;
using Google.Apis.Auth.OAuth2;
using Google.Cloud.AIPlatform.V1;
using System.Diagnostics;
using Value = Google.Protobuf.WellKnownTypes.Value;

Expand All @@ -14,10 +15,11 @@ public async Task<TextToImageResponse> GenerateImageAsync(TextToImageRequest req
throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var serviceAccountCredential = Provider.Configuration.GoogleCredential?.UnderlyingCredential as ServiceAccountCredential;
var predictRequest = new PredictRequest
{
EndpointAsEndpointName = EndpointName.FromProjectLocationPublisherModel(
Provider.Configuration.ProjectId,
serviceAccountCredential?.ProjectId,
Provider.Configuration.Location,
Provider.Configuration.Publisher, Id),
Instances =
Expand Down
76 changes: 26 additions & 50 deletions src/Providers/Google.VertexAI/test/VertexAITest.cs
Original file line number Diff line number Diff line change
@@ -1,77 +1,53 @@
using Google.Apis.Auth.OAuth2;
using LangChain.Providers.Google.VertexAI.Predefined;

namespace LangChain.Providers.Google.VertexAI.Test
{
[TestFixture]
[Explicit]
//Required 'GOOGLE_APPLICATION_CREDENTIALS' env with Google credentials path json file.
//Required 'GOOGLE_APPLICATION_CREDENTIALS' env with Google credentials path json file.
public partial class VertexAITests
{
[Test]
public async Task Chat()
{

var credentials = GoogleCredential.GetApplicationDefault();

if (credentials.UnderlyingCredential is ServiceAccountCredential serviceAccountCredential)
var config = new VertexAIConfiguration()
{

var config = new VertexAIConfiguration()
//Publisher = "google",
//Location = "us-central1",
//GoogleCredential = GoogleCredential.FromJson("{your-json}"),
/*GenerationConfig = new GenerationConfig
{
ProjectId = serviceAccountCredential.ProjectId,
//Publisher = "google",
//Location = "us-central1",
/*GenerationConfig = new GenerationConfig
{
Temperature = 0.4f,
TopP = 1,
TopK = 32,
MaxOutputTokens = 2048
}*/
};
Temperature = 0.4f,
TopP = 1,
TopK = 32,
MaxOutputTokens = 2048
}*/
};

var provider = new VertexAIProvider(config);
var model = new Gemini15ProChatModel(provider);
var provider = new VertexAIProvider(config);
var model = new Gemini15ProChatModel(provider);

string answer = await model.GenerateAsync("Generate some random name:");
string answer = await model.GenerateAsync("Generate some random name:");

answer.Should().NotBeNull();

Console.WriteLine(answer);
}
answer.Should().NotBeNull();
Console.WriteLine(answer);

}

[Test]
public async Task TextToImage()
{
var provider = new VertexAIProvider(new VertexAIConfiguration());
var model = new VertexAITextToImageModel(provider, "imagegeneration@006", 2);
var answer = await model.GenerateImageAsync("a dog reading a newspaper");
answer.Should().NotBeNull();

var credentials = GoogleCredential.GetApplicationDefault();

if (credentials.UnderlyingCredential is ServiceAccountCredential serviceAccountCredential)
foreach (var img in answer.Images)
{

var config = new VertexAIConfiguration()
{
ProjectId = serviceAccountCredential.ProjectId
};

var provider = new VertexAIProvider(config);

var model = new VertexAITextToImageModel(provider, "imagegeneration@006", 2);

var answer = await model.GenerateImageAsync("a dog reading a newspaper");

answer.Should().NotBeNull();

foreach (var img in answer.Images)
{
string outputFileName = $"dog_newspaper_{Guid.NewGuid()}.png";
File.WriteAllBytes(outputFileName, Convert.FromBase64String(img));
FileInfo fileInfo = new(Path.GetFullPath(outputFileName));
Console.WriteLine($"Created output image {fileInfo.FullName} with {fileInfo.Length} bytes");
}
string outputFileName = $"dog_newspaper_{Guid.NewGuid()}.png";
File.WriteAllBytes(outputFileName, Convert.FromBase64String(img));
FileInfo fileInfo = new(Path.GetFullPath(outputFileName));
Console.WriteLine($"Created output image {fileInfo.FullName} with {fileInfo.Length} bytes");
}
}
}
Expand Down