Skip to content

Commit

Permalink
Turn State Infrastructure (#232)
Browse files Browse the repository at this point in the history
Added support for the turn state. resolves #209 

Additions
- Turn State
- Turn State Manager
- Conversation history
- Integration of Turn State with AI, Application, OpenAIPlanner classes

Missing that will come in subsequent PRs:
- `Application` unit tests
- `AI` class unit tests

A few major differences between this and the JS SDK:
- Dropped `Default` from `DefaultTurnState` to get `TurnState`. Not sure
why `Default` was there in the first place, I presume it's to signal to
the user that this component is configurable. We do that implictly in C#
with interfaces (ex. `ITurnStateManager`) and unsealed class (ex.
`TurnState`).
- `Application` class now has two generic parameters, `TState` and
`TTurnStateManager`. This was to solve the problem that there is
virtually no covariance in C# generic types (i.e
`ITurnStateManager<TurnState> var = TurnStateManager<ChildTurnState>`
throws compiler errors, even though `ChildTurnState` extends
`TurnState`).
  • Loading branch information
singhk97 authored Jul 11, 2023
1 parent cd540a5 commit ef713b6
Show file tree
Hide file tree
Showing 63 changed files with 2,471 additions and 618 deletions.
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
using Microsoft.Bot.Builder.M365.AI.Action;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Bot.Builder.M365.Tests.TestUtils;

