Skip to content

Commit

Permalink
Implement AlignmentGuide
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Sep 21, 2024
1 parent 289ef5d commit 0d68474
Showing 1 changed file with 242 additions and 0 deletions.
242 changes: 242 additions & 0 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,37 @@ def is_final_state(self, state: Any) -> bool:
def copy(self) -> "Guide":
...

def accepts(self, token_ids: List[int], state=None) -> bool:
"""
Determine whether the sequence, `token_ids`, is accepted by the Guide.
`token_ids` doesn't need to complete the guide to be accepted.
"""
derived = self.derive(token_ids, state)
return derived is not None

def derive(self, token_ids: List[int], state=None) -> Union["Guide", None]:
if state is None:
state = self.initial_state
for token_id in token_ids:
instruction = self.get_next_instruction(state)

# determine if token_id allowed by instruction
if isinstance(instruction, Write):
raise NotImplementedError("TODO")
elif isinstance(instruction, Generate):
if (
instruction.tokens is not None
and token_id not in instruction.tokens
):
return None
else:
raise TypeError(f"Expected instruction, got {instruction}")

# advance state
state = self.get_next_state(state, token_id)

return state


class StopAtEOSGuide(Guide):
"""Guide to generate tokens until the EOS token has been generated."""
Expand Down Expand Up @@ -487,3 +518,214 @@ def must_terminate_state(self, state: CFGState) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the Guide."""
return CFGGuide(self.cfg_string, self.tokenizer)


@cache()
def build_vocab_prefix_map(tokenizer: "Tokenizer") -> Dict[str, Set[Tuple[str, Tuple]]]:
"""Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]"""

# precompute the token ids of all vocab suffixes
suffixes = list(
{tok[i:] for tok in tokenizer.vocabulary for i in range(1, len(tok))}
)
encoded_suffixes, _ = tokenizer.encode(suffixes)
encoded_suffixes = [
[tok for tok in seq_ids if tok != tokenizer.pad_token_id]
for seq_ids in encoded_suffixes.tolist()
]
suffix_map = dict(zip(suffixes, map(tuple, encoded_suffixes)))
suffix_map[""] = tuple()

# compute prefix-suffix map for all tokens, s.t. prefix + suffix = token
prefix_map = collections.defaultdict(set)
for token, token_id in tokenizer.vocabulary.items():
for i in range(1, len(token) + 1):
prefix_map[token[:i]].add((token[i:], suffix_map[token[i:]]))
return prefix_map


AlignmentGuideState = collections.namedtuple(
"AlignmentGuideState", ["legal_path_map", "child_guide_state"]
)


class AlignmentGuide(Guide):
def __init__(
self, prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
):
"""
Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide.
Parameters
----------
prompt : str
The prompt text to be aligned with the generated tokens.
tokenizer : Tokenizer
Tokenizer used to align the prompt.
child_guide : Guide, optional
A guide to take control after alignment is complete. None -> Unconstrained after alignment
"""
self.prompt = prompt
self.tokenizer = tokenizer
self.child_guide = child_guide

alignment_seqs, child_guide_ids = self._get_alignment_sequences(
prompt, tokenizer, child_guide
)
alignment_prompt_ids, common_prompt_len = self._get_longest_common_prompt_ids(
alignment_seqs
)

self.alignment_prompt = self.tokenizer.decode(
[alignment_seqs[0, :common_prompt_len]]
)[0]

# calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens
legal_paths = [
tuple([t for t in seq if t != tokenizer.pad_token_id])
for seq in alignment_seqs[:, common_prompt_len:].tolist()
]
legal_path_map = dict(zip(legal_paths, child_guide_ids))

self.initial_state = AlignmentGuideState(
legal_path_map=legal_path_map, child_guide_state=None
)

@staticmethod
def _get_alignment_sequences(
prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None
):
"""
Calculate all possible sequences which are valid with a prompt + child_guide
E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid
Returns tuple of (alignment_seqs, child_guide_ids) of same length
- alignment_seqs:
All token sequences which can represent `prompt` + start of generation. The last token
must represent the end of the prompt can extend beyond the prompt to start generation.
Sequences are only included if the start of generation portion is legal with child guide.
- child_guide_ids:
Token to send to the child guide to simulate the start of generation. In the example above
"world" is the last alignment seq token, therefore we must advance the state of the child
guide with the tokenization of "rld" in order to continue generation with the child guide.
"""
guide_accepts: Dict[
Tuple[int], bool
] = {} # cache of suffix acceptance for child_guide.accepts()

# prompts with alignment tokens at end
aligned_prompt_completions: List[str] = []
# tokens to feed child guide once alignment completes
child_guide_ids: List[Tuple] = []

# compute alignment seqs which are valid with prompt and child guide
for prefix, alignment_details in build_vocab_prefix_map(tokenizer).items():
if prompt.endswith(prefix):
for suffix, suffix_ids in alignment_details:
if child_guide is None:
aligned_prompt_completions.append(prompt + suffix)
child_guide_ids.append(tuple())
elif guide_accepts.setdefault(
suffix_ids, child_guide.accepts(suffix_ids)
):
aligned_prompt_completions.append(prompt + suffix)
child_guide_ids.append(suffix_ids)

alignment_seqs, _ = tokenizer.encode(aligned_prompt_completions)
return alignment_seqs, child_guide_ids

@staticmethod
def _get_longest_common_prompt_ids(alignment_seqs):
"""
Among all candidate prompt alignment seqs, get the longest shared prefix and their length
"""
# get longest common prefix among alignment sequences, which will form our alignment prompt
common = (
(alignment_seqs.unsqueeze(1) == alignment_seqs.unsqueeze(0))
.all(0)
.cumprod(1)
)
common_len = common.sum(1).max().item()
return alignment_seqs[0, :common_len], common_len

def get_next_instruction(self, state: AlignmentGuideState) -> Instruction:
"""
Return the next set of valid tokens for generation based on the current state.
If alignment hasn't completed:
tokens which continue one of the candidate alignment paths are legal
If alignment has completed:
get instruction from the child guide
"""
if state.legal_path_map is not None:
return Generate(
sorted({token_ids[0] for token_ids in state.legal_path_map.keys()})
)
elif self.child_guide is None:
return Generate(None)
else:
return self.child_guide.get_next_instruction(state.child_guide_state)

def get_next_state(
self, state: AlignmentGuideState, token_id: int
) -> AlignmentGuideState:
"""
Get AlignmentGuideState advanced by token ID.
If alignment has completed:
get instruction from the child guide
If alignment hasn't completed:
Filter out alignment paths which don't start with token_id
Remove First token from remaining paths
If advancing the state completes alignment:
Advance the child_guide state
"""
if state.legal_path_map is None:
if self.child_guide is not None:
return AlignmentGuideState(
legal_path_map=None,
child_guide_state=self.child_guide.get_next_state(
state.child_guide_state, token_id
),
)
else:
return AlignmentGuideState(None, None)
else:
next_state_legal_path_map = {
key[1:]: value
for key, value in state.legal_path_map.items()
if key[0] == token_id
}
# if none remaining, advance the child guide
if not any(next_state_legal_path_map):
if self.child_guide is not None:
child_guide_advancement_ids = next(
iter(next_state_legal_path_map.values())
)
return AlignmentGuideState(
legal_path_map=None,
child_guide_state=self.child_guide.derive(
child_guide_advancement_ids, state.child_guide_state
),
)
else:
return AlignmentGuideState(None, None)

# if paths remaining, return advanced legal_path_map
else:
return AlignmentGuideState(
legal_path_map=next_state_legal_path_map,
child_guide_state=state.child_guide_state,
)

def is_final_state(self, state: AlignmentGuideState) -> bool:
if state.legal_path_map is not None:
return False
elif self.child_guide is None:
return True
else:
return self.child_guide.is_final_state(state.child_guide_state)

def copy(self):
"""AlignmentGuide isn't mutated"""
return self

0 comments on commit 0d68474

Please sign in to comment.