From 3ea174d913a672b9d5256881ed701be1a2cf8e58 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 17 Sep 2024 16:14:08 -0400 Subject: [PATCH] Create optional strict=False for generate.json --- docs/reference/generation/json.md | 8 ++ outlines/generate/json.py | 29 ++++++- tests/generate/test_generate_json.py | 115 +++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 tests/generate/test_generate_json.py diff --git a/docs/reference/generation/json.md b/docs/reference/generation/json.md index da9f14729..0f75a198c 100644 --- a/docs/reference/generation/json.md +++ b/docs/reference/generation/json.md @@ -42,6 +42,14 @@ print(result) generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*") ``` +!!! Note "Non-Strict Mode" + Because models may exhaust their context window before a valid schema is generated, an error resulting from from an invalid generation may occur. This is particularly troublesome when an error interrupts a batch workload. To ensure `generate.json` returns a dict containing error details for invalid sequences rather than raising an error, use the following: + + ```python + generator = generate.json(model, User, strict=False) + ``` + + !!! Note "Performance" `generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once. diff --git a/outlines/generate/json.py b/outlines/generate/json.py index f75878d29..58e9b34d9 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -18,6 +18,7 @@ def json( schema_object: Union[str, object, Callable], sampler: Sampler = multinomial(), whitespace_pattern: Optional[str] = None, + strict=True, ) -> SequenceGeneratorAdapter: """ Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. @@ -36,6 +37,10 @@ def json( whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + strict + If strict mode is enabled, generations which don't conform to the schema or aren't + valid JSON will result in an error. Outlines guarantees generation complies with a schema, + but schemas often allow for infinite repetition and exhaust the model_max_length. Returns ------- @@ -43,21 +48,39 @@ def json( transforms the result if BaseModel is used. """ + + def maybe_strict_formatter(formatter): + """If strict, use normal formatter. Otherwise, return error dict on failure""" + if strict: + return formatter + + def allow_fail_formatter(generated_output): + try: + return formatter(generated_output) + except Exception as e: + return { + "error": str(e), + "error_type": type(e).__name__, + "output": generated_output, + } + + return allow_fail_formatter + if isinstance(schema_object, type(BaseModel)): schema = pyjson.dumps(schema_object.model_json_schema()) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: schema_object.parse_raw(x) + generator.format_sequence = maybe_strict_formatter(schema_object.parse_raw) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = maybe_strict_formatter(pyjson.loads) elif isinstance(schema_object, str): schema = schema_object regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = maybe_strict_formatter(pyjson.loads) else: raise ValueError( f"Cannot parse schema {schema_object}. The schema must be either " diff --git a/tests/generate/test_generate_json.py b/tests/generate/test_generate_json.py new file mode 100644 index 000000000..00232ef2f --- /dev/null +++ b/tests/generate/test_generate_json.py @@ -0,0 +1,115 @@ +import json +import string + +import pytest +from pydantic import BaseModel, ValidationError + +from outlines import generate + + +class MockCharacterTokenizer: + def __init__(self): + characters = set( + string.ascii_letters + + string.digits + + string.punctuation + + string.whitespace + ) + self.vocabulary = {tok: tok_id for tok_id, tok in enumerate(characters)} + self.vocabulary["eos"] = len(characters) + self.special_tokens = {"eos"} + self.eos_token_id = len(characters) + + def convert_token_to_string(self, token): + return token + + +class MockModel: + def __init__(self, generated): + self.generated = generated + self.tokenizer = MockCharacterTokenizer() + + def generate(self, *args, **kwargs): + return self.generated + + +mock_json_schema = json.dumps( + { + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + "additionalProperties": False, + } +) + + +class MockPydanticModel(BaseModel): + message: str + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_success(schema): + model = MockModel(generated='{"message": "foo"}') + generator = generate.json(model, schema) + generator("hi") + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_success_batch(schema): + model = MockModel( + generated=[ + '{"message": "foo"}', + '{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"}', + ] + ) + generator = generate.json(model, schema) + for output in generator("hi"): + pass + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_fail(schema): + model = MockModel(generated='{"message": "foo') + generator = generate.json(model, schema) + with pytest.raises((json.decoder.JSONDecodeError, ValidationError)): + generator("hi") + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_strict_fail_batch(schema): + model = MockModel( + generated=[ + '{"message": "foo"}', + '{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"', + ] + ) + generator = generate.json(model, schema) + with pytest.raises((json.decoder.JSONDecodeError, ValidationError)): + generator("hi") + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_non_strict_evade_failure(schema): + model = MockModel(generated='{"message": "foo') + generator = generate.json(model, schema, strict=False) + result = generator("hi") + assert result["error_type"] in ("JSONDecodeError", "ValidationError") + assert result["output"] == model.generated + + +@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel]) +def test_generate_non_strict_evade_failure_batch(schema): + model = MockModel( + generated=[ + '{"message": "foo"}', + '{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"', + ] + ) + generator = generate.json(model, schema, strict=False) + result = generator("hi") + if isinstance(schema, str): + assert result[0] == json.loads(model.generated[0]) + else: + assert result[0] == schema.parse_raw(model.generated[0]) + assert result[1]["error_type"] in ("JSONDecodeError", "ValidationError") + assert result[1]["output"] == model.generated[1]