Skip to content

Commit

Permalink
working OpenAI classes!
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Jan 22, 2024
1 parent 0c3c771 commit 75e8f26
Show file tree
Hide file tree
Showing 14 changed files with 245 additions and 133 deletions.
20 changes: 13 additions & 7 deletions python/samples/kernel-syntax-examples/openai_function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,21 @@
async def main() -> None:
context = kernel.create_new_context()
context.variables["user_input"] = "I want to find a hotel in Seattle with free wifi and a pool."

messages = []
tool_call = None
response = chat_function.invoke_stream_async(context=context)
async for message in response:
print(message)
if tool_call_choices := context.objects["tool_calls"]:
for tool_call in tool_call_choices:
for tool in tool_call.values():
print(f"Function to be called: {tool['function'].name}")
print(f"Function parameters: \n{tool['function'].parse_arguments()}")
current = message[0]
messages.append(current)
if current.tool_calls:
if tool_call is None:
tool_call = current.tool_calls[0]
# continue
# tool_call.update(current.tool_calls[0])

if tool_call:
print(f"Function to be called: {tool_call.function.name}")
print(f"Function parameters: \n{tool_call.function.parse_arguments()}")
return
print("No function was called")
print(f"Output was: {str(context)}")
Expand Down
6 changes: 3 additions & 3 deletions python/samples/kernel-syntax-examples/self-critique_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ async def main() -> None:

config = dotenv_values(".env")

AZURE_COGNITIVE_SEARCH_ENDPOINT = config["AZURE_COGNITIVE_SEARCH_ENDPOINT"]
AZURE_COGNITIVE_SEARCH_ADMIN_KEY = config["AZURE_COGNITIVE_SEARCH_ADMIN_KEY"]
AZURE_COGNITIVE_SEARCH_ENDPOINT = config["AZURE_AISEARCH_URL"]
AZURE_COGNITIVE_SEARCH_ADMIN_KEY = config["AZURE_AISEARCH_API_KEY"]
AZURE_OPENAI_API_KEY = config["AZURE_OPENAI_API_KEY"]
AZURE_OPENAI_ENDPOINT = config["AZURE_OPENAI_ENDPOINT"]
vector_size = 1536
Expand Down Expand Up @@ -72,7 +72,7 @@ async def main() -> None:
kernel.register_memory_store(memory_store=connector)

print("Populating memory...")
await populate_memory(kernel)
# await populate_memory(kernel)

sk_prompt_rag = """
Assistant can have a conversation with you about any topic.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Class to hold chat messages."""
import json
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

from semantic_kernel.orchestration.context_variables import ContextVariables
from semantic_kernel.sk_pydantic import SKBaseModel
Expand All @@ -9,9 +9,21 @@
class FunctionCall(SKBaseModel):
"""Class to hold a function call response."""

name: str
arguments: str
id: str
name: Optional[str] = None
arguments: Optional[str] = None
# TODO: check if needed
id: Optional[str] = None

def update(self, name: str, arguments: Optional[str]):
"""Update the function call."""
if name:
if name != self.name:
self.name = name
if arguments:
if self.arguments is None:
self.arguments = arguments
else:
self.arguments += arguments

def parse_arguments(self) -> Dict[str, str]:
"""Parse the arguments into a dictionary."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Class to hold chat messages."""
from typing import Literal, Optional

from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall
from semantic_kernel.sk_pydantic import SKBaseModel


class ToolCall(SKBaseModel):
"""Class to hold a tool call response."""

id: Optional[str] = None
type: Optional[Literal["function"]] = "function"
function: Optional[FunctionCall] = None

def update(self, chunk: "ToolCall"):
"""Update the function call."""
if self.function:
self.function.update(chunk.function.name, chunk.function.arguments)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice

from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall
from semantic_kernel.connectors.ai.open_ai.models.chat.tool_calls import ToolCall
from semantic_kernel.connectors.ai.open_ai.responses.open_ai_chat_message_content import OpenAIChatMessageContent

if TYPE_CHECKING:
Expand All @@ -23,8 +25,8 @@ def __init__(
response: ChatCompletion,
metadata: Dict[str, Any],
request_settings: "AIRequestSettings",
function_call: Optional[Dict[str, Any]],
tool_calls: Optional[List[Dict[str, Any]]],
function_call: Optional[FunctionCall],
tool_calls: Optional[List[ToolCall]],
tool_message: Optional[str],
):
"""Initialize a chat response from Azure OpenAI."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice

from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall
from semantic_kernel.connectors.ai.open_ai.models.chat.tool_calls import ToolCall
from semantic_kernel.connectors.ai.open_ai.responses.open_ai_streaming_chat_message_content import (
OpenAIStreamingChatMessageContent,
)
Expand Down Expand Up @@ -33,8 +35,8 @@ def __init__(
chunk: ChatCompletionChunk,
metadata: Dict[str, Any],
request_settings: "AIRequestSettings",
function_call: Optional[Dict[str, Any]],
tool_calls: Optional[List[Dict[str, Any]]],
function_call: Optional[FunctionCall],
tool_calls: Optional[List[ToolCall]],
tool_message: Optional[str],
):
"""Initialize a chat response from Azure OpenAI."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice

from semantic_kernel.connectors.ai.open_ai.models.chat.function_call import FunctionCall
from semantic_kernel.connectors.ai.open_ai.models.chat.tool_calls import ToolCall
from semantic_kernel.models.contents import StreamingChatMessageContent

