diff --git a/outlines/text/generate/__init__.py b/outlines/text/generate/__init__.py index dc39d81bf..bb752264a 100644 --- a/outlines/text/generate/__init__.py +++ b/outlines/text/generate/__init__.py @@ -1,2 +1,2 @@ from .continuation import continuation -from .integer import integer +from .regex import float, integer, regex diff --git a/outlines/text/generate/integer.py b/outlines/text/generate/integer.py deleted file mode 100644 index f138b047d..000000000 --- a/outlines/text/generate/integer.py +++ /dev/null @@ -1,96 +0,0 @@ -import math -from typing import List, Optional, Tuple - -import interegular -import torch - -from outlines.text.generate.continuation import Continuation -from outlines.text.parsing import find_partial_matches, map_partial_states_to_vocab - - -class Integer(Continuation): - """Represents a integer generation model. - - `Integer` instances are constrained generation models that only - generate integer values. Leading zeros are fobidden. EOS tokens - are only allowed after at least one digit has been generated. - - >>> import outlines.text as text - >>> sequence = text.generate.integer(model)("Return an integer between 0 and 10") - - """ - - def __init__(self, model, max_tokens: Optional[int]): - super().__init__(model, max_tokens) - - vocabulary = model.tokenizer.vocabulary - sorted_vocabulary = [ - k for k, v in sorted(vocabulary.items(), key=lambda kv: kv[1]) - ] - - int_regex_string = r"(0|[1-9][0-9]+)" - int_regex_pattern = interegular.parse_pattern(int_regex_string) - self.int_regex_fsm = int_regex_pattern.simplify().to_fsm() - - def partial_match_filter(string, end_idx, state_seq): - if end_idx is not None and end_idx < len(string) - 1: - return False - return True - - pstate_to_vocab = map_partial_states_to_vocab( - list(sorted_vocabulary), - {"INT": self.int_regex_fsm}, - True, - partial_match_filter, - ) - self.pstate_to_vocab = {k: list(v) for k, v in pstate_to_vocab.items()} - - def create_proposal( - self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor - ) -> torch.DoubleTensor: - """Modify the next-token logits so that only integers can be generated. - - Parameters - ---------- - generated_token_ids - The token ids generated so far. - logits - The next-token logits. - - """ - if generated_token_ids.shape[-1] > 0: - # TODO Make this work for `generated_token_ids` of arbitrary shape - sampled_sequences = self.model.tokenizer.decode(generated_token_ids) - if isinstance(sampled_sequences, str): - sampled_sequences = [sampled_sequences] - partial_matches = [ - find_partial_matches(self.int_regex_fsm, sequence) - for sequence in sampled_sequences - ] - pmatches = [ - max(partial_match, key=lambda x: x[0] if x[0] is not None else -1) - for partial_match in partial_matches - ] - self.pstates: List[Tuple[str, int]] = [ - (self.pstates[0][0], pmatch[1][-1]) for pmatch in pmatches - ] - else: - self.pstates = [ - ("INT", self.int_regex_fsm.initial) - for _ in range(generated_token_ids.shape[0]) - ] - - masks = [] - for pstate in self.pstates: - next_support = self.pstate_to_vocab[pstate] - mask = torch.full((len(self.model.tokenizer.vocabulary),), -math.inf) - mask[next_support] = 0 - masks.append(mask.unsqueeze(0)) - - mask = torch.concatenate(masks, dim=0) - - return logits + mask - - -def integer(model, max_tokens: Optional[int] = None): - return Integer(model, max_tokens) diff --git a/outlines/text/generate/regex.py b/outlines/text/generate/regex.py new file mode 100644 index 000000000..0bcb86ecf --- /dev/null +++ b/outlines/text/generate/regex.py @@ -0,0 +1,160 @@ +import math +from typing import List, Optional, Tuple + +import interegular +import torch + +from outlines.text.generate.continuation import Continuation +from outlines.text.parsing import find_partial_matches, map_partial_states_to_vocab + + +class Regex(Continuation): + """Represents a regex-based generation model. + + `Regex` instances are constrained generation models that only generate + sequences that match an input regex. We assume that the sequence can be + terminated (but not necessarily) when the finite state machine corresponding + to the regex is in an accepting state. + + >>> import outlines.text as text + >>> sequence = text.generate.regex(model, "(0|[1-9][0-9]+)")("Return an integer between 0 and 10") + + """ + + def __init__(self, model, regex_string: str, max_tokens: Optional[int]): + super().__init__(model, max_tokens) + + vocabulary = model.tokenizer.vocabulary + sorted_vocabulary = [ + k for k, v in sorted(vocabulary.items(), key=lambda kv: kv[1]) + ] + + regex_pattern = interegular.parse_pattern(regex_string) + self.regex_fsm = regex_pattern.to_fsm().reduce() + + def partial_match_filter(string, end_idx, state_seq): + if end_idx is not None and end_idx < len(string) - 1: + return False + return True + + pstate_to_vocab = map_partial_states_to_vocab( + list(sorted_vocabulary), + {"REGEX": self.regex_fsm}, + partial_match_filter, + final_state_string=model.tokenizer.eos_token, + ) + + # TODO: This check might be a little too strict, because I think that + # while some states are made unreachable by a vocabulary (and will not + # be present in the following set difference), there could still be + # paths to terminal states emanating from the states that are reachable. + states_with_transition = {x[1] for x in pstate_to_vocab.keys()} + if len(self.regex_fsm.states.difference(states_with_transition)) > 0: + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + self.pstate_to_vocab = {k: list(v) for k, v in pstate_to_vocab.items()} + # These tuples are comprised of the FSM name, last FSM state, and + # number of processed tokens. + # When an EOS is observed, the last FSM state becomes `-1`. + self.pstates: List[Tuple[str, int, int]] = [] + + def create_proposal( + self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor + ) -> torch.DoubleTensor: + """Modify the next-token logits so that only integers can be generated. + + Parameters + ---------- + generated_token_ids + The token ids generated so far. + logits + The next-token logits. + + """ + + if len(self.pstates) == 0: + self.pstates = [ + ("REGEX", self.regex_fsm.initial, 0) + for _ in range(generated_token_ids.shape[0]) + ] + + if generated_token_ids.shape[-1] > 0: + new_pstates = [] + for token_seq, (_, last_fsm_state, last_token_idx) in zip( + generated_token_ids, + self.pstates, + ): + # Get the tokens we haven't already processed + readable_tokens = token_seq[last_token_idx:] + # excluding any EOS tokens + not_eos_mask = [ + tk != self.model.tokenizer.eos_token_id for tk in readable_tokens + ] + readable_tokens = readable_tokens[not_eos_mask] + if len(readable_tokens) > 0: + # If we previously ended with an EOS, we shouldn't be + # getting/sampling any more non-EOS tokens + assert last_fsm_state > -1 + + sequence = self.model.tokenizer.decode(readable_tokens) + + ((_, state_seq),) = find_partial_matches( + self.regex_fsm, + "".join(sequence), + start_state=last_fsm_state, + ) + pstate = ( + "REGEX", + state_seq[-1], + last_token_idx + len(sequence), + ) + else: + pstate = ("REGEX", -1, last_token_idx) + + new_pstates.append(pstate) + + self.pstates = new_pstates + + masks = [] + for pstate in self.pstates: + mask = torch.full((len(self.model.tokenizer.vocabulary),), -math.inf) + + if pstate[1] > -1: + next_support = self.pstate_to_vocab[pstate[:2]] + else: + next_support = [self.model.tokenizer.eos_token_id] + + mask[next_support] = 0 + masks.append(mask.unsqueeze(0)) + + mask = torch.concatenate(masks, dim=0) + + return logits + mask + + +def regex(model, regex_string: str, max_tokens: Optional[int] = None): + return Regex(model, regex_string, max_tokens) + + +def integer(model, max_tokens: Optional[int] = None): + """Generate integers. + + The regex used to constrain the generation optionally matches plus or minus + signs and forbids leading zeros (even if the `int` function in Python allows + them). + + """ + return Regex(model, r"[-+]?\d+", max_tokens) + + +def float(model, max_tokens: Optional[int] = None): + """Generate floating-point numbers. + + The regex used to constrain the generation optionally matches plus or minus + signs, and forbids leading zeros (even if the `float` function in Python + allows them). + + """ + return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens) diff --git a/tests/text/generate/test_integer.py b/tests/text/generate/test_integer.py deleted file mode 100644 index d5ae7548f..000000000 --- a/tests/text/generate/test_integer.py +++ /dev/null @@ -1,70 +0,0 @@ -import math - -import torch - -from outlines.text.generate.integer import integer - - -class Tokenizer: - eos_token = "" - eos_token_id = 0 - pad_token_id = -1 - vocabulary = {"": 0, "00": 1, "1": 2, "0.": 3, "431": 4, "a": 5} - tokens = list(vocabulary.keys()) - - def decode(self, token_ids): - decoded = [] - for i in range(token_ids.shape[0]): - decoded.append("".join([self.tokens[idx] for idx in token_ids[i]])) - - return decoded - - -class Model: - tokenizer = Tokenizer() - device = "cpu" - - -def test_integer_proposal(): - model = Model() - generator = integer(model) - - logits = torch.ones(len(model.tokenizer.vocabulary)) - result = generator.create_proposal(torch.tensor([[]]), logits) - assert torch.equal( - result, torch.tensor([[-math.inf, -math.inf, 1.0, -math.inf, 1.0, -math.inf]]) - ) - - logits = torch.ones(len(model.tokenizer.vocabulary)) - result = generator.create_proposal(torch.tensor([[2]]), logits) - assert torch.equal( - result, torch.tensor([[-math.inf, 1.0, 1.0, -math.inf, 1.0, -math.inf]]) - ) - - logits = torch.ones(len(model.tokenizer.vocabulary)) - result = generator.create_proposal(torch.tensor([[4]]), logits) - assert torch.equal( - result, torch.tensor([[-math.inf, 1.0, 1.0, -math.inf, 1.0, -math.inf]]) - ) - - logits = torch.ones(len(model.tokenizer.vocabulary)) - result = generator.create_proposal(torch.tensor([[4], [2]]), logits) - assert torch.equal( - result, - torch.tensor( - [ - [-math.inf, 1.0, 1.0, -math.inf, 1.0, -math.inf], - [-math.inf, 1.0, 1.0, -math.inf, 1.0, -math.inf], - ] - ), - ) - - logits = torch.ones((4, len(model.tokenizer.vocabulary))) - result = generator.create_proposal(torch.tensor([[]]), logits) - assert torch.equal( - result, - torch.tile( - torch.tensor([[-math.inf, -math.inf, 1.0, -math.inf, 1.0, -math.inf]]), - (4, 1), - ), - ) diff --git a/tests/text/generate/test_integration_transfomers.py b/tests/text/generate/test_integration_transfomers.py index 307da6d44..5c055bc43 100644 --- a/tests/text/generate/test_integration_transfomers.py +++ b/tests/text/generate/test_integration_transfomers.py @@ -1,9 +1,10 @@ +import re + import pytest import torch import outlines.models as models -from outlines.text.generate.continuation import continuation -from outlines.text.generate.integer import integer +import outlines.text.generate as generate def test_transformers_integration_continuation(): @@ -12,15 +13,17 @@ def test_transformers_integration_continuation(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") - sequence = continuation(model)("Write a short sentence", rng=rng) + sequence = generate.continuation(model)("Write a short sentence", rng=rng) assert isinstance(sequence, str) assert model.tokenizer.eos_token not in sequence - sequence = continuation(model, max_tokens=10)("Write a short sentence", rng=rng) + sequence = generate.continuation(model, max_tokens=10)( + "Write a short sentence", rng=rng + ) assert isinstance(sequence, str) prompts = ["Write a short sentence", "And another one"] - sequence = continuation(model, max_tokens=10)(prompts, rng=rng) + sequence = generate.continuation(model, max_tokens=10)(prompts, rng=rng) assert isinstance(sequence, list) assert len(sequence) == 2 assert isinstance(sequence[0], str) @@ -34,7 +37,19 @@ def test_transformers_integration_continuation_array_samples(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") prompts = ["Write a short sentence", "And another one"] - _ = continuation(model, max_tokens=10)(prompts, rng=rng, samples=3) + _ = generate.continuation(model, max_tokens=10)(prompts, rng=rng, samples=3) + + +def test_transformers_various_regexes(): + rng = torch.Generator() + rng.manual_seed(0) + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompt = "Write an email address" + regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" + sequence = generate.regex(model, regex_str)(prompt, rng=rng) + assert re.fullmatch(regex_str, sequence[len(prompt) :]) is not None def test_transformers_integration_integer(): @@ -44,7 +59,7 @@ def test_transformers_integration_integer(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") prompt = "Write a short sentence" - sequence = integer(model, max_tokens=10)(prompt, rng=rng) + sequence = generate.integer(model, max_tokens=10)(prompt, rng=rng) generated = sequence[len(prompt) :] assert generated[0] != 0 @@ -58,13 +73,27 @@ def test_transformers_integration_integer_array(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu") prompts = ["Give me a number", "And another one"] - sequence = integer(model, max_tokens=10)(prompts, rng=rng) + sequence = generate.integer(model, max_tokens=10)(prompts, rng=rng) assert isinstance(sequence, list) assert len(sequence) == 2 int(sequence[0][len(prompts[0]) :]) int(sequence[1][len(prompts[1]) :]) +def test_transformers_integration_float(): + rng = torch.Generator() + rng.manual_seed(0) + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompt = "Write a short sentence" + sequence = generate.float(model, max_tokens=10)(prompt, rng=rng) + + generated = sequence[len(prompt) :] + assert generated[0] != 0 + float(generated) + + def test_transformers_integration_with_pad_token(): model_name = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" model = models.transformers(model_name, device="cpu") diff --git a/tests/text/generate/test_regex.py b/tests/text/generate/test_regex.py new file mode 100644 index 000000000..5d4ede3a6 --- /dev/null +++ b/tests/text/generate/test_regex.py @@ -0,0 +1,122 @@ +import math + +import pytest +import torch + +import outlines.text.generate as generate + + +class Tokenizer: + eos_token = "" + pad_token = None + eos_token_id = 0 + pad_token_id = -1 + vocabulary = {"": 0, "-": 1, "1": 2, "0.": 3, "431": 4, "a": 5, "A": 6} + tokens = list(vocabulary.keys()) + + def decode(self, token_ids): + decoded = [] + for i in range(token_ids.shape[0]): + decoded.append("".join([self.tokens[idx] for idx in token_ids[i]])) + + return decoded + + +class Model: + tokenizer = Tokenizer() + device = "cpu" + + +@pytest.mark.parametrize( + "regex_string, valid_first_token, proposal", + [ + ( + r"[A-Z]+", + 6, + [-math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, 1.0], + ), + ( + r"[a-z]+", + 5, + [-math.inf, -math.inf, -math.inf, -math.inf, -math.inf, 1.0, -math.inf], + ), + ( + r"(a|A)", + 6, + [-math.inf, -math.inf, -math.inf, -math.inf, -math.inf, 1.0, 1.0], + ), + (r"\d+", 2, [-math.inf, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]), + (r"\d+\.", 3, [-math.inf, -math.inf, 1.0, 1.0, 1.0, -math.inf, -math.inf]), + ], +) +def test_regex_proposal(regex_string, valid_first_token, proposal): + model = Model() + generator = generate.regex(model, regex_string) + + logits = torch.ones(len(model.tokenizer.vocabulary)) + result = generator.create_proposal(torch.tensor([[]]), logits) + assert torch.equal(result.squeeze(), torch.tensor(proposal)) + assert result.squeeze()[0] == -math.inf + + # The EOS token can be generated once the FSM is in an accept state + result = generator.create_proposal(torch.tensor([[valid_first_token]]), logits) + assert result.squeeze()[0] == 1 + + +def test_regex_no_valid_transition(): + model = Model() + with pytest.raises(ValueError, match="The vocabulary does not allow"): + generate.regex(model, "aw") + + +@pytest.mark.parametrize( + "input_ids, proposal", + [ + ([[]], [[-math.inf, 1.0, 1.0, -math.inf, 1.0, -math.inf, -math.inf]]), + ([[1]], [[-math.inf, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]]), + ([[4]], [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]]), + ( + [[4], [2]], + [ + [1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf], + [1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf], + ], + ), + ( + [[4, 0], [1, 2]], + [ + [1.0, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf], + [1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf], + ], + ), + ], +) +def test_integer_proposal(input_ids, proposal): + model = Model() + generator = generate.integer(model) + + logits = torch.ones(len(model.tokenizer.vocabulary)) + result = generator.create_proposal(torch.tensor(input_ids), logits) + assert torch.equal( + result, + torch.tensor(proposal), + ) + + +@pytest.mark.parametrize( + "input_ids, proposal", + [ + ([[]], [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf]]), + ([[3]], [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf]]), + ], +) +def test_float_proposal(input_ids, proposal): + model = Model() + generator = generate.float(model) + + logits = torch.ones(len(model.tokenizer.vocabulary)) + result = generator.create_proposal(torch.tensor(input_ids), logits) + assert torch.equal( + result, + torch.tensor(proposal), + )