Skip to content

Commit

Permalink
Create optional strict=False for generate.json
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Sep 23, 2024
1 parent 77c6d67 commit 3ea174d
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 3 deletions.
8 changes: 8 additions & 0 deletions docs/reference/generation/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ print(result)
generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*")
```

!!! Note "Non-Strict Mode"
Because models may exhaust their context window before a valid schema is generated, an error resulting from from an invalid generation may occur. This is particularly troublesome when an error interrupts a batch workload. To ensure `generate.json` returns a dict containing error details for invalid sequences rather than raising an error, use the following:

```python
generator = generate.json(model, User, strict=False)
```


!!! Note "Performance"

`generation.json` computes an index that helps Outlines guide generation. This can take some time, but only needs to be done once. If you want to generate several times with the same schema make sure that you only call `generate.json` once.
Expand Down
29 changes: 26 additions & 3 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def json(
schema_object: Union[str, object, Callable],
sampler: Sampler = multinomial(),
whitespace_pattern: Optional[str] = None,
strict=True,
) -> SequenceGeneratorAdapter:
"""
Generate structured JSON data with a `Transformer` model based on a specified JSON Schema.
Expand All @@ -36,28 +37,50 @@ def json(
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
strict
If strict mode is enabled, generations which don't conform to the schema or aren't
valid JSON will result in an error. Outlines guarantees generation complies with a schema,
but schemas often allow for infinite repetition and exhaust the model_max_length.
Returns
-------
A `SequenceGenerator` instance that generates text constrained by the schema_object and
transforms the result if BaseModel is used.
"""

def maybe_strict_formatter(formatter):
"""If strict, use normal formatter. Otherwise, return error dict on failure"""
if strict:
return formatter

def allow_fail_formatter(generated_output):
try:
return formatter(generated_output)
except Exception as e:
return {
"error": str(e),
"error_type": type(e).__name__,
"output": generated_output,
}

return allow_fail_formatter

if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
generator.format_sequence = maybe_strict_formatter(schema_object.parse_raw)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
generator.format_sequence = maybe_strict_formatter(pyjson.loads)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
generator.format_sequence = maybe_strict_formatter(pyjson.loads)
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
Expand Down
115 changes: 115 additions & 0 deletions tests/generate/test_generate_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import string

import pytest
from pydantic import BaseModel, ValidationError

from outlines import generate


class MockCharacterTokenizer:
def __init__(self):
characters = set(
string.ascii_letters
+ string.digits
+ string.punctuation
+ string.whitespace
)
self.vocabulary = {tok: tok_id for tok_id, tok in enumerate(characters)}
self.vocabulary["eos"] = len(characters)
self.special_tokens = {"eos"}
self.eos_token_id = len(characters)

def convert_token_to_string(self, token):
return token


class MockModel:
def __init__(self, generated):
self.generated = generated
self.tokenizer = MockCharacterTokenizer()

def generate(self, *args, **kwargs):
return self.generated


mock_json_schema = json.dumps(
{
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
"additionalProperties": False,
}
)


class MockPydanticModel(BaseModel):
message: str


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_success(schema):
model = MockModel(generated='{"message": "foo"}')
generator = generate.json(model, schema)
generator("hi")


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_success_batch(schema):
model = MockModel(
generated=[
'{"message": "foo"}',
'{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"}',
]
)
generator = generate.json(model, schema)
for output in generator("hi"):
pass


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_fail(schema):
model = MockModel(generated='{"message": "foo')
generator = generate.json(model, schema)
with pytest.raises((json.decoder.JSONDecodeError, ValidationError)):
generator("hi")


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_strict_fail_batch(schema):
model = MockModel(
generated=[
'{"message": "foo"}',
'{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"',
]
)
generator = generate.json(model, schema)
with pytest.raises((json.decoder.JSONDecodeError, ValidationError)):
generator("hi")


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_non_strict_evade_failure(schema):
model = MockModel(generated='{"message": "foo')
generator = generate.json(model, schema, strict=False)
result = generator("hi")
assert result["error_type"] in ("JSONDecodeError", "ValidationError")
assert result["output"] == model.generated


@pytest.mark.parametrize("schema", [mock_json_schema, MockPydanticModel])
def test_generate_non_strict_evade_failure_batch(schema):
model = MockModel(
generated=[
'{"message": "foo"}',
'{"message": "basteuhotuhnoethunoteuhntoeuhntoehuotn"',
]
)
generator = generate.json(model, schema, strict=False)
result = generator("hi")
if isinstance(schema, str):
assert result[0] == json.loads(model.generated[0])
else:
assert result[0] == schema.parse_raw(model.generated[0])
assert result[1]["error_type"] in ("JSONDecodeError", "ValidationError")
assert result[1]["output"] == model.generated[1]

0 comments on commit 3ea174d

Please sign in to comment.