Skip to content

Commit

Permalink
Fix whitespace and control character handling in JSON guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 16, 2023
1 parent b5c2241 commit ff4ebb3
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 46 deletions.
25 changes: 12 additions & 13 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Dict

STRING_INNER = r'(?:[^"\\]|\\.)'
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING = f'"{STRING_INNER}*"'
INTEGER = r"(0|[1-9][0-9]*)"
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
Expand Down Expand Up @@ -142,7 +142,7 @@ def expand_json_schema(raw_schema: Dict, definitions: Dict):
return raw_schema


def build_schedule_from_instance(instance: Dict, indent: int = 0):
def build_schedule_from_instance(instance: Dict):
"""Build a generation schedule from a instance.
This recursively follows the references to other instances.
Expand All @@ -163,27 +163,26 @@ def build_schedule_from_instance(instance: Dict, indent: int = 0):
"""
schedule = []
if "properties" in instance:
schedule.append("{\n")
schedule += build_schedule_from_instance(instance["properties"], indent + 2)
if indent > 0:
schedule.append(" " * indent)
schedule.append("}")
schedule.append(r"\{")
schedule += build_schedule_from_instance(instance["properties"])
schedule.append(r"\}")
else:
for i, (name, annotation) in enumerate(instance.items()):
schedule.append(" " * indent)
schedule.append(f'"{name}": ')
whitespace = r"[\n ]*"
schedule.append(f'{whitespace}"{name}"{whitespace}:{whitespace}')

if "anyOf" in annotation:
schedule.append(annotation)
elif annotation["type"] == "object":
schedule += build_schedule_from_instance(annotation, indent)
schedule += build_schedule_from_instance(annotation)
else:
schedule.append(annotation)

# We cannot add commas after the last key-value pair in JSON
if i == len(instance) - 1:
schedule.append("\n")
schedule.append(whitespace)
else:
schedule.append(",\n")
schedule.append(f"{whitespace},")

return schedule