if TYPE_CHECKING:
Expand All @@ -25,17 +27,17 @@ class OpenAIStreamingChatMessageContent(StreamingChatMessageContent):
"""

inner_content: ChatCompletionChunk
function_call: Optional[Dict[str, Any]] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ToolCall]] = None

def __init__(
self,
choice: Choice,
chunk: ChatCompletionChunk,
metadata: Dict[str, Any],
request_settings: "AIRequestSettings",
function_call: Optional[Dict[str, Any]],
tool_calls: Optional[List[Dict[str, Any]]],
function_call: Optional[FunctionCall],
tool_calls: Optional[List[ToolCall]],
):
"""Initialize a chat response from OpenAI."""
super().__init__(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import AsyncGenerator, Dict, List, Optional, Union

from openai import AsyncStream
from openai.types import Completion
from openai.types.chat.chat_completion_chunk import Choice
from pydantic import PrivateAttr

from semantic_kernel.models.contents.kernel_content import KernelContent
from semantic_kernel.models.contents import TextContent


class OpenAITextResponse(KernelContent):
class OpenAITextContent(TextContent):
"""A text completion response from OpenAI.
For streaming responses, make sure to async loop through parse_stream before trying anything else.
Expand All @@ -20,47 +16,4 @@ class OpenAITextResponse(KernelContent):
- parse_stream: get the streaming content of the response.
"""

raw_response: Union[Completion, AsyncStream[Completion]]
_parsed_content: Optional[Dict[str, str]] = PrivateAttr(default_factory=dict)

@property
def content(self) -> Optional[str]:
"""Get the content of the response.
Content can be None when no text response was given, check if there are tool calls instead.
"""
if not isinstance(self.raw_response, Completion):
raise ValueError("content is not available for streaming responses, use stream_content instead.")
return self.raw_response.choices[0].text

@property
def all_content(self) -> List[Optional[str]]:
"""Get the content of the response.
Some or all content might be None, check if there are tool calls instead.
"""
if not isinstance(self.raw_response, Completion):
if self._parsed_content is not {}:
return list(self._parsed_content.values())
raise ValueError("all_content is not available for streaming responses, use stream_content instead.")
return [choice.text for choice in self.raw_response.choices]

async def parse_stream(self) -> AsyncGenerator[str, None]:
"""Get the streaming content of the response."""
if isinstance(self.raw_response, Completion):
raise ValueError("streaming_content is not available for regular responses, use content instead.")
async for chunk in self.raw_response:
if len(chunk.choices) == 0:
continue
for choice in chunk.choices:
self.parse_choice(choice)
if chunk.choices[0].delta.text:
yield chunk.choices[0].delta.text

def parse_choice(self, choice: Choice) -> None:
"""Parse a choice and store the text."""
if choice.delta.content is not None:
if choice.index in self._parsed_content:
self._parsed_content[choice.index] += choice.delta.text
else:
self._parsed_content[choice.index] = choice.delta.text
inner_content: Completion
Original file line number Diff line number Diff line change
Expand Up @@ -340,31 +340,53 @@ async def complete_chat_stream_async(
if not isinstance(response, AsyncStream):
raise ValueError("Expected an AsyncStream[ChatCompletionChunk] response.")

content = [""] * settings.number_of_responses
out_messages = {}
tool_messages_by_index = {}
tool_call_ids_by_index = {}
function_name_by_index = {}
function_arguments_by_index = {}
function_call_by_index = {}

async for chunk in response:
if len(chunk.choices) == 0:
continue
chunk_metadata = self.get_metadata_from_streaming_chat_response(chunk)

contents = [
self._create_return_content_stream(chunk, choice, chunk_metadata, settings) for choice in chunk.choices
]
for index, content in enumerate(contents):
if content.content is not None:
content[index] += str(content)
if content.tool_calls is not None:
tool_call_ids_by_index[index] += content.tool_calls
if content.function_call is not None:
if content.function_call["name"] is not None:
function_name_by_index[index] = content.function_call["name"]

function_arguments_by_index[index] += content.function_call["arguments"]
self._handle_updates(
contents, out_messages, tool_call_ids_by_index, function_call_by_index, tool_messages_by_index
)
yield contents

def _handle_updates(
self, contents, out_messages, tool_call_ids_by_index, function_call_by_index, tool_messages_by_index
):
"""Handle updates to the messages, tool_calls and function_calls.
This will be used for auto-invoking tools.
"""
for index, content in enumerate(contents):
if content.content is not None:
if index not in out_messages:
out_messages[index] = str(content)
else:
out_messages[index] += str(content)
if content.tool_calls is not None:
if index not in tool_call_ids_by_index:
tool_call_ids_by_index[index] = content.tool_calls
else:
for tc_index, tool_call in enumerate(content.tool_calls):
tool_call_ids_by_index[index][tc_index].update(tool_call)
if content.function_call is not None:
if index not in function_call_by_index:
function_call_by_index[index] = content.function_call
else:
function_call_by_index[index].update(content.function_call)
if content.tool_message is not None:
if index not in tool_messages_by_index:
tool_messages_by_index[index] = content.tool_message
else:
tool_messages_by_index[index] += content.tool_message

def _create_return_content_stream(
self,
chunk: ChatCompletionChunk,
Expand Down
Loading

0 comments on commit 75e8f26

Please sign in to comment.