Skip to content

Commit

Permalink
Support recursive arrays in JSON when an item is an array
Browse files Browse the repository at this point in the history
  • Loading branch information
AL-377 authored and brandonwillard committed Oct 19, 2023
1 parent f847632 commit 07c47db
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
22 changes: 18 additions & 4 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ def build_schedule_from_schema(schema: str):
return reduced_schedule


def expand_item_json_schema(expanded_property: Dict, resolver: Callable[[str], Dict]):
"""Recursively expand "$ref"s in "item"s."""
if "items" not in expanded_property.keys():
return
elif "$ref" in expanded_property["items"]:
expanded_property["items"] = expand_json_schema(
resolver(expanded_property["items"]["$ref"]), resolver
)
else:
expand_item_json_schema(expanded_property["items"], resolver)


def expand_json_schema(
raw_schema: Dict,
resolver: Callable[[str], Dict],
Expand Down Expand Up @@ -166,12 +178,14 @@ def expand_json_schema(
)
elif "type" in value and value["type"] == "array": # if item is a list
expanded_properties[name] = value
if "$ref" in value["items"]:
expanded_properties[name]["items"] = expand_json_schema(
resolver(value["items"]["$ref"]), resolver
)

if "$ref" in value["items"] or (
"type" in value["items"] and value["items"]["type"] == "array"
):
expand_item_json_schema(expanded_properties[name], resolver)
else:
expanded_properties[name]["items"] = value["items"]

else:
expanded_properties[name] = value

Expand Down
65 changes: 65 additions & 0 deletions tests/text/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,71 @@ class Spam(BaseModel):
]


def test_pydantic_recursive_list_object():
class ItemModel(BaseModel):
name: str

class ArrayModel1(BaseModel):
item_model_lists: List[List[ItemModel]]

class ArrayModel2(BaseModel):
nums: List[List[int]]

class ArrayModel3(BaseModel):
array_model_lists: List[List[ArrayModel1]]

schema = json.dumps(ArrayModel1.model_json_schema())
schedule = build_schedule_from_schema(schema)
array_model_1_schema = {
"items": {
"items": {
"title": "ItemModel",
"type": "object",
"properties": {"name": {"title": "Name", "type": "string"}},
},
"type": "array",
},
"title": "Item Model Lists",
"type": "array",
}
assert schedule == [
'\\{[\\n ]*"item_model_lists"[\\n ]*:[\\n ]*',
array_model_1_schema,
"[\\n ]*\\}",
]

schema = json.dumps(ArrayModel2.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'\\{[\\n ]*"nums"[\\n ]*:[\\n ]*',
{
"items": {"items": {"type": "integer"}, "type": "array"},
"title": "Nums",
"type": "array",
},
"[\\n ]*\\}",
]

schema = json.dumps(ArrayModel3.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'\\{[\\n ]*"array_model_lists"[\\n ]*:[\\n ]*',
{
"items": {
"items": {
"title": "ArrayModel1",
"type": "object",
"properties": {"item_model_lists": array_model_1_schema},
},
"type": "array",
},
"title": "Array Model Lists",
"type": "array",
},
"[\\n ]*\\}",
]


def test_pydantic_union():
"""Schemas with Union types."""

Expand Down

0 comments on commit 07c47db

Please sign in to comment.