diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index feedf5253..9a52abecd 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -75,10 +75,10 @@ def __call__( # Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape if len(torch_logits.shape) == 2: - processed_logits = self.process_logits(input_ids.tolist(), torch_logits) + processed_logits = self.process_logits(input_ids, torch_logits) elif len(torch_logits.shape) == 1: processed_logits = self.process_logits( - [input_ids.tolist()], torch_logits.unsqueeze(0) + input_ids.unsqueeze(0), torch_logits.unsqueeze(0) ).squeeze(0) # return logits as passed array type diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index d9a912c89..e3b9e60d3 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -70,7 +70,7 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide): self._seq_start_idx = None def process_logits( - self, input_ids: List[List[int]], logits: torch.Tensor + self, input_ids: torch.LongTensor, logits: torch.FloatTensor ) -> torch.Tensor: """Use the Guide to bias the logits before sampling the next token. @@ -93,19 +93,32 @@ def process_logits( for seq_ids in input_ids: gen_ids = seq_ids[self._seq_start_idx :] - curr_state_key = hash(tuple(gen_ids)) + curr_state_key = hash(tuple(gen_ids.tolist())) if curr_state_key not in self._guide_states: - prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))] - curr_state = self.guide.get_next_state(prev_state, gen_ids[-1]) + prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] + curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) self._guide_states[curr_state_key] = curr_state sequence_states.append(self._guide_states[curr_state_key]) mask = torch.ones_like(logits, dtype=torch.bool) + + allowed_tokens_batch = [] + batch_indices = [] for i, guide_state in enumerate(sequence_states): - allowed_tokens = self.guide.get_next_instruction(guide_state).tokens - mask[i, allowed_tokens] = False + allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( + mask.device, non_blocking=True + ) + allowed_tokens_batch.append(allowed_tokens) + batch_indices.append( + torch.full_like(allowed_tokens, i) + ) # Store batch index for each allowed token + + allowed_tokens_concat = torch.cat(allowed_tokens_batch) + batch_indices_concat = torch.cat(batch_indices) + + mask[batch_indices_concat, allowed_tokens_concat] = False logits.masked_fill_(mask, float("-inf")) return logits @@ -202,7 +215,7 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"): super().__init__(tokenizer=tokenizer, guide=cfg_guide) def process_logits( - self, input_ids: List[List[int]], logits: torch.Tensor + self, input_ids: torch.LongTensor, logits: torch.Tensor ) -> torch.Tensor: """Same behavior as GuideLogitsProcessor, but uses rejection sampling""" if self._seq_start_idx is None: @@ -212,11 +225,11 @@ def process_logits( for seq_ids in input_ids: gen_ids = seq_ids[self._seq_start_idx :] - curr_state_key = hash(tuple(gen_ids)) + curr_state_key = hash(tuple(gen_ids.tolist())) if curr_state_key not in self._guide_states: - prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))] - curr_state = self.guide.get_next_state(prev_state, gen_ids[-1]) + prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] + curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) self._guide_states[curr_state_key] = curr_state sequence_states.append(self._guide_states[curr_state_key])