Skip to content

Commit

Permalink
Update tests and documentation in relation to max_tokens and stop_at
Browse files Browse the repository at this point in the history
Fix typo fsm
  • Loading branch information
RobinPicard authored and rlouf committed Dec 23, 2023
1 parent 35f001f commit 9041004
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 37 deletions.
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ answer = outlines.generate.format(model, float)(prompt)

### Efficient regex-guided generation

Outlines also comes with fast regex-guided generation. In fact, the `choice`,
`integer` and `float` functions above all use regex-guided generation under the
Outlines also comes with fast regex-guided generation. In fact, the `choice` and
`format` functions above all use regex-guided generation under the
hood:

``` python
Expand All @@ -92,12 +92,11 @@ import outlines
model = outlines.models.transformers("mistralai/Mistral-7B-v0.1")

prompt = "What is the IP address of the Google DNS servers? "
unguided = outlines.generate.text(model, max_tokens=30)(prompt)
unguided = outlines.generate.text(model)(prompt, max_tokens=30)
guided = outlines.generate.regex(
model,
r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
max_tokens=30,
)(prompt)
)(prompt, max_tokens=30)

print(unguided)
# What is the IP address of the Google DNS servers?
Expand Down Expand Up @@ -325,7 +324,7 @@ def labelling(to_label, examples):

model = outlines.models.transformers("mistralai/Mistral-7B-v0.1")
prompt = labelling("Just awesome", examples)
answer = outlines.generate.text(model, max_tokens=100)(prompt)
answer = outlines.generate.text(model)(prompt, max_tokens=100)
```

