Skip to content

Commit

Permalink
Changes due to last review:
Browse files Browse the repository at this point in the history
1. Refactor ClientBase class fields to read-only properties
2. Optimized null handling in assignment of ResponseSafetyRatings
3. Several classes have been restructured and moved to become inner classes and their names have been adjusted to reflect this change by appending "Part" or "Element" to each.
4. Adjusted namespace for public types in Core to the parent namespace: Microsoft.SemanticKernel.Connectors.GoogleVertexAI.
5. The GeminiConfiguration class has been removed, and we now use a direct string for the modelId.
  • Loading branch information
Krzysztof318 committed Jan 22, 2024
1 parent cdafc20 commit f2866cf
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 261 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini;
using Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini.Common;
using Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini.GoogleAI;
using SemanticKernel.UnitTests;
Expand All @@ -31,8 +30,7 @@ public GeminiClientCountingTokensTests()
public async Task ShouldReturnGreaterThanZeroTokenCountAsync()
{
// Arrange
var geminiConfiguration = new GeminiConfiguration("fake-api-key") { ModelId = "fake-model" };
var client = this.CreateTokenCounterClient(geminiConfiguration);
var client = this.CreateTokenCounterClient("fake-model", "fake-key");

// Act
var tokenCount = await client.CountTokensAsync("fake-text");
Expand All @@ -41,13 +39,13 @@ public async Task ShouldReturnGreaterThanZeroTokenCountAsync()
Assert.True(tokenCount > 0);
}

private GeminiTokenCounterClient CreateTokenCounterClient(GeminiConfiguration geminiConfiguration)
private GeminiTokenCounterClient CreateTokenCounterClient(string modelId, string apiKey)
{
var client = new GeminiTokenCounterClient(
httpClient: this._httpClient,
configuration: geminiConfiguration,
modelId: modelId,
httpRequestFactory: new GoogleAIGeminiHttpRequestFactory(),
endpointProvider: new GoogleAIGeminiEndpointProvider(geminiConfiguration.ApiKey));
endpointProvider: new GoogleAIGeminiEndpointProvider(apiKey));
return client;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini;
using Microsoft.SemanticKernel.Connectors.GoogleVertexAI;
using Xunit;

namespace SemanticKernel.Connectors.GoogleVertexAI.UnitTests.Core.Gemini;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core;

