From 329d5728d4d11f5b3c733feaed6f7c4fbbc4a517 Mon Sep 17 00:00:00 2001 From: Chris <66376200+crickman@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:27:25 -0700 Subject: [PATCH] .Net Agents - Include code and output generated by code-interpreter in streaming output (#9068) ### Motivation and Context As part of building sample for Learn site updated, discovered that generated code wasn't visible as streamed output. ### Description Showing generated code maintains parity with non-streamed features. ### Contribution Checklist - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone :smile: --- .../Agents/OpenAIAssistant_Streaming.cs | 70 +++++++++++++++---- .../OpenAI/Internal/AssistantThreadActions.cs | 39 +++++++++++ 2 files changed, 94 insertions(+), 15 deletions(-) diff --git a/dotnet/samples/Concepts/Agents/OpenAIAssistant_Streaming.cs b/dotnet/samples/Concepts/Agents/OpenAIAssistant_Streaming.cs index 127aa8ba5657..e394b8c49dad 100644 --- a/dotnet/samples/Concepts/Agents/OpenAIAssistant_Streaming.cs +++ b/dotnet/samples/Concepts/Agents/OpenAIAssistant_Streaming.cs @@ -11,23 +11,24 @@ namespace Agents; /// public class OpenAIAssistant_Streaming(ITestOutputHelper output) : BaseAgentsTest(output) { - private const string ParrotName = "Parrot"; - private const string ParrotInstructions = "Repeat the user message in the voice of a pirate and then end with a parrot sound."; - [Fact] public async Task UseStreamingAssistantAgentAsync() { + const string AgentName = "Parrot"; + const string AgentInstructions = "Repeat the user message in the voice of a pirate and then end with a parrot sound."; + // Define the agent OpenAIAssistantAgent agent = - await OpenAIAssistantAgent.CreateAsync( - kernel: new(), - clientProvider: this.GetClientProvider(), - definition: new OpenAIAssistantDefinition(this.Model) - { - Instructions = ParrotInstructions, - Name = ParrotName, - Metadata = AssistantSampleMetadata, - }); + await OpenAIAssistantAgent.CreateAsync( + kernel: new(), + clientProvider: this.GetClientProvider(), + definition: new OpenAIAssistantDefinition(this.Model) + { + Instructions = AgentInstructions, + Name = AgentName, + EnableCodeInterpreter = true, + Metadata = AssistantSampleMetadata, + }); // Create a thread for the agent conversation. string threadId = await agent.CreateThreadAsync(new OpenAIThreadCreationOptions { Metadata = AssistantSampleMetadata }); @@ -44,7 +45,8 @@ await OpenAIAssistantAgent.CreateAsync( [Fact] public async Task UseStreamingAssistantAgentWithPluginAsync() { - const string MenuInstructions = "Answer questions about the menu."; + const string AgentName = "Host"; + const string AgentInstructions = "Answer questions about the menu."; // Define the agent OpenAIAssistantAgent agent = @@ -53,8 +55,8 @@ await OpenAIAssistantAgent.CreateAsync( clientProvider: this.GetClientProvider(), definition: new OpenAIAssistantDefinition(this.Model) { - Instructions = MenuInstructions, - Name = "Host", + Instructions = AgentInstructions, + Name = AgentName, Metadata = AssistantSampleMetadata, }); @@ -73,6 +75,36 @@ await OpenAIAssistantAgent.CreateAsync( await DisplayChatHistoryAsync(agent, threadId); } + [Fact] + public async Task UseStreamingAssistantWithCodeInterpreterAsync() + { + const string AgentName = "MathGuy"; + const string AgentInstructions = "Solve math problems with code."; + + // Define the agent + OpenAIAssistantAgent agent = + await OpenAIAssistantAgent.CreateAsync( + kernel: new(), + clientProvider: this.GetClientProvider(), + definition: new OpenAIAssistantDefinition(this.Model) + { + Instructions = AgentInstructions, + Name = AgentName, + EnableCodeInterpreter = true, + Metadata = AssistantSampleMetadata, + }); + + // Create a thread for the agent conversation. + string threadId = await agent.CreateThreadAsync(new OpenAIThreadCreationOptions { Metadata = AssistantSampleMetadata }); + + // Respond to user input + await InvokeAgentAsync(agent, threadId, "Is 191 a prime number?"); + await InvokeAgentAsync(agent, threadId, "Determine the values in the Fibonacci sequence that that are less then the value of 101"); + + // Output the entire chat history + await DisplayChatHistoryAsync(agent, threadId); + } + // Local function to invoke agent and display the conversation messages. private async Task InvokeAgentAsync(OpenAIAssistantAgent agent, string threadId, string input) { @@ -83,6 +115,7 @@ private async Task InvokeAgentAsync(OpenAIAssistantAgent agent, string threadId, ChatHistory history = []; bool isFirst = false; + bool isCode = false; await foreach (StreamingChatMessageContent response in agent.InvokeStreamingAsync(threadId, messages: history)) { if (string.IsNullOrEmpty(response.Content)) @@ -90,6 +123,13 @@ private async Task InvokeAgentAsync(OpenAIAssistantAgent agent, string threadId, continue; } + // Differentiate between assistant and tool messages + if (isCode != (response.Metadata?.ContainsKey(OpenAIAssistantAgent.CodeInterpreterMetadataKey) ?? false)) + { + isFirst = false; + isCode = !isCode; + } + if (!isFirst) { Console.WriteLine($"\n# {response.Role} - {response.AuthorName ?? "*"}:"); diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs index 6d1d589ea290..fbefc51642d3 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs @@ -406,6 +406,14 @@ public static async IAsyncEnumerable InvokeStreamin break; } } + else if (update is RunStepDetailsUpdate detailsUpdate) + { + StreamingChatMessageContent? toolContent = GenerateStreamingCodeInterpreterContent(agent.GetName(), detailsUpdate); + if (toolContent != null) + { + yield return toolContent; + } + } else if (update is RunStepUpdate stepUpdate) { switch (stepUpdate.UpdateKind) @@ -416,6 +424,8 @@ public static async IAsyncEnumerable InvokeStreamin case StreamingUpdateReason.RunStepCompleted: currentStep = null; break; + default: + break; } } } @@ -571,6 +581,35 @@ private static StreamingChatMessageContent GenerateStreamingMessageContent(strin return content; } + private static StreamingChatMessageContent? GenerateStreamingCodeInterpreterContent(string? assistantName, RunStepDetailsUpdate update) + { + StreamingChatMessageContent content = + new(AuthorRole.Assistant, content: null) + { + AuthorName = assistantName, + }; + + // Process text content + if (update.CodeInterpreterInput != null) + { + content.Items.Add(new StreamingTextContent(update.CodeInterpreterInput)); + content.Metadata = new Dictionary { { OpenAIAssistantAgent.CodeInterpreterMetadataKey, true } }; + } + + if ((update.CodeInterpreterOutputs?.Count ?? 0) > 0) + { + foreach (var output in update.CodeInterpreterOutputs!) + { + if (output.ImageFileId != null) + { + content.Items.Add(new StreamingFileReferenceContent(output.ImageFileId)); + } + } + } + + return content.Items.Count > 0 ? content : null; + } + private static AnnotationContent GenerateAnnotationContent(TextAnnotation annotation) { string? fileId = null;