From 07c47dbf409fa7518affccf8886f5cb87782304b Mon Sep 17 00:00:00 2001 From: AL-377 <535338194@qq.com> Date: Tue, 10 Oct 2023 17:51:47 +0800 Subject: [PATCH] Support recursive arrays in JSON when an item is an array --- outlines/text/json_schema.py | 22 +++++++++--- tests/text/test_json_schema.py | 65 ++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/outlines/text/json_schema.py b/outlines/text/json_schema.py index ddccde5c5..b9d6a84c8 100644 --- a/outlines/text/json_schema.py +++ b/outlines/text/json_schema.py @@ -125,6 +125,18 @@ def build_schedule_from_schema(schema: str): return reduced_schedule +def expand_item_json_schema(expanded_property: Dict, resolver: Callable[[str], Dict]): + """Recursively expand "$ref"s in "item"s.""" + if "items" not in expanded_property.keys(): + return + elif "$ref" in expanded_property["items"]: + expanded_property["items"] = expand_json_schema( + resolver(expanded_property["items"]["$ref"]), resolver + ) + else: + expand_item_json_schema(expanded_property["items"], resolver) + + def expand_json_schema( raw_schema: Dict, resolver: Callable[[str], Dict], @@ -166,12 +178,14 @@ def expand_json_schema( ) elif "type" in value and value["type"] == "array": # if item is a list expanded_properties[name] = value - if "$ref" in value["items"]: - expanded_properties[name]["items"] = expand_json_schema( - resolver(value["items"]["$ref"]), resolver - ) + + if "$ref" in value["items"] or ( + "type" in value["items"] and value["items"]["type"] == "array" + ): + expand_item_json_schema(expanded_properties[name], resolver) else: expanded_properties[name]["items"] = value["items"] + else: expanded_properties[name] = value diff --git a/tests/text/test_json_schema.py b/tests/text/test_json_schema.py index 19c01c3ec..239effb61 100644 --- a/tests/text/test_json_schema.py +++ b/tests/text/test_json_schema.py @@ -159,6 +159,71 @@ class Spam(BaseModel): ] +def test_pydantic_recursive_list_object(): + class ItemModel(BaseModel): + name: str + + class ArrayModel1(BaseModel): + item_model_lists: List[List[ItemModel]] + + class ArrayModel2(BaseModel): + nums: List[List[int]] + + class ArrayModel3(BaseModel): + array_model_lists: List[List[ArrayModel1]] + + schema = json.dumps(ArrayModel1.model_json_schema()) + schedule = build_schedule_from_schema(schema) + array_model_1_schema = { + "items": { + "items": { + "title": "ItemModel", + "type": "object", + "properties": {"name": {"title": "Name", "type": "string"}}, + }, + "type": "array", + }, + "title": "Item Model Lists", + "type": "array", + } + assert schedule == [ + '\\{[\\n ]*"item_model_lists"[\\n ]*:[\\n ]*', + array_model_1_schema, + "[\\n ]*\\}", + ] + + schema = json.dumps(ArrayModel2.model_json_schema()) + schedule = build_schedule_from_schema(schema) + assert schedule == [ + '\\{[\\n ]*"nums"[\\n ]*:[\\n ]*', + { + "items": {"items": {"type": "integer"}, "type": "array"}, + "title": "Nums", + "type": "array", + }, + "[\\n ]*\\}", + ] + + schema = json.dumps(ArrayModel3.model_json_schema()) + schedule = build_schedule_from_schema(schema) + assert schedule == [ + '\\{[\\n ]*"array_model_lists"[\\n ]*:[\\n ]*', + { + "items": { + "items": { + "title": "ArrayModel1", + "type": "object", + "properties": {"item_model_lists": array_model_1_schema}, + }, + "type": "array", + }, + "title": "Array Model Lists", + "type": "array", + }, + "[\\n ]*\\}", + ] + + def test_pydantic_union(): """Schemas with Union types."""