namespace Microsoft.Bot.Builder.M365.Tests.AI
namespace Microsoft.Bot.Builder.M365.Tests.AITests
{
public class ActionCollectionTests
{
[Fact]
public void Test_Simple()
{
// Arrange
IActionCollection<TurnState> actionCollection = new ActionCollection<TurnState>();
IActionCollection<TestTurnState> actionCollection = new ActionCollection<TestTurnState>();
string name = "action";
ActionHandler<TurnState> handler = (turnContext, turnState, data, action) => Task.FromResult(true);
ActionHandler<TestTurnState> handler = (turnContext, turnState, data, action) => Task.FromResult(true);
bool allowOverrides = true;

// Act
actionCollection.SetAction(name, handler, allowOverrides);
ActionEntry<TurnState> entry = actionCollection.GetAction(name);
ActionEntry<TestTurnState> entry = actionCollection.GetAction(name);

// Assert
Assert.True(actionCollection.HasAction(name));
Expand All @@ -34,9 +30,9 @@ public void Test_Simple()
public void Test_Set_NonOverridable_Action_Throws_Exception()
{
// Arrange
IActionCollection<TurnState> actionCollection = new ActionCollection<TurnState>();
IActionCollection<TestTurnState> actionCollection = new ActionCollection<TestTurnState>();
string name = "action";
ActionHandler<TurnState> handler = (turnContext, turnState, data, action) => Task.FromResult(true);
ActionHandler<TestTurnState> handler = (turnContext, turnState, data, action) => Task.FromResult(true);
bool allowOverrides = false;
actionCollection.SetAction(name, handler, allowOverrides);

Expand All @@ -52,7 +48,7 @@ public void Test_Set_NonOverridable_Action_Throws_Exception()
public void Test_Get_NonExistent_Action()
{
// Arrange
IActionCollection<TurnState> actionCollection = new ActionCollection<TurnState>();
IActionCollection<TestTurnState> actionCollection = new ActionCollection<TestTurnState>();
var nonExistentAction = "non existent action";

// Act
Expand All @@ -67,7 +63,7 @@ public void Test_Get_NonExistent_Action()
public void Test_HasAction_False()
{
// Arrange
IActionCollection<TurnState> actionCollection = new ActionCollection<TurnState>();
IActionCollection<TestTurnState> actionCollection = new ActionCollection<TestTurnState>();
var nonExistentAction = "non existent action";

// Act
Expand All @@ -81,8 +77,8 @@ public void Test_HasAction_False()
public void Test_HasAction_True()
{
// Arrange
IActionCollection<TurnState> actionCollection = new ActionCollection<TurnState>();
ActionHandler<TurnState> handler = (turnContext, turnState, data, action) => Task.FromResult(true);
IActionCollection<TestTurnState> actionCollection = new ActionCollection<TestTurnState>();
ActionHandler<TestTurnState> handler = (turnContext, turnState, data, action) => Task.FromResult(true);
var name = "actionName";

// Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Moq;
using System.Reflection;
using Microsoft.Bot.Builder.M365.Exceptions;
using Microsoft.Bot.Builder.M365.Tests.TestUtils;

namespace Microsoft.Bot.Builder.M365.Tests.AITests
{
Expand All @@ -22,13 +23,12 @@ public async void Test_ReviewPrompt_ThrowsException()
var endpoint = "randomEndpoint";

var botAdapterMock = new Mock<BotAdapter>();
// TODO: when TurnState is implemented, get the user input
var activity = new Activity()
{
Text = "input",
};
var turnContext = new TurnContext(botAdapterMock.Object, activity);
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var promptTemplate = new PromptTemplate(
"prompt",
new PromptTemplateConfiguration
Expand All @@ -47,7 +47,7 @@ public async void Test_ReviewPrompt_ThrowsException()
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<AzureContentSafetyTextAnalysisRequest>())).ThrowsAsync(exception);

var options = new AzureContentSafetyModeratorOptions(apiKey, endpoint, ModerationType.Both);
var moderator = new AzureContentSafetyModerator<TurnState>(options);
var moderator = new AzureContentSafetyModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand All @@ -68,13 +68,13 @@ public async void Test_ReviewPrompt_Flagged(ModerationType moderate)
var endpoint = "randomEndpoint";

var botAdapterMock = new Mock<BotAdapter>();
// TODO: when TurnState is implemented, get the user input
// TODO: when TestTurnState is implemented, get the user input
var activity = new Activity()
{
Text = "input",
};
var turnContext = new TurnContext(botAdapterMock.Object, activity);
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var promptTemplate = new PromptTemplate(
"prompt",
new PromptTemplateConfiguration
Expand All @@ -100,7 +100,7 @@ public async void Test_ReviewPrompt_Flagged(ModerationType moderate)
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<AzureContentSafetyTextAnalysisRequest>())).ReturnsAsync(response);

var options = new AzureContentSafetyModeratorOptions(apiKey, endpoint, moderate);
var moderator = new AzureContentSafetyModerator<TurnState>(options);
var moderator = new AzureContentSafetyModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand Down Expand Up @@ -133,13 +133,13 @@ public async void Test_ReviewPrompt_NotFlagged(ModerationType moderate)
var endpoint = "randomEndpoint";

var botAdapterMock = new Mock<BotAdapter>();
// TODO: when TurnState is implemented, get the user input
// TODO: when TestTurnState is implemented, get the user input
var activity = new Activity()
{
Text = "input",
};
var turnContext = new TurnContext(botAdapterMock.Object, activity);
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var promptTemplate = new PromptTemplate(
"prompt",
new PromptTemplateConfiguration
Expand All @@ -165,7 +165,7 @@ public async void Test_ReviewPrompt_NotFlagged(ModerationType moderate)
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<AzureContentSafetyTextAnalysisRequest>())).ReturnsAsync(response);

var options = new AzureContentSafetyModeratorOptions(apiKey, endpoint, moderate);
var moderator = new AzureContentSafetyModerator<TurnState>(options);
var moderator = new AzureContentSafetyModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand All @@ -183,7 +183,7 @@ public async void Test_ReviewPlan_ThrowsException()
var endpoint = "randomEndpoint";

var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var plan = new Plan(new List<IPredictedCommand>()
{
new PredictedDoCommand("action"),
Expand All @@ -195,7 +195,7 @@ public async void Test_ReviewPlan_ThrowsException()
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<AzureContentSafetyTextAnalysisRequest>())).ThrowsAsync(exception);

var options = new AzureContentSafetyModeratorOptions(apiKey, endpoint, ModerationType.Both);
var moderator = new AzureContentSafetyModerator<TurnState>(options);
var moderator = new AzureContentSafetyModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand All @@ -216,7 +216,7 @@ public async void Test_ReviewPlan_Flagged(ModerationType moderate)
var endpoint = "randomEndpoint";

var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var plan = new Plan(new List<IPredictedCommand>()
{
new PredictedDoCommand("action"),
Expand All @@ -235,7 +235,7 @@ public async void Test_ReviewPlan_Flagged(ModerationType moderate)
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<AzureContentSafetyTextAnalysisRequest>())).ReturnsAsync(response);

var options = new AzureContentSafetyModeratorOptions(apiKey, endpoint, moderate);
var moderator = new AzureContentSafetyModerator<TurnState>(options);
var moderator = new AzureContentSafetyModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand Down Expand Up @@ -268,7 +268,7 @@ public async void Test_ReviewPlan_NotFlagged(ModerationType moderate)
var endpoint = "randomEndpoint";

var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var plan = new Plan(new List<IPredictedCommand>()
{
new PredictedDoCommand("action"),
Expand All @@ -287,7 +287,7 @@ public async void Test_ReviewPlan_NotFlagged(ModerationType moderate)
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<AzureContentSafetyTextAnalysisRequest>())).ReturnsAsync(response);

var options = new AzureContentSafetyModeratorOptions(apiKey, endpoint, moderate);
var moderator = new AzureContentSafetyModerator<TurnState>(options);
var moderator = new AzureContentSafetyModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using Microsoft.Bot.Schema;
using Moq;
using System.Reflection;
using Microsoft.Bot.Builder.M365.State;
using Microsoft.Bot.Builder.M365.Tests.TestUtils;

namespace Microsoft.Bot.Builder.M365.Tests.AITests
{
Expand All @@ -21,13 +23,12 @@ public async void Test_ReviewPrompt_ThrowsException()
var apiKey = "randomApiKey";

var botAdapterMock = new Mock<BotAdapter>();
// TODO: when TurnState is implemented, get the user input
var activity = new Activity()
{
Text = "input",
};
var turnContext = new TurnContext(botAdapterMock.Object, activity);
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var promptTemplate = new PromptTemplate(
"prompt",
new PromptTemplateConfiguration
Expand All @@ -46,7 +47,7 @@ public async void Test_ReviewPrompt_ThrowsException()
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<string>(), It.IsAny<string>())).ThrowsAsync(exception);

var options = new OpenAIModeratorOptions(apiKey, ModerationType.Both);
var moderator = new OpenAIModerator<TurnState>(options);
var moderator = new OpenAIModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand All @@ -66,13 +67,13 @@ public async void Test_ReviewPrompt_Flagged(ModerationType moderate)
var apiKey = "randomApiKey";

var botAdapterMock = new Mock<BotAdapter>();
// TODO: when TurnState is implemented, get the user input
// TODO: when TestTurnState is implemented, get the user input
var activity = new Activity()
{
Text = "input",
};
var turnContext = new TurnContext(botAdapterMock.Object, activity);
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var promptTemplate = new PromptTemplate(
"prompt",
new PromptTemplateConfiguration
Expand Down Expand Up @@ -122,7 +123,7 @@ public async void Test_ReviewPrompt_Flagged(ModerationType moderate)
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<string>(), It.IsAny<string>())).ReturnsAsync(response);

var options = new OpenAIModeratorOptions(apiKey, moderate);
var moderator = new OpenAIModerator<TurnState>(options);
var moderator = new OpenAIModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand Down Expand Up @@ -151,7 +152,7 @@ public async void Test_ReviewPlan_ThrowsException()
var apiKey = "randomApiKey";

var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var plan = new Plan(new List<IPredictedCommand>()
{
new PredictedDoCommand("action"),
Expand All @@ -163,7 +164,7 @@ public async void Test_ReviewPlan_ThrowsException()
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<string>(), It.IsAny<string>())).ThrowsAsync(exception);

var options = new OpenAIModeratorOptions(apiKey, ModerationType.Both);
var moderator = new OpenAIModerator<TurnState>(options);
var moderator = new OpenAIModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand All @@ -183,7 +184,7 @@ public async void Test_ReviewPlan_Flagged(ModerationType moderate)
var apiKey = "randomApiKey";

var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var turnStateMock = new Mock<TestTurnState>();
var plan = new Plan(new List<IPredictedCommand>()
{
new PredictedDoCommand("action"),
Expand Down Expand Up @@ -226,7 +227,7 @@ public async void Test_ReviewPlan_Flagged(ModerationType moderate)
clientMock.Setup(client => client.ExecuteTextModeration(It.IsAny<string>(), It.IsAny<string>())).ReturnsAsync(response);

var options = new OpenAIModeratorOptions(apiKey, moderate);
var moderator = new OpenAIModerator<TurnState>(options);
var moderator = new OpenAIModerator<TestTurnState>(options);
moderator.GetType().GetField("_client", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(moderator, clientMock.Object);

// Act
Expand Down
Loading

0 comments on commit ef713b6

Please sign in to comment.