Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prompt/generation alignment #161

Open
rlouf opened this issue Jun 22, 2023 · 5 comments · May be fixed by #531 or #1166
Open

Implement prompt/generation alignment #161

rlouf opened this issue Jun 22, 2023 · 5 comments · May be fixed by #531 or #1166
Labels
enhancement text Linked to text generation
Milestone

Comments

@rlouf
Copy link
Member

rlouf commented Jun 22, 2023

Guidance implements a method called token healing, which consists in correcting for the quirks introduced by modern encodings like BPE. See this notebook for a thorough explanation of why this is necessary. The implementation for Transformers models is here.

This consists in backtracking one or several tokens and start generation by imposing that we reproduce the text that corresponds to the removed tokens. This can be integrated in the __call__ method of the Sequence class.

@rlouf rlouf added text Linked to text generation enhancement labels Jun 22, 2023
@rlouf rlouf added this to the 0.1 milestone Jul 13, 2023
@rlouf
Copy link
Member Author

rlouf commented Jul 15, 2023

First thoughts on how this could be achieved. Let's consider the prompt This is a prompt .

  1. We loop over the entire vocabulary and use partial matching to determine the tokens that cross the prompt boundaries, i.e. the ones that start with a substring of the prompt but are longer than this substring. This gives us a list of potential tokens.
  2. For each of these tokens, match the part that is within the boundary and strip the prompt string.
  3. Tokenize the prompt to which we have removed the parts within the boundary. Generate a mask that only allows the previously found token(s).

What do we do when (2) gives several matches?

@arunpatro
Copy link
Contributor

In case of multiple matches, we can rank by the integer id of the token(s). A smaller integer implies more frequently occurring token in the BPE tokenizing process.

How about evaluating the log probs of the sequence? I think smaller integer should have higher log prob.

@rlouf
Copy link
Member Author

rlouf commented Jul 17, 2023

Here is a quick outline of a solution that uses outlines.text.parsing.find_partial_matches and only loops through the vocabulary once:

from outlines.text.parsing import find_partial_matches
from transformers import AutoTokenizer
import interegular
import collections

tokenizer = AutoTokenizer.from_pretrained("gpt2")
vocabulary = tokenizer.get_vocab()

# The BPE tokenizer encodes spaces as Ġ so anything that is based on string matching
# will require us to do this.
sorted_vocabulary = [
    tokenizer.convert_tokens_to_string([k]) for k, v in sorted(vocabulary.items(), key=lambda kv: kv[1])
]

prompt = "This is a new day"
fsm = interegular.parse_pattern(prompt).to_fsm()

tokenized_prompt = tokenizer.encode(prompt)
prompt_tokens = [tokenizer.decode(t) for t in tokenized_prompt]
token_idx_in_prompt = [prompt.rfind(pt) for pt in prompt_tokens]  # fails if appears several times do something better by tracking the position in the prompt string

found = defaultdict(list)
for vocab_str in sorted_vocabulary:
    pmatch = find_partial_matches(
        fsm,
        vocab_str
    )
    if pmatch != set():
        end_idx, states = pmatch.pop()  # We need to loop over the matches instead
        if end_idx is not None and states[-1] == len(prompt):
            if states[0] in prompt_idx:
                found[token_idx_in_prompt.index(states[0])].append(vocab_str)

print(found)
# {4: [' day', ' days', ' daylight', ' daytime']})

We then need to back one token, generate a next token using the masks we can build from the above list and then generate the sequence as usual.

Now I understand my intuition behind leaving a space at the end of my prompt: this tells the model that it shouldn't complete the word. As you can see, the found dict contains not only " day" but also completions like " daytime".

However, if you run the code above for the prompt "This is a new day ", you can see that it backs up one token (the whitespace), and suggests the ~33,000 tokens that start with a whitespace as potential continuations.

@rlouf
Copy link
Member Author

rlouf commented Jul 17, 2023

Another fun one:

import outlines.models as models
import outlines.text.generate as generate

model = models.transformers("gpt2")

prompt = "Is regex-guided generation useful? "
unguided = generate.continuation(model, max_tokens=30)(prompt)
guided = generate.regex(model, r"(Yes|No)", max_tokens=30)(prompt)

print(guided)
# Is regex-guided generation useful? No

prompt = "Is regex-guided generation useful?"
guided = generate.regex(model, r"(Yes|No)", max_tokens=30)(prompt)
print(guided)
# Is regex-guided generation useful?No

prompt = "Is regex-guided generation useful?"
guided = generate.regex(model, r"( )?(Yes|No)", max_tokens=30)(prompt)
print(guided)
# Is regex-guided generation useful? No

print([k for k in model.tokenizer.vocabulary.keys() if k.endswith("Yes")])
# ['Yes', 'ĠYes']

The "right" prompting here would be to leave a whitespace after the question, since we don't want "useful?" to be completed. However, " Yes" might be the most likely answer, as this is typically how the model would have tokenized "Is regex-generation useful? Yes". So we need to back one token and add this character to the regex. In this case, we should be able to match " Yes", " No" and also a succession of whitespace and "Yes" or "No".

@rlouf rlouf changed the title Implement token healing Implement token/prompt alignment Jul 17, 2023
@RobinPicard RobinPicard linked a pull request Jan 11, 2024 that will close this issue
@RobinPicard
Copy link
Contributor

Do you think it would be right strategy to still create the regex_fsm during the initialization of RegexFSM but then to only create the states_to_token_maps after the generation function is called with the prompt (we would first modify the regex_fsm to include the states corresponding to the last token of the prompt)? The downfall of this seems to be that we're adding some overhead to calling the generation function

@rlouf rlouf linked a pull request Jan 27, 2024 that will close this issue
@rlouf rlouf changed the title Implement token/prompt alignment Implement prompt/generation alignment Feb 11, 2024
@lapp0 lapp0 linked a pull request Sep 22, 2024 that will close this issue
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement text Linked to text generation
Projects
None yet
3 participants