## Join us
Expand Down
4 changes: 2 additions & 2 deletions examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
batch_size = 10
max_tokens = 30
for grammar in [nlamb_grammar, calc_grammar]:
generator = generate.cfg(model, grammar, max_tokens=max_tokens)
sequences = generator([" "] * batch_size)
generator = generate.cfg(model, grammar)
sequences = generator([" "] * batch_size, max_tokens=max_tokens)
for seq in sequences:
try:
parse = generator.fsm.parser.parse(seq)
Expand Down
3 changes: 2 additions & 1 deletion outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
The new state of the FSM.
"""
self.num_tokens_generated += 1
if idx == 0:
self.num_tokens_generated += 1

if token_id == self.end_token_id:
return FSMState(-1)
Expand Down
17 changes: 13 additions & 4 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ def token_generator() -> Iterator[Union[List[str], str]]:
num_generated = 0
is_stop_at_reached = [False for _ in range(num_sequences)]
while True:
if (max_tokens and num_generated >= max_tokens) or all(is_stop_at_reached):
if (max_tokens and num_generated >= max_tokens) or all(
is_stop_at_reached
):
return
try:
sequence = next(states)
Expand All @@ -266,14 +268,21 @@ def token_generator() -> Iterator[Union[List[str], str]]:
next_tokens = [
token[len(sequence) :] if not stop else ""
for token, sequence, stop in zip(
generated_sequences, previously_generated_sequences, is_stop_at_reached
generated_sequences,
previously_generated_sequences,
is_stop_at_reached,
)
]
previously_generated_sequences = generated_sequences
if stop_sequences:
is_stop_at_reached = [
stop or self.is_stop_sequence_reached([generated_sequence], stop_sequences)
for generated_sequence, stop in zip(generated_sequences, is_stop_at_reached)
stop
or self.is_stop_sequence_reached(
[generated_sequence], stop_sequences
)
for generated_sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
]
yield next_tokens

Expand Down
74 changes: 50 additions & 24 deletions tests/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,40 @@ def test_deprecation():
model = models.transformers(model_name, device="cpu")

with pytest.warns(DeprecationWarning):
outlines.text.generate.continuation(model, max_tokens=10)
outlines.text.generate.continuation(model)

with pytest.warns(DeprecationWarning):
outlines.text.generate.choice(model, ["A", "B"], max_tokens=10)
outlines.text.generate.continuation(model)

with pytest.warns(DeprecationWarning):
outlines.text.generate.regex(model, "[0-9]", max_tokens=10)
outlines.text.generate.choice(model, ["A", "B"])

with pytest.warns(DeprecationWarning):
outlines.text.generate.format(model, int, max_tokens=10)
outlines.text.generate.regex(model, "[0-9]")

with pytest.warns(DeprecationWarning):
outlines.text.generate.format(model, int)

with pytest.warns(DeprecationWarning):
outlines.generate.text(model, max_tokens=10)

with pytest.warns(DeprecationWarning):
outlines.generate.text(model, stop_at=["."])

with pytest.warns(DeprecationWarning):
outlines.generate.regex(model, "[0-9]", max_tokens=10)

with pytest.warns(DeprecationWarning):
outlines.generate.format(model, int, max_tokens=10)

with pytest.warns(DeprecationWarning):
outlines.generate.choice(model, ["A", "B"], max_tokens=10)

class Character(BaseModel):
name: str

with pytest.warns(DeprecationWarning):
outlines.generate.json(model, Character, max_tokens=10)

with pytest.warns(DeprecationWarning):

Expand All @@ -49,11 +73,13 @@ def test_transformers_integration_text():
assert isinstance(sequence, str)
assert model.tokenizer.eos_token not in sequence

sequence = generate.text(model, max_tokens=10)("Write a short sentence ", rng=rng)
sequence = generate.text(model)(
"Write a short sentence ", max_tokens=10, stop_at=".", rng=rng
)
assert isinstance(sequence, str)

prompts = ["Write a short sentence ", "And another one "]
sequence = generate.text(model, max_tokens=10)(prompts, rng=rng)
sequence = generate.text(model)(prompts, max_tokens=10, stop_at=[".", ","], rng=rng)
assert isinstance(sequence, list)
assert len(sequence) == 2
assert isinstance(sequence[0], str)
Expand All @@ -65,8 +91,8 @@ def test_transformers_integration_streaming():

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
sequence = generate.text(model, max_tokens=10).stream(
"Write a short sentence ", rng=rng
sequence = generate.text(model).stream(
"Write a short sentence ", max_tokens=10, stop_at=[".", ","], rng=rng
)

token = next(sequence)
Expand All @@ -76,8 +102,8 @@ def test_transformers_integration_streaming():
remaining = "".join([token[0] for token in sequence])
assert isinstance(remaining, str)

sequence = generate.text(model, max_tokens=10).stream(
["Prompt1", "Prompt2"], rng=rng
sequence = generate.text(model).stream(
["Prompt1", "Prompt2"], max_tokens=10, stop_at=[".", ","], rng=rng
)
tokens = next(sequence)
assert isinstance(tokens, list)
Expand All @@ -93,7 +119,7 @@ def test_transformers_integration_text_stop():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
prompt = "Write a short sentence "
sequence = generate.text(model, stop="a")(prompt, rng=rng)
sequence = generate.text(model)(prompt, stop_at="a", rng=rng)
assert sequence[len(prompt) :].find("a") == -1


Expand All @@ -105,7 +131,7 @@ def test_transformers_integration_text_array_samples():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
prompts = ["Write a short sentence", "And another one"]
_ = generate.text(model, max_tokens=10)(prompts, rng=rng, samples=3)
_ = generate.text(model)(prompts, max_tokens=10, rng=rng, samples=3)


def test_transformers_various_regexes():
Expand Down Expand Up @@ -146,7 +172,7 @@ def test_transformers_integration_integer():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "Write a short sentence"
sequence = generate.format(model, int, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, int)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
int(sequence)
Expand All @@ -159,7 +185,7 @@ def test_transformers_integration_integer_array():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompts = ["Give me a number", "And another one"]
sequence = generate.format(model, int, max_tokens=10)(prompts, rng=rng)
sequence = generate.format(model, int)(prompts, max_tokens=10, rng=rng)
assert isinstance(sequence, list)
assert len(sequence) == 2
int(sequence[0])
Expand All @@ -173,7 +199,7 @@ def test_transformers_integration_float():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "Write a short sentence"
sequence = generate.format(model, float, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, float)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
float(sequence)
Expand All @@ -186,7 +212,7 @@ def test_transformers_integration_bool():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "Is this True or False?"
sequence = generate.format(model, bool, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, bool)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
bool(sequence)
Expand All @@ -199,7 +225,7 @@ def test_transformers_integration_date():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "What day is it today?"
sequence = generate.format(model, datetime.date, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, datetime.date)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
datetime.datetime.strptime(sequence, "%Y-%m-%d")
Expand All @@ -212,7 +238,7 @@ def test_transformers_integration_time():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "What time is it?"
sequence = generate.format(model, datetime.time, max_tokens=10)(prompt, rng=rng)
sequence = generate.format(model, datetime.time)(prompt, max_tokens=10, rng=rng)

assert sequence != ""
datetime.datetime.strptime(sequence, "%H:%M:%S")
Expand All @@ -225,7 +251,7 @@ def test_transformers_integration_datetime():
model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name)
prompt = "What time is it?"
sequence = generate.format(model, datetime.datetime, max_tokens=20)(prompt, rng=rng)
sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20, rng=rng)

assert sequence != 0
datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -264,7 +290,7 @@ class Spam(BaseModel):
rng = torch.Generator()
rng.manual_seed(0) # make sure that `bar` is not an int

result = generate.json(model, Spam, max_tokens=500)(prompt, rng=rng)
result = generate.json(model, Spam)(prompt, max_tokens=500, rng=rng)
assert isinstance(result, BaseModel)
assert isinstance(result.foo, int)
assert isinstance(result.bar, float)
Expand All @@ -291,7 +317,7 @@ def test_transformers_json_schema():
rng = torch.Generator()
rng.manual_seed(0) # make sure that `bar` is not an int

result = generate.json(model, schema, max_tokens=500)(prompt, rng=rng)
result = generate.json(model, schema)(prompt, max_tokens=500, rng=rng)
assert isinstance(result, dict)
assert isinstance(result["foo"], int)
assert isinstance(result["bar"], str)
Expand All @@ -311,7 +337,7 @@ class Spam(BaseModel):
rng = torch.Generator()
rng.manual_seed(0) # make sure that `bar` is not an int

result = generate.json(model, Spam, max_tokens=500)(prompts, rng=rng)
result = generate.json(model, Spam)(prompts, max_tokens=500, rng=rng)
assert isinstance(result[0], BaseModel)
assert isinstance(result[1], BaseModel)

Expand Down Expand Up @@ -393,7 +419,7 @@ class Spam(BaseModel):
rng = torch.Generator()
rng.manual_seed(4)

result = generate.json(model, Spam, max_tokens=100)(prompt, rng=rng)
result = generate.json(model, Spam)(prompt, max_tokens=100, rng=rng)
assert isinstance(result, BaseModel)
assert (
isinstance(result.bar, int)
Expand All @@ -413,7 +439,7 @@ def function(foo: int, bar: List[int]):
rng = torch.Generator()
rng.manual_seed(4)

sequence = generate.json(model, function, max_tokens=100)(prompt, rng=rng)
sequence = generate.json(model, function)(prompt, max_tokens=100, rng=rng)
assert isinstance(sequence, dict)
assert isinstance(function(**sequence), int)

Expand Down

0 comments on commit 9041004

Please sign in to comment.