Skip to content

Commit

Permalink
.Net Agents - Align with expectations when arguments override specifi…
Browse files Browse the repository at this point in the history
…ed. (#9096)

### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

When default `KernelArguments` are specified for
`KernelAgent.Arguments`, it is likely expected that when override
`arguments` parameter is provided that this override doesn't necessarily
invalidate the default `KernelArguments.ExecutionSettings`.

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

Provide performant merge logic for _default_ and _override_
`KernelArguments` for all invocations.
This merge preserves any _default_ values not specifically included in
the _override_...including both `ExecutionSettings` and parameters.

Solves this (common) pattern:
```c#
ChatCompletionAgent agent =
    new()
    {
        Name = "SampleAssistantAgent",
        Instructions =
            """
            Something something.
            
            Something dynamic: {{$topic}}

            The current date and time is: {{$now}}. 
            """,
        Kernel = kernel,
        Arguments =
            new KernelArguments(new AzureOpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })
            {
                { "topic", "default topic" }
            }
    };
    
KernelArguments arguments =
    new()
    {
        { "now", $"{now.ToShortDateString()} {now.ToShortTimeString()}" }
    };
await foreach (ChatMessageContent response in agent.InvokeAsync(history, arguments))
{
    ...
} 	
```

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Oct 4, 2024
1 parent 27fa987 commit 81953f2
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 14 deletions.
46 changes: 46 additions & 0 deletions dotnet/src/Agents/Abstractions/KernelAgent.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -62,4 +64,48 @@ public abstract class KernelAgent : Agent

return await this.Template.RenderAsync(kernel, arguments, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Provide a merged instance of <see cref="KernelArguments"/> with precedence for override arguments.
/// </summary>
/// <param name="arguments">The override arguments</param>
/// <remarks>
/// This merge preserves original <see cref="PromptExecutionSettings"/> and <see cref="KernelArguments"/> parameters.
/// and allows for incremental addition or replacement of specific parameters while also preserving the ability
/// to override the execution settings.
/// </remarks>
protected KernelArguments? MergeArguments(KernelArguments? arguments)
{
// Avoid merge when default arguments are not set.
if (this.Arguments == null)
{
return arguments;
}

// Avoid merge when override arguments are not set.
if (arguments == null)
{
return this.Arguments;
}

// Both instances are not null, merge with precedence for override arguments.

// Merge execution settings with precedence for override arguments.
Dictionary<string, PromptExecutionSettings>? settings =
(arguments.ExecutionSettings ?? s_emptySettings)
.Concat(this.Arguments.ExecutionSettings ?? s_emptySettings)
.GroupBy(entry => entry.Key)
.ToDictionary(entry => entry.Key, entry => entry.First().Value);

// Merge parameters with precedence for override arguments.
Dictionary<string, object?>? parameters =
arguments
.Concat(this.Arguments)
.GroupBy(entry => entry.Key)
.ToDictionary(entry => entry.Key, entry => entry.First().Value);

return new KernelArguments(parameters, settings);
}

private static readonly Dictionary<string, PromptExecutionSettings> s_emptySettings = [];
}
4 changes: 2 additions & 2 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
kernel ??= this.Kernel;
arguments ??= this.Arguments;
arguments = this.MergeArguments(arguments);

(IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) = GetChatCompletionService(kernel, arguments);

Expand Down Expand Up @@ -95,7 +95,7 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
kernel ??= this.Kernel;
arguments ??= this.Arguments;
arguments = this.MergeArguments(arguments);

(IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) = GetChatCompletionService(kernel, arguments);

Expand Down
22 changes: 12 additions & 10 deletions dotnet/src/Agents/OpenAI/Internal/AssistantToolResourcesFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ internal static class AssistantToolResourcesFactory

if (hasVectorStore || hasCodeInterpreterFiles)
{
var fileSearch = hasVectorStore
? new FileSearchToolResources
{
VectorStoreIds = { vectorStoreId! }
}
: null;

var codeInterpreter = hasCodeInterpreterFiles
? new CodeInterpreterToolResources()
: null;
FileSearchToolResources? fileSearch =
hasVectorStore ?
new()
{
VectorStoreIds = { vectorStoreId! }
} :
null;

CodeInterpreterToolResources? codeInterpreter =
hasCodeInterpreterFiles ?
new() :
null;
codeInterpreter?.FileIds.AddRange(codeInterpreterFileIds!);

toolResources = new ToolResources
{
FileSearch = fileSearch,
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ public async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
this.ThrowIfDeleted();

kernel ??= this.Kernel;
arguments ??= this.Arguments;
arguments = this.MergeArguments(arguments);

await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this._client, threadId, options, this.Logger, kernel, arguments, cancellationToken).ConfigureAwait(false))
{
Expand Down Expand Up @@ -396,7 +396,7 @@ public IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
this.ThrowIfDeleted();

kernel ??= this.Kernel;
arguments ??= this.Arguments;
arguments = this.MergeArguments(arguments);

return AssistantThreadActions.InvokeStreamingAsync(this, this._client, threadId, messages, options, this.Logger, kernel, arguments, cancellationToken);
}
Expand Down
111 changes: 111 additions & 0 deletions dotnet/src/Agents/UnitTests/KernelAgentTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Linq;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Xunit;

namespace SemanticKernel.Agents.UnitTests;

/// <summary>
/// Verify behavior of <see cref="KernelAgent"/> base class.
/// </summary>
public class KernelAgentTests
{
/// <summary>
/// Verify ability to merge null <see cref="KernelArguments"/>.
/// </summary>
[Fact]
public void VerifyNullArgumentMerge()
{
// Arrange
MockAgent agentWithNullArguments = new();
// Act
KernelArguments? arguments = agentWithNullArguments.MergeArguments(null);
// Assert
Assert.Null(arguments);

// Arrange
KernelArguments overrideArguments = [];
// Act
arguments = agentWithNullArguments.MergeArguments(overrideArguments);
// Assert
Assert.NotNull(arguments);
Assert.StrictEqual(overrideArguments, arguments);

// Arrange
MockAgent agentWithEmptyArguments = new() { Arguments = new() };
// Act
arguments = agentWithEmptyArguments.MergeArguments(null);
// Assert
Assert.NotNull(arguments);
Assert.StrictEqual(agentWithEmptyArguments.Arguments, arguments);
}

/// <summary>
/// Verify ability to merge <see cref="KernelArguments"/> parameters.
/// </summary>
[Fact]
public void VerifyArgumentParameterMerge()
{
// Arrange
MockAgent agentWithArguments = new() { Arguments = new() { { "a", 1 } } };
KernelArguments overrideArguments = new() { { "b", 2 } };

// Act
KernelArguments? arguments = agentWithArguments.MergeArguments(overrideArguments);

// Assert
Assert.NotNull(arguments);
Assert.Equal(2, arguments.Count);
Assert.Equal(1, arguments["a"]);
Assert.Equal(2, arguments["b"]);

// Arrange
overrideArguments["a"] = 11;
overrideArguments["c"] = 3;

// Act
arguments = agentWithArguments.MergeArguments(overrideArguments);

// Assert
Assert.NotNull(arguments);
Assert.Equal(3, arguments.Count);
Assert.Equal(11, arguments["a"]);
Assert.Equal(2, arguments["b"]);
Assert.Equal(3, arguments["c"]);
}

/// <summary>
/// Verify ability to merge <see cref="KernelArguments.ExecutionSettings"/>.
/// </summary>
[Fact]
public void VerifyArgumentSettingsMerge()
{
// Arrange
FunctionChoiceBehavior autoInvoke = FunctionChoiceBehavior.Auto();
MockAgent agentWithSettings = new() { Arguments = new(new PromptExecutionSettings() { FunctionChoiceBehavior = autoInvoke }) };
KernelArguments overrideArgumentsNoSettings = new();

// Act
KernelArguments? arguments = agentWithSettings.MergeArguments(overrideArgumentsNoSettings);

// Assert
Assert.NotNull(arguments);
Assert.NotNull(arguments.ExecutionSettings);
Assert.Single(arguments.ExecutionSettings);
Assert.StrictEqual(autoInvoke, arguments.ExecutionSettings.First().Value.FunctionChoiceBehavior);

// Arrange
FunctionChoiceBehavior noInvoke = FunctionChoiceBehavior.None();
KernelArguments overrideArgumentsWithSettings = new(new PromptExecutionSettings() { FunctionChoiceBehavior = noInvoke });

// Act
arguments = agentWithSettings.MergeArguments(overrideArgumentsWithSettings);

// Assert
Assert.NotNull(arguments);
Assert.NotNull(arguments.ExecutionSettings);
Assert.Single(arguments.ExecutionSettings);
Assert.StrictEqual(noInvoke, arguments.ExecutionSettings.First().Value.FunctionChoiceBehavior);
}
}
6 changes: 6 additions & 0 deletions dotnet/src/Agents/UnitTests/MockAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,10 @@ public override IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsy
this.InvokeCount++;
return this.Response.Select(m => new StreamingChatMessageContent(m.Role, m.Content)).ToAsyncEnumerable();
}

// Expose protected method for testing
public new KernelArguments? MergeArguments(KernelArguments? arguments)
{
return base.MergeArguments(arguments);
}
}

0 comments on commit 81953f2

Please sign in to comment.