internal abstract class ClientBase
{
protected readonly IHttpRequestFactory HTTPRequestFactory;
protected readonly IEndpointProvider EndpointProvider;
protected readonly HttpClient HTTPClient;
protected readonly ILogger Logger;
protected IHttpRequestFactory HTTPRequestFactory { get; }
protected IEndpointProvider EndpointProvider { get; }
protected HttpClient HTTPClient { get; }
protected ILogger Logger { get; }

protected ClientBase(
HttpClient httpClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ protected static GeminiMetadata GetResponseMetadata(
TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0,
PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason,
PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(),
ResponseSafetyRatings = candidate.SafetyRatings.ToList(),
ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
};

protected void LogUsageMetadata(GeminiMetadata metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ internal class GeminiTokenCounterClient : GeminiClient, IGeminiTokenCounterClien
/// Represents a client for token counting gemini model.
/// </summary>
/// <param name="httpClient">HttpClient instance used to send HTTP requests</param>
/// <param name="configuration">Gemini configuration instance containing API key and other configuration options</param>
/// <param name="modelId">Id of the model to use to counting tokens</param>
/// <param name="httpRequestFactory">Request factory for gemini rest api or gemini vertex ai</param>
/// <param name="endpointProvider">Endpoints provider for gemini rest api or gemini vertex ai</param>
/// <param name="logger">Logger instance used for logging (optional)</param>
public GeminiTokenCounterClient(
HttpClient httpClient,
GeminiConfiguration configuration,
string modelId,
IHttpRequestFactory httpRequestFactory,
IEndpointProvider endpointProvider,
ILogger? logger = null)
Expand All @@ -36,9 +36,9 @@ public GeminiTokenCounterClient(
endpointProvider: endpointProvider,
logger: logger)
{
VerifyModelId(configuration);
Verify.NotNullOrWhiteSpace(modelId);

this._modelId = configuration.ModelId!;
this._modelId = modelId;
}

/// <inheritdoc/>
Expand All @@ -64,9 +64,4 @@ private static int DeserializeAndProcessCountTokensResponse(string body)
var node = DeserializeResponse<JsonNode>(body);
return node["totalTokens"]?.GetValue<int>() ?? throw new KernelException("Invalid response from model");
}

private static void VerifyModelId(GeminiConfiguration configuration)
{
Verify.NotNullOrWhiteSpace(configuration.ModelId, $"{nameof(configuration)}.{nameof(configuration.ModelId)}");
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini;

namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI;

/// <summary>
/// The base structured datatype containing multi-part content of a message.
/// </summary>
public sealed class GeminiContent
{
/// <summary>
/// Ordered Parts that constitute a single message. Parts may have different MIME types.
/// </summary>
[JsonPropertyName("parts")]
[JsonRequired]
public IList<GeminiPart> Parts { get; set; } = null!;

/// <summary>
/// Optional. The producer of the content. Must be either 'user' or 'model'.
/// </summary>
/// <remarks>Useful to set for multi-turn conversations, otherwise can be left blank or unset.</remarks>
[JsonPropertyName("role")]
[JsonConverter(typeof(AuthorRoleConverter))]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public AuthorRole? Role { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini;
namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI;

/// <summary>
/// Represents a Gemini Finish Reason.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using System.Collections.ObjectModel;
using System.Runtime.CompilerServices;

namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini;
namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI;

/// <summary>
/// Represents the metadata associated with a Gemini response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;

namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI.Core.Gemini;
namespace Microsoft.SemanticKernel.Connectors.GoogleVertexAI;

/// <summary>
/// Union field data can be only one of properties in class GeminiPart
Expand All @@ -24,21 +24,21 @@ public sealed class GeminiPart : IJsonOnDeserialized
/// </summary>
[JsonPropertyName("inlineData")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public GeminiPartInlineData? InlineData { get; set; }
public InlineDataPart? InlineData { get; set; }

/// <summary>
/// Function call data.
/// </summary>
[JsonPropertyName("functionCall")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public GeminiPartFunctionCall? FunctionCall { get; set; }
public FunctionCallPart? FunctionCall { get; set; }

/// <summary>
/// Object representing the function call response.
/// </summary>
[JsonPropertyName("functionResponse")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public GeminiPartFunctionResponse? FunctionResponse { get; set; }
public FunctionResponsePart? FunctionResponse { get; set; }

/// <summary>
/// Checks whether only one property of the GeminiPart instance is not null.
Expand All @@ -62,69 +62,69 @@ public void OnDeserialized()
"GeminiPart is invalid. One and only one property among Text, InlineData, FunctionCall, and FunctionResponse should be set.");
}
}
}

/// <summary>
/// Inline media bytes like image or video data.
/// </summary>
public sealed class GeminiPartInlineData
{
/// <summary>
/// The IANA standard MIME type of the source data.
/// Inline media bytes like image or video data.
/// </summary>
/// <remarks>
/// Accepted types include: "image/png", "image/jpeg", "image/heic", "image/heif", "image/webp".
/// </remarks>
[JsonPropertyName("mimeType")]
[JsonRequired]
public string MimeType { get; set; } = null!;
public sealed class InlineDataPart
{
/// <summary>
/// The IANA standard MIME type of the source data.
/// </summary>
/// <remarks>
/// Accepted types include: "image/png", "image/jpeg", "image/heic", "image/heif", "image/webp".
/// </remarks>
[JsonPropertyName("mimeType")]
[JsonRequired]
public string MimeType { get; set; } = null!;

/// <summary>
/// Base64 encoded data
/// </summary>
[JsonPropertyName("data")]
[JsonRequired]
public string InlineData { get; set; } = null!;
}
/// <summary>
/// Base64 encoded data
/// </summary>
[JsonPropertyName("data")]
[JsonRequired]
public string InlineData { get; set; } = null!;
}

/// <summary>
/// A predicted FunctionCall returned from the model that contains a
/// string representing the FunctionDeclaration.name with the arguments and their values.
/// </summary>
public sealed class GeminiPartFunctionCall
{
/// <summary>
/// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
/// A predicted FunctionCall returned from the model that contains a
/// string representing the FunctionDeclaration.name with the arguments and their values.
/// </summary>
[JsonPropertyName("name")]
[JsonRequired]
public string FunctionName { get; set; } = null!;
public sealed class FunctionCallPart
{
/// <summary>
/// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
/// </summary>
[JsonPropertyName("name")]
[JsonRequired]
public string FunctionName { get; set; } = null!;

/// <summary>
/// Optional. The function parameters and values in JSON object format.
/// </summary>
[JsonPropertyName("args")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<JsonNode>? Arguments { get; set; }
}
/// <summary>
/// Optional. The function parameters and values in JSON object format.
/// </summary>
[JsonPropertyName("args")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<JsonNode>? Arguments { get; set; }
}

/// <summary>
/// The result output of a FunctionCall that contains a string representing the FunctionDeclaration.name and
/// a structured JSON object containing any output from the function is used as context to the model.
/// </summary>
public sealed class GeminiPartFunctionResponse
{
/// <summary>
/// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
/// The result output of a FunctionCall that contains a string representing the FunctionDeclaration.name and
/// a structured JSON object containing any output from the function is used as context to the model.
/// </summary>
[JsonPropertyName("name")]
[JsonRequired]
public string FunctionName { get; set; } = null!;
public sealed class FunctionResponsePart
{
/// <summary>
/// Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
/// </summary>
[JsonPropertyName("name")]
[JsonRequired]
public string FunctionName { get; set; } = null!;

/// <summary>
/// Required. The function response in JSON object format.
/// </summary>
[JsonPropertyName("response")]
[JsonRequired]
public IList<JsonNode> Response { get; set; } = null!;
/// <summary>
/// Required. The function response in JSON object format.
/// </summary>
[JsonPropertyName("response")]
[JsonRequired]
public IList<JsonNode> Response { get; set; } = null!;
}
}
Loading

0 comments on commit f2866cf

Please sign in to comment.