Expand All @@ -205,7 +204,7 @@ def match_step_to_regex(step):
"""
match step:
case str() as step:
return re.escape(step)
return step

case {"enum": choices, "type": "string"}:
choices = [f'"{re.escape(choice)}"' for choice in choices]
Expand Down
2 changes: 1 addition & 1 deletion tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class Spam(BaseModel):
sequence = generate.json(model, Spam, max_tokens=1000)(prompt, rng=rng)
parsed = json.loads(sequence)
assert isinstance(parsed["foo"], int)
assert isinstance(parsed["bar"], float)
assert isinstance(parsed["bar"], int)
assert isinstance(parsed["spam"], str)
assert isinstance(parsed["fuzz"], bool)
assert len(parsed["spam"]) == 10
Expand Down
62 changes: 30 additions & 32 deletions tests/text/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ class User(BaseModel):
schema = json.dumps(User.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "user_id": ',
'\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*',
{"title": "User Id", "type": "integer"},
',\n "name": ',
'[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
',\n "maxlength_name": ',
'[\\n ]*,[\\n ]*"maxlength_name"[\\n ]*:[\\n ]*',
{"title": "Maxlength Name", "type": "string", "maxLength": 10},
',\n "minlength_name": ',
'[\\n ]*,[\\n ]*"minlength_name"[\\n ]*:[\\n ]*',
{"title": "Minlength Name", "type": "string", "minLength": 10},
',\n "value": ',
'[\\n ]*,[\\n ]*"value"[\\n ]*:[\\n ]*',
{"title": "Value", "type": "number"},
',\n "is_true": ',
'[\\n ]*,[\\n ]*"is_true"[\\n ]*:[\\n ]*',
{"title": "Is True", "type": "boolean"},
"\n}",
"[\\n ]*\\}",
]


Expand All @@ -53,9 +53,9 @@ class Foo(BaseModel):
schema = json.dumps(Foo.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "bar": ',
'\\{[\\n ]*"bar"[\\n ]*:[\\n ]*',
{"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Bar"},
"\n}",
"[\\n ]*\\}",
]


Expand All @@ -67,11 +67,11 @@ class User(BaseModel):
schema = json.dumps(User.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "user_id": ',
'\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*',
{"title": "User Id", "type": "integer"},
',\n "value": ',
'[\\n ]*,[\\n ]*"value"[\\n ]*:[\\n ]*',
{"title": "Value", "type": "array", "items": {"type": "number"}},
"\n}",
"[\\n ]*\\}",
]


Expand All @@ -88,15 +88,15 @@ class User(BaseModel):
schema = json.dumps(User.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "user_id": ',
'\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*',
{"title": "User Id", "type": "integer"},
',\n "name": ',
'[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*',
{
"title": "Name",
"enum": ["John", "Marc", "Michel"],
"type": "string",
},
"\n}",
"[\\n ]*\\}",
]


Expand All @@ -122,15 +122,15 @@ class Spam(BaseModel):
schema = json.dumps(Spam.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "foo": {\n "count": ',
'\\{[\\n ]*"foo"[\\n ]*:[\\n ]*\\{[\\n ]*"count"[\\n ]*:[\\n ]*',
{"title": "Count", "type": "integer"},
',\n "size": {\n "buzz": ',
'[\\n ]*,[\\n ]*"size"[\\n ]*:[\\n ]*\\{[\\n ]*"buzz"[\\n ]*:[\\n ]*',
{"title": "Buzz", "type": "string"},
'\n }\n },\n "bars": {\n "apple": ',
'[\\n ]*\\}[\\n ]*\\}[\\n ]*,[\\n ]*"bars"[\\n ]*:[\\n ]*\\{[\\n ]*"apple"[\\n ]*:[\\n ]*',
{"title": "Apple", "type": "string"},
',\n "banana": ',
'[\\n ]*,[\\n ]*"banana"[\\n ]*:[\\n ]*',
{"title": "Banana", "type": "string"},
"\n }\n}",
"[\\n ]*\\}[\\n ]*\\}",
]


Expand All @@ -145,7 +145,7 @@ class Spam(BaseModel):
schema = json.dumps(Spam.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "foo": ',
'\\{[\\n ]*"foo"[\\n ]*:[\\n ]*',
{
"items": {
"title": "Foo",
Expand All @@ -155,7 +155,7 @@ class Spam(BaseModel):
"title": "Foo",
"type": "array",
},
"\n}",
"[\\n ]*\\}",
]


Expand All @@ -169,23 +169,23 @@ class Spam(BaseModel):
schema = json.dumps(Spam.model_json_schema())
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "foo": ',
'\\{[\\n ]*"foo"[\\n ]*:[\\n ]*',
{"title": "Foo", "type": "integer"},
',\n "bar": ',
'[\\n ]*,[\\n ]*"bar"[\\n ]*:[\\n ]*',
{"title": "Bar", "anyOf": [{"type": "number"}, {"type": "string"}]},
"\n}",
"[\\n ]*\\}",
]


def test_json_schema():
schema = '{"title": "User", "type": "object", "properties": {"user_id": {"title": "User Id", "type": "integer"}, "name": {"title": "Name", "type": "string"}}, "required": ["user_id", "name"]}'
schedule = build_schedule_from_schema(schema)
assert schedule == [
'{\n "user_id": ',
'\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*',
{"title": "User Id", "type": "integer"},
',\n "name": ',
'[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
"\n}",
"[\\n ]*\\}",
]


Expand Down Expand Up @@ -317,7 +317,7 @@ def test_match_number(pattern, does_match):
"type": "object",
"properties": {"count": {"title": "Count", "type": "integer"}},
},
'\\{\\\n\\ \\ "count":\\ ' + INTEGER + "\\\n\\}",
'\\{[\\n ]*"count"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*\\}',
[('{\n "count": 100\n}', True)],
),
(
Expand Down Expand Up @@ -346,9 +346,7 @@ def test_match_number(pattern, does_match):
}
},
},
'\\{\\\n\\ \\ "fuzz":\\ \\{\\\n\\ \\ \\ \\ "spam":\\ '
+ INTEGER
+ "\\\n\\ \\ \\}\\\n\\}",
f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}',
[('{\n "fuzz": {\n "spam": 100\n }\n}', True)],
),
],
Expand Down

0 comments on commit ff4ebb3

Please sign in to comment.