From 3bfee7b64be60624a9aa0bb48c008dcac4d57fca Mon Sep 17 00:00:00 2001 From: Eduard van Valkenburg Date: Thu, 29 Aug 2024 18:01:53 +0200 Subject: [PATCH] Python: fix discrimator field for CMC (#8417) ### Motivation and Context Turns out we had a mistake in the way CMC discriminates content types, this PR fixes that ### Description Adds discriminator field to the item in the list rather then the list itself Adds additional tests ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- python/.vscode/tasks.json | 2 +- .../contents/chat_message_content.py | 17 +--- .../contents/test_chat_message_content.py | 92 ++++++++++++++++++- 3 files changed, 97 insertions(+), 14 deletions(-) diff --git a/python/.vscode/tasks.json b/python/.vscode/tasks.json index 9a6a7ecbfd69..dbd972976939 100644 --- a/python/.vscode/tasks.json +++ b/python/.vscode/tasks.json @@ -122,7 +122,7 @@ "pytest", "tests/unit/", "--last-failed", - "-v" + "-vv" ], "group": "test", "presentation": { diff --git a/python/semantic_kernel/contents/chat_message_content.py b/python/semantic_kernel/contents/chat_message_content.py index ced273de75a2..1b52a8c9ea65 100644 --- a/python/semantic_kernel/contents/chat_message_content.py +++ b/python/semantic_kernel/contents/chat_message_content.py @@ -3,7 +3,7 @@ import logging from enum import Enum from html import unescape -from typing import Any, ClassVar, Literal, Union, overload +from typing import Annotated, Any, ClassVar, Literal, overload from xml.etree.ElementTree import Element # nosec from defusedxml import ElementTree @@ -26,7 +26,6 @@ from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent from semantic_kernel.contents.kernel_content import KernelContent -from semantic_kernel.contents.streaming_text_content import StreamingTextContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason @@ -41,15 +40,9 @@ IMAGE_CONTENT_TAG: ImageContent, } -ITEM_TYPES = Union[ - AnnotationContent, - ImageContent, - TextContent, - StreamingTextContent, - FunctionResultContent, - FunctionCallContent, - FileReferenceContent, -] +ITEM_TYPES = ( + AnnotationContent | ImageContent | TextContent | FunctionResultContent | FunctionCallContent | FileReferenceContent +) logger = logging.getLogger(__name__) @@ -78,7 +71,7 @@ class ChatMessageContent(KernelContent): tag: ClassVar[str] = CHAT_MESSAGE_CONTENT_TAG role: AuthorRole name: str | None = None - items: list[ITEM_TYPES] = Field(default_factory=list, discriminator=DISCRIMINATOR_FIELD) + items: list[Annotated[ITEM_TYPES, Field(..., discriminator=DISCRIMINATOR_FIELD)]] = Field(default_factory=list) encoding: str | None = None finish_reason: FinishReason | None = None diff --git a/python/tests/unit/contents/test_chat_message_content.py b/python/tests/unit/contents/test_chat_message_content.py index 10997b9a0d98..9e7dcaa07b8a 100644 --- a/python/tests/unit/contents/test_chat_message_content.py +++ b/python/tests/unit/contents/test_chat_message_content.py @@ -284,8 +284,98 @@ def test_cmc_to_dict_keys(): "content": [{"type": "text", "text": "Hello, "}, {"type": "text", "text": "world!"}], }, ), + ( + { + "role": "user", + "items": [ + {"content_type": "text", "text": "Hello, "}, + {"content_type": "text", "text": "world!"}, + ], + }, + { + "role": "user", + "content": [{"type": "text", "text": "Hello, "}, {"type": "text", "text": "world!"}], + }, + ), + ( + { + "role": "user", + "items": [ + {"content_type": "annotation", "file_id": "test"}, + ], + }, + { + "role": "user", + "content": [{"type": "text", "text": "test None (Start Index=None->End Index=None)"}], + }, + ), + ( + { + "role": "user", + "items": [ + {"content_type": "file_reference", "file_id": "test"}, + ], + }, + { + "role": "user", + "content": [{"file_id": "test"}], + }, + ), + ( + { + "role": "user", + "items": [ + {"content_type": "function_call", "name": "test-test"}, + ], + }, + { + "role": "user", + "content": [{"id": None, "type": "function", "function": {"name": "test-test", "arguments": None}}], + }, + ), + ( + { + "role": "user", + "items": [ + {"content_type": "function_call", "name": "test-test"}, + {"content_type": "function_result", "name": "test-test", "result": "test", "id": "test"}, + ], + }, + { + "role": "user", + "content": [ + {"id": None, "type": "function", "function": {"name": "test-test", "arguments": None}}, + {"tool_call_id": "test", "content": "test"}, + ], + }, + ), + ( + { + "role": "user", + "items": [ + {"content_type": "image", "uri": "http://test"}, + ], + }, + { + "role": "user", + "content": [{"image_url": {"url": "http://test/"}, "type": "image_url"}], + }, + ), + ], + ids=[ + "user_content", + "user_with_name", + "user_item", + "function_call", + "function_result", + "multiple_items", + "multiple_items_serialize", + "annotations_serialize", + "file_reference_serialize", + "function_call_serialize", + "function_result_serialize", + "image_serialize", ], - ids=["user_content", "user_with_name", "user_item", "function_call", "function_result", "multiple_items"], ) def test_cmc_to_dict_items(input_args, expected_dict): message = ChatMessageContent(**input_args)