Skip to content

Commit

Permalink
Merge pull request #1569 from tkellogg/chat-adapter-pydantic
Browse files Browse the repository at this point in the history
chat_adapter: Format fields as JSON
  • Loading branch information
okhat authored Oct 7, 2024
2 parents 6a00c85 + b73a397 commit 58c2071
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
12 changes: 11 additions & 1 deletion dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import textwrap

from pydantic import TypeAdapter
import pydantic
from .base import Adapter
from typing import get_origin, get_args

Expand Down Expand Up @@ -73,10 +74,19 @@ def format_list(items):
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(items)])


def _format_field_value(value) -> str:
if isinstance(value, list):
return format_list(value)
elif isinstance(value, pydantic.BaseModel):
return value.model_dump_json()
else:
return str(value)


def format_fields(fields):
output = []
for k, v in fields.items():
v = v if not isinstance(v, list) else format_list(v)
v = _format_field_value(v)
output.append(f"[[ ## {k} ## ]]\n{v}")

return '\n\n'.join(output).strip()
Expand Down
32 changes: 31 additions & 1 deletion tests/functional/test_signature_typed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional, Union
from dspy.adapters.chat_adapter import _format_field_value
import pytest

import pydantic
Expand Down Expand Up @@ -114,6 +115,9 @@ class MySignature(dspy.Signature):
instance = build_model_instance()
parsed_instance = parser(instance.model_dump_json())

formatted_instance = _format_field_value(instance.model_dump_json())
assert formatted_instance == instance.model_dump_json(), f"{formatted_instance} != {instance.model_dump_json()}"

assert parsed_instance == instance, f"{instance} != {parsed_instance}"


Expand All @@ -128,13 +132,36 @@ class MySignature(dspy.Signature):
parsed_instance = parser(instance.model_dump_json())
assert parsed_instance == instance, f"{instance} != {parsed_instance}"

formatted_instance = _format_field_value(instance.model_dump_json())
assert formatted_instance == instance.model_dump_json(), f"{formatted_instance} != {instance.model_dump_json()}"

# Check null case
parsed_instance = parser("null")
assert parsed_instance == None, "Optional[MyModel] should be None"


def test_nested_pydantic():
class NestedModel(pydantic.BaseModel):
model: MyModel

class MySignature(dspy.Signature):
question: str = dspy.InputField()
answer: NestedModel = dspy.OutputField()

_, parser = get_field_and_parser(MySignature)

instance = NestedModel(model=build_model_instance())
parsed_instance = parser(instance.model_dump_json())

formatted_instance = _format_field_value(instance.model_dump_json())
assert formatted_instance == instance.model_dump_json(), f"{formatted_instance} != {instance.model_dump_json()}"

assert parsed_instance == instance, f"{instance} != {parsed_instance}"


def test_dataclass():
from dataclasses import dataclass
from dataclasses import dataclass, asdict
import ujson

@dataclass(frozen=True)
class MyDataclass:
Expand All @@ -152,3 +179,6 @@ class MySignature(dspy.Signature):
instance = MyDataclass("foobar", 42, 3.14, True)
parsed_instance = parser('{"string": "foobar", "number": 42, "floating": 3.14, "boolean": true}')
assert parsed_instance == instance, f"{instance} != {parsed_instance}"

formatted_instance = _format_field_value(ujson.dumps(asdict(instance)))
assert formatted_instance == ujson.dumps(asdict(instance)), f"{formatted_instance} != {ujson.dumps(asdict(instance))}"

0 comments on commit 58c2071

Please sign in to comment.