From 411eaaf2d9426a56f7fdb99a1d7073dadd463806 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 24 May 2024 01:12:31 -0500 Subject: [PATCH] Use less problematic whitespace token (#916) Fixes #839 #908 #690 #450 ## Problem A major problem, especially with smaller language models, is the repetition problem. For example, let's say a model is generating json and must provide 12 space tokens for indentation in json output. Often a language model will assign a high probability to a 13th space token, and do the same for a 14th space, and then enter an infinite space generation loop. This is a problem with NLG that has been known for half a decade, but only has mitigations (mirostat, repetition penalty, using hundreds of billions of weights, etc), no absolute solutions (except for **structured generation**) ## Solution For structured json generation, we set a sane default whitespace pattern of `r"[ ]?"`. This removes all newlines and indentation. It disallows any syntactic whitespace beyond a single space separator. Users can still set the argument `whitespace_pattern=` if they want different behavior --- docs/reference/json.md | 4 ++-- outlines/fsm/json_schema.py | 2 +- tests/fsm/test_json_schema.py | 45 ++++++++++++----------------------- 3 files changed, 18 insertions(+), 33 deletions(-) diff --git a/docs/reference/json.md b/docs/reference/json.md index 3b5976f19..85e1a846a 100644 --- a/docs/reference/json.md +++ b/docs/reference/json.md @@ -36,10 +36,10 @@ print(result) !!! Note "JSON and whitespaces" - By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string: + By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows: ```python - generator = generate.json(model, User, whitespace_pattern="") + generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` !!! Note "Performance" diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index dbd2baa40..0e0d25bfc 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -16,7 +16,7 @@ NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" BOOLEAN = r"(true|false)" NULL = r"null" -WHITESPACE = r"[\n ]*" +WHITESPACE = r"[ ]?" type_to_regex = { "string": STRING, diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index b12f9576e..bc836ac8b 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -215,8 +215,8 @@ def test_match_number(pattern, does_match): "properties": {"count": {"title": "Count", "type": "integer"}}, "required": ["count"], }, - '\\{[\\n ]*"count"[\\n ]*:[\\n ]*(-)?(0|[1-9][0-9]*)[\\n ]*\\}', - [('{\n "count": 100\n}', True)], + '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', + [('{ "count": 100 }', True)], ), # array ( @@ -277,7 +277,7 @@ def test_match_number(pattern, does_match): rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", [ ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), - ("""{ "test_dict":{"foo":"bar"\n}}""", True), + ("""{ "test_dict":{"foo":"bar" }}""", True), ("""{ "test_dict":{}}""", True), ("""{ "WRONG_KEY":{}}""", False), ("""{ "test_dict":{"wrong_type" 1}}""", False), @@ -369,8 +369,8 @@ def test_match_number(pattern, does_match): }, "required": ["fuzz"], }, - f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}', - [('{\n "fuzz": {\n "spam": 100\n }\n}', True)], + f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', + [('{ "fuzz": { "spam": 100 }}', True)], ), # Schema with a reference ( @@ -384,7 +384,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "a"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], ), ( @@ -399,7 +399,7 @@ def test_match_number(pattern, does_match): }, "required": ["user_id", "name", "name2"], }, - f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], ), ( @@ -441,7 +441,7 @@ def test_match_number(pattern, does_match): } }, }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', [ ( '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', @@ -462,7 +462,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*({STRING}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "weapon" : "sword" }', True), @@ -482,7 +482,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" , "weapon" : "sword" }', True), ( @@ -506,7 +506,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', [ ( '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', @@ -530,7 +530,7 @@ def test_match_number(pattern, does_match): "title": "Character", "type": "object", }, - f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}', + f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', [ ('{ "name" : "Player" }', True), ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), @@ -710,19 +710,6 @@ def test_format(schema, regex, examples): ('{"time":20:20:39Z}', False), # missing quotes for value ], ), - # Unconstrained Object - ( - { - "title": "Foo", - "type": "object", - }, - [ - ("{}", True), - ('{"a": 1, "b": null}', True), - ('{"a": {"z": {"g": 4}}, "b": null}', True), - ("1234", False), # not an object - ], - ), ], ) def test_format_without_regex(schema, examples): @@ -737,7 +724,7 @@ def test_format_without_regex(schema, examples): assert match is None -@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"]) +@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) def test_json_schema_custom_whitespace_pattern(whitespace_pattern): """assert whitespace_pattern setting respected""" @@ -759,13 +746,11 @@ class MockModel(BaseModel): ) mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" - match_default_ws = re.fullmatch(pattern, mock_result_mult_ws) + match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) if whitespace_pattern is None: assert match_default_ws else: - assert match_default_ws is None - - assert re.fullmatch(pattern, mock_result_maybe_ws) + assert re.fullmatch(pattern, mock_result_mult_ws) def test_one_of_doesnt_produce_illegal_lookaround():