Skip to content

Commit

Permalink
.Net: Fix to ensure OpenAIPromptExecutionSettings ChatSystemPrompt is…
Browse files Browse the repository at this point in the history
… not ignored (#4530)

### Motivation and Context

Resolves issue #4510 

### Description

Add a system message using the chat system prompt if no system message
is included in the chat history

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Jan 10, 2024
1 parent a843f47 commit abc900e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,11 @@ private static ChatCompletionsOptions CreateChatCompletionsOptions(
}
}

if (!string.IsNullOrWhiteSpace(executionSettings?.ChatSystemPrompt) && !chatHistory.Any(m => m.Role == AuthorRole.System))
{
options.Messages.Add(GetRequestMessage(new ChatMessageContent(AuthorRole.System, executionSettings!.ChatSystemPrompt)));
}

foreach (var message in chatHistory)
{
options.Messages.Add(GetRequestMessage(message));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ public async Task ItAddsIdToChatMessageAsync()
var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
Assert.NotNull(actualRequestContent);
var optionsJson = JsonSerializer.Deserialize<JsonElement>(actualRequestContent);
Assert.Equal(1, optionsJson.GetProperty("messages").GetArrayLength());
Assert.Equal("John Doe", optionsJson.GetProperty("messages")[0].GetProperty("tool_call_id").GetString());
Assert.Equal(2, optionsJson.GetProperty("messages").GetArrayLength());
Assert.Equal("John Doe", optionsJson.GetProperty("messages")[1].GetProperty("tool_call_id").GetString());
}

[Fact]
Expand Down Expand Up @@ -163,6 +163,30 @@ public async Task ItGetTextContentsShouldHaveModelIdDefinedAsync()
Assert.Equal("gpt-3.5-turbo", textContent.ModelId);
}

[Fact]
public async Task ItAddsSystemMessageAsync()
{
// Arrange
var chatCompletion = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient);
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{ Content = new StringContent(ChatCompletionResponse) };
var chatHistory = new ChatHistory();
chatHistory.AddMessage(AuthorRole.User, "Hello");

// Act
await chatCompletion.GetChatMessageContentsAsync(chatHistory, this._executionSettings);

// Assert
var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
Assert.NotNull(actualRequestContent);
var optionsJson = JsonSerializer.Deserialize<JsonElement>(actualRequestContent);
Assert.Equal(2, optionsJson.GetProperty("messages").GetArrayLength());
Assert.Equal("Assistant is a large language model.", optionsJson.GetProperty("messages")[0].GetProperty("content").GetString());
Assert.Equal("system", optionsJson.GetProperty("messages")[0].GetProperty("role").GetString());
Assert.Equal("Hello", optionsJson.GetProperty("messages")[1].GetProperty("content").GetString());
Assert.Equal("user", optionsJson.GetProperty("messages")[1].GetProperty("role").GetString());
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
48 changes: 48 additions & 0 deletions dotnet/src/IntegrationTests/Connectors/OpenAI/ChatHistoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,54 @@ public async Task ItSerializesAndDeserializesChatHistoryAsync()
Assert.Null(exception);
}

[Fact]
public async Task ItUsesChatSystemPromptFromSettingsAsync()
{
// Arrange
this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;
this.ConfigureAzureOpenAIChatAsText(builder);
builder.Plugins.AddFromType<FakePlugin>();
var kernel = builder.Build();

string systemPrompt = "You are batman. If asked who you are, say 'I am Batman!'";

OpenAIPromptExecutionSettings settings = new() { ChatSystemPrompt = systemPrompt };
ChatHistory history = new();

// Act
history.AddUserMessage("Who are you?");
var service = kernel.GetRequiredService<IChatCompletionService>();
ChatMessageContent result = await service.GetChatMessageContentAsync(history, settings, kernel);

// Assert
Assert.Contains("Batman", result.ToString(), StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task ItUsesChatSystemPromptFromChatHistoryAsync()
{
// Arrange
this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;
this.ConfigureAzureOpenAIChatAsText(builder);
builder.Plugins.AddFromType<FakePlugin>();
var kernel = builder.Build();

string systemPrompt = "You are batman. If asked who you are, say 'I am Batman!'";

OpenAIPromptExecutionSettings settings = new();
ChatHistory history = new(systemPrompt);

// Act
history.AddUserMessage("Who are you?");
var service = kernel.GetRequiredService<IChatCompletionService>();
ChatMessageContent result = await service.GetChatMessageContentAsync(history, settings, kernel);

// Assert
Assert.Contains("Batman", result.ToString(), StringComparison.OrdinalIgnoreCase);
}

private void ConfigureAzureOpenAIChatAsText(IKernelBuilder kernelBuilder)
{
var azureOpenAIConfiguration = this._configuration.GetSection("Planners:AzureOpenAI").Get<AzureOpenAIConfiguration>();
Expand Down

0 comments on commit abc900e

Please sign in to comment.