From 3508cd98dd6de3e5acfc9a9e3bf1e4ddb6d7ac38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sat, 23 Dec 2023 14:30:07 +0100 Subject: [PATCH] Rename and document new methods --- outlines/generate/api.py | 117 +++++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 29 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index d3c58278d..3be3af790 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -18,12 +18,13 @@ class SequenceGenerator: - def __init__(self, fsm, model, sampler, device, stop_at=None, max_tokens=None): + def __init__(self, fsm, model, sampler, device, max_tokens=None, stop_at=None): self.generate_token = token_generator(model, sampler) self.fsm = fsm self.tokenizer = model.tokenizer self.device = device self.max_tokens = max_tokens + if isinstance(stop_at, str): stop_at = [stop_at] self.stop_sequences = stop_at @@ -49,21 +50,51 @@ def get_generated_token_ids( prompts: List[str], last_state: GenerationState, ) -> List[torch.Tensor]: - """Give the tokens generated (so the current sequences without the initial user prompts)""" - # Get the number of tokens in the prompts + """Get the tokens generated so far. + + Parameters + ---------- + init_state + The initial state of the generation. + prompts + The prompts passed to the generator. + last_state + The current state of the generation + + Returns + ------- + A tensor that contains the token ids that have been generated so far. + + """ prompt_token_ids = init_state[0] prompt_lengths = [len(prompt_token_ids[i]) for i in range(len(prompts))] - # Remove the prompts from the generated sequences + token_ids = [ cur_token_ids[length:] for cur_token_ids, length in zip(last_state.token_ids, prompt_lengths) ] + return token_ids - def is_stop_sequence_reached( + def is_stop_sequence_found( self, generated_sequences: List[str], stop_sequences: List[str] ) -> bool: - """True if at least one of the stop sequences is found in each generated sequence""" + """Determine whether one of the stop sequences has been generated. + + Parameters + ---------- + generated_sequences + The list of sequences generated so far. + stop_sequences + The list that contains the sequence which stop the generation when + found. + + Returns + ------- + True if at least one of the stop sequences has been found in each generated + sequence. + + """ return all( [ any([seq in generated for seq in stop_sequences]) @@ -71,8 +102,18 @@ def is_stop_sequence_reached( ] ) - def format_sequence(self, sequence: str, stop_sequences: List[str]) -> str: - """Format the text sequence generated before returning it to the user""" + def strip_stop_sequences(self, sequence: str, stop_sequences: List[str]) -> str: + """Remove the stop sequences from the generated sequences. + + Parameters + ---------- + sequence + One of the generated sequences. + stop_sequences + The list that contains the sequence which stop the generation when + found. + + """ if stop_sequences: match_indexes = [sequence.find(seq) for seq in stop_sequences] if any([index != -1 for index in match_indexes]): @@ -83,10 +124,25 @@ def format_sequence(self, sequence: str, stop_sequences: List[str]) -> str: : match_indexes[min_match_index_pos] + len(stop_sequences[min_match_index_pos]) ] - return self.structure_sequence(sequence) - def structure_sequence(self, sequence: str) -> str: - """Modify the structure/type of the sequence, is overriden in some generate functions""" + return sequence + + def format_sequence(self, sequence: str) -> str: + """Translate the generated sequence to another type. + + This method is for instance overridden when generating JSON to either + return a dictionnary or a Pydantic model. + + Parameters + ---------- + sequence + A generated sequences. + + Returns + ------- + The formatted sequence. + + """ return sequence def __call__( @@ -135,9 +191,10 @@ def __call__( if isinstance(stop_at, str): stop_at = [stop_at] - stop_sequences = stop_at or self.stop_sequences + stop_sequences = stop_at or self.stop_sequences max_tokens = max_tokens or self.max_tokens + num_sequences = len(prompts) if rng is None: rng = torch.Generator(device=self.device) @@ -146,7 +203,6 @@ def __call__( init_state = init_generator_state( self.tokenizer, self.device, prompts, kv_cache ) - num_sequences = len(prompts) init_fsm_states = [FSMState(0) for _ in range(num_sequences)] states = sequence_generator( @@ -157,25 +213,28 @@ def __call__( try: last_state = next(states) if max_tokens or stop_sequences: - token_ids = self.get_generated_token_ids( + generated_token_ids = self.get_generated_token_ids( init_state, prompts, last_state ) - if max_tokens and len(token_ids[0]) >= max_tokens: + if max_tokens and len(generated_token_ids[0]) >= max_tokens: break - if stop_sequences and self.is_stop_sequence_reached( - self.tokenizer.decode(token_ids), stop_sequences + if stop_sequences and self.is_stop_sequence_found( + self.tokenizer.decode(generated_token_ids), stop_sequences ): break except StopIteration: break - token_ids = self.get_generated_token_ids(init_state, prompts, last_state) - generated = self.tokenizer.decode(token_ids) - + generated_token_ids = self.get_generated_token_ids( + init_state, prompts, last_state + ) + generated = self.tokenizer.decode(generated_token_ids) + stripped = [ + self.strip_stop_sequences(sequence, stop_sequences) + for sequence in generated + ] try: - formatted = [ - self.format_sequence(sequence, stop_sequences) for sequence in generated - ] + formatted = [self.format_sequence(sequence) for sequence in stripped] except pyjson.decoder.JSONDecodeError: raise TypeError( "Could not format the output of the model into a dictionary or a Pydantic model." @@ -277,7 +336,7 @@ def token_generator() -> Iterator[Union[List[str], str]]: if stop_sequences: is_stop_at_reached = [ stop - or self.is_stop_sequence_reached( + or self.is_stop_sequence_found( [generated_sequence], stop_sequences ) for generated_sequence, stop in zip( @@ -301,7 +360,7 @@ def text( device = model.device generator = SequenceGenerator( - fsm, model, sampler, device, stop_at=stop_at, max_tokens=max_tokens + fsm, model, sampler, device, max_tokens=max_tokens, stop_at=stop_at ) return generator @@ -332,7 +391,7 @@ def cfg( device = model.device generator = SequenceGenerator( - fsm, model, sampler, device, stop_at=stop_at, max_tokens=max_tokens + fsm, model, sampler, device, max_tokens=max_tokens, stop_at=stop_at ) return generator @@ -368,17 +427,17 @@ def json( schema = pyjson.dumps(schema_object.model_json_schema()) regex_str = build_regex_from_object(schema) generator = regex(model, regex_str, max_tokens, sampler) - generator.structure_sequence = lambda x: schema_object.parse_raw(x) + generator.format_sequence = lambda x: schema_object.parse_raw(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_object(schema) generator = regex(model, regex_str, max_tokens, sampler) - generator.structure_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = lambda x: pyjson.loads(x) elif isinstance(schema_object, str): schema = schema_object regex_str = build_regex_from_object(schema) generator = regex(model, regex_str, max_tokens, sampler) - generator.structure_sequence = lambda x: pyjson.loads(x) + generator.format_sequence = lambda x: pyjson.loads(x) else: raise ValueError( f"Cannot parse schema {schema_object}. The schema must be either "