Skip to content

Commit

Permalink
Stop generation with Continuation when a specific string was generated
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jul 15, 2023
1 parent 0bdcc56 commit bfa0e94
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 27 deletions.
75 changes: 64 additions & 11 deletions outlines/text/generate/continuation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union

import torch

Expand All @@ -17,36 +17,89 @@ class Continuation(Sequence):
"""

def __init__(self, model, max_tokens: Optional[int]):
def __init__(
self, model, max_tokens: Optional[int] = None, stop: Union[str, List[str]] = []
):
super().__init__(model, max_tokens)
self.eos_token_id = torch.tensor(
[self.model.tokenizer.eos_token_id], device=self.device
)

if isinstance(stop, str):
stop = [stop]

self.stop_sequences = stop

def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
"""Determine whether the sequences reached maximum length of end with
and EOS token.
In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
of the sequences that haven't been marked as finished already, which is
why we only need to look for the EOS token in the last element rather
than in the whole sequence.
We only need to look for the EOS token in the last element rather than
in the whole sequence. Indeed, (1) EOS is a single token (2)
`Sequence`'s `__call__` methods only passed the `token_ids` of the
sequences that haven't been marked as finished already.
Parameters
----------
token_ids
The input sequences.
"""
return token_ids[:, -1] == self.model.tokenizer.eos_token_id

sequences = self.model.tokenizer.decode(token_ids)
contains_stop_sequence = []
for sequence in sequences:
found = False
for stop_str in self.stop_sequences:
if stop_str in sequence:
found = True

contains_stop_sequence.append(found)

contains_stop_sequence = torch.tensor(contains_stop_sequence, dtype=torch.bool)
contains_eos = token_ids[:, -1] == self.model.tokenizer.eos_token_id

return torch.logical_or(contains_eos, contains_stop_sequence)

def postprocess_completions(self, completions: List[str]) -> List[str]:
"""Remove the EOS token from the completion."""
return [
"""Remove the EOS token from the completion.
Sequences in `stop` take precedence over EOS. For instance, if
`stop=["\n"]` and the generated sequence is 'One\nTwo<EOS>`
`Continuation.postprocess_completions` will return `One`.
"""
completions_without_eos = [
completion.replace(self.model.tokenizer.eos_token, "")
for completion in completions
]

completions_without_stop = []
for completion in completions_without_eos:
for stop_str in self.stop_sequences:
idx = completion.rfind(stop_str) # ignore the prompt
if idx > 0:
completion = completion[:idx]

completions_without_stop.append(completion)

return completions_without_stop


def continuation(model, max_tokens: Optional[int] = None):
return Continuation(model, max_tokens)
def continuation(
model, max_tokens: Optional[int] = None, *, stop: Union[str, List[str]] = []
):
"""Generate text sequences.
Parameters
----------
model
The model to use to computes the next-token logits.
max_tokens
The maximum number of tokens to generate.
stop
A string or list of strings which, when generated, stops
the generation for this sequence.
"""
return Continuation(model, max_tokens, stop)
30 changes: 30 additions & 0 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def create_proposal(


def regex(model, regex_string: str, max_tokens: Optional[int] = None):
"""Generate text sequences that match the input regex.
Parameters
----------
model
The model to use to computes the next-token logits.
regex
The regular expression generated expressions must match.
max_tokens
The maximum number of tokens to generate.
"""
return Regex(model, regex_string, max_tokens)


Expand All @@ -145,6 +157,15 @@ def integer(model, max_tokens: Optional[int] = None):
signs and forbids leading zeros (even if the `int` function in Python allows
them).
Parameters
----------
model
The model to use to computes the next-token logits.
regex
The regular expression generated expressions must match.
max_tokens
The maximum number of tokens to generate.
"""
return Regex(model, r"[-+]?\d+", max_tokens)

