Skip to content

Commit

Permalink
feat(parsing): add support for pydantic dataclasses (#1655)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored and stainless-app[bot] committed Aug 20, 2024
1 parent e8c28f2 commit ecd6e92
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 14 deletions.
28 changes: 19 additions & 9 deletions src/openai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, is_given
from ..._compat import model_parse_json
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import to_strict_json_schema
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
Expand Down Expand Up @@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
return cast(FunctionDefinition, input_fn).get("strict") or False


def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
return issubclass(typ, pydantic.BaseModel)


def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))

if is_dataclass_like_type(response_format):
if not PYDANTIC_V2:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")

return pydantic.TypeAdapter(response_format).validate_json(content)

raise TypeError(f"Unable to automatically parse response format type {response_format}")


Expand All @@ -241,14 +243,22 @@ def type_to_response_format_param(
# can only be a `type`
response_format = cast(type, response_format)

if not is_basemodel_type(response_format):
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None

if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")

return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(response_format),
"name": response_format.__name__,
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}
26 changes: 22 additions & 4 deletions src/openai/lib/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from __future__ import annotations

from typing import Any
import inspect
from typing import Any, TypeVar
from typing_extensions import TypeGuard

import pydantic

from .._types import NOT_GIVEN
from .._utils import is_dict as _is_dict, is_list
from .._compat import model_json_schema
from .._compat import PYDANTIC_V2, model_json_schema

_T = TypeVar("_T")


def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
if inspect.isclass(model) and is_basemodel_type(model):
schema = model_json_schema(model)
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
schema = model.json_schema()
else:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")

def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
schema = model_json_schema(model)
return _ensure_strict_json_schema(schema, path=(), root=schema)


Expand Down Expand Up @@ -117,6 +126,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
return resolved


def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
return issubclass(typ, pydantic.BaseModel)


def is_dataclass_like_type(typ: type) -> bool:
"""Returns True if the given type likely used `@pydantic.dataclass`"""
return hasattr(typ, "__pydantic_config__")


def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
Expand Down
59 changes: 58 additions & 1 deletion tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import json
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, List, Callable, Optional
from typing_extensions import Literal, TypeVar

import httpx
Expand Down Expand Up @@ -317,6 +317,63 @@ class Location(BaseModel):
)


@pytest.mark.respx(base_url=base_url)
@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2")
def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
from pydantic.dataclasses import dataclass

@dataclass
class CalendarEvent:
name: str
date: str
participants: List[str]

completion = _make_snapshot_request(
lambda c: c.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
response_format=CalendarEvent,
),
content_snapshot=snapshot(
'{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}'
),
mock_client=client,
respx_mock=respx_mock,
)

assert print_obj(completion, monkeypatch) == snapshot(
"""\
ParsedChatCompletion[CalendarEvent](
choices=[
ParsedChoice[CalendarEvent](
finish_reason='stop',
index=0,
logprobs=None,
message=ParsedChatCompletionMessage[CalendarEvent](
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
function_call=None,
parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']),
refusal=None,
role='assistant',
tool_calls=[]
)
)
],
created=1723761008,
id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3',
model='gpt-4o-2024-08-06',
object='chat.completion',
service_tier=None,
system_fingerprint='fp_2a322c9ffc',
usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49)
)
"""
)


@pytest.mark.respx(base_url=base_url)
def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
completion = _make_snapshot_request(
Expand Down

0 comments on commit ecd6e92

Please sign in to comment.