Expand All @@ -156,5 +177,14 @@ def float(model, max_tokens: Optional[int] = None):
signs, and forbids leading zeros (even if the `float` function in Python
allows them).
Parameters
----------
model
The model to use to computes the next-token logits.
regex
The regular expression generated expressions must match.
max_tokens
The maximum number of tokens to generate.
"""
return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens)
4 changes: 3 additions & 1 deletion outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def __call__(
)
token_ids = self.update_token_ids(is_finished, token_ids, updated_token_ids)
attention_mask = self.expand_attention_mask(attention_mask)
is_finished[~is_finished] = self.is_finished(updated_token_ids).flatten()
is_finished[~is_finished] = self.is_finished(
updated_token_ids[:, num_prompt_tokens:]
).flatten()

result = self.model.tokenizer.decode(token_ids)
result = self.postprocess_completions(result)
Expand Down
75 changes: 63 additions & 12 deletions tests/text/generate/test_continuation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from numpy.testing import assert_array_equal
import torch

from outlines.text.generate.continuation import Continuation, continuation

Expand All @@ -9,35 +8,87 @@ class Tokenizer:
eos_token_id = 0
pad_token_id = -1

def decode(self, token_ids):
return ["Test"] * token_ids.shape[0]


class Model:
tokenizer = Tokenizer()
device = "cpu"


def test_continuation_is_finished():
model = continuation(Model(), 10)
def test_continuation_eos_is_finished():
model = continuation(Model())
assert isinstance(model, Continuation)

token_ids = np.array([[3, 2]])
token_ids = torch.tensor([[3, 2]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False])
assert torch.equal(result, torch.tensor([False]))

token_ids = np.array([[3, 2, 0]])
token_ids = torch.tensor([[3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True])
assert torch.equal(result, torch.tensor([True]))

token_ids = np.array([[3, 2, 1], [3, 2, 0]])
token_ids = torch.tensor([[3, 2, 1], [3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False, True])
assert torch.equal(result, torch.tensor([False, True]))

token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
token_ids = torch.tensor([[3, 2, 1, 0], [3, 2, 0, -1]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True, False])
assert torch.equal(result, torch.tensor([True, False]))


def test_continuation_postprocess():
model = continuation(Model())
result = model.postprocess_completions(["Here<EOS>"])
assert len(result) == 1
assert result[0] == "Here"


def test_continuation_stop_is_finished():
tokenizer = Tokenizer()
tokenizer.decode = lambda x: ["finished \n", "not_finished"]
model = Model()
model.tokenizer = tokenizer

model = continuation(model, stop=["\n"])

token_ids = torch.tensor([[2, 3]])
result = model.is_finished(token_ids)
assert torch.equal(result, torch.tensor([True, False]))


def test_continuation_stop_postprocess():
model = Continuation(Model(), stop="\n")
result = model.postprocess_completions(["Stop\n"])
assert len(result) == 1
assert result[0] == "Stop"

model = Continuation(Model(), stop=["\n", ","])
result = model.postprocess_completions(["Stop"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\n"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\naaa"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop,aa\naaa"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\naa,a"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\n", "Nonstop"])
assert len(result) == 2
assert result == ["Stop", "Nonstop"]

result = model.postprocess_completions(["StopHere\nNoHere<EOS>"])
assert len(result) == 1
assert result[0] == "StopHere"
10 changes: 7 additions & 3 deletions tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@ def test_transformers_integration_continuation():

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
sequence = generate.continuation(model)("Write a short sentence", rng=rng)
sequence = generate.continuation(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)
assert model.tokenizer.eos_token not in sequence

sequence = generate.continuation(model, max_tokens=10)(
"Write a short sentence", rng=rng
"Write a short sentence ", rng=rng
)
assert isinstance(sequence, str)

prompts = ["Write a short sentence", "And another one"]
prompts = ["Write a short sentence ", "And another one "]
sequence = generate.continuation(model, max_tokens=10)(prompts, rng=rng)
assert isinstance(sequence, list)
assert len(sequence) == 2
assert isinstance(sequence[0], str)

prompt = "Write a short sentence "
sequence = generate.continuation(model, stop="a")(prompt, rng=rng)
assert sequence[len(prompt) :].find("a") == -1


@pytest.mark.xfail
def test_transformers_integration_continuation_array_samples():
Expand Down

0 comments on commit bfa0e94

Please sign in to comment.