Skip to content

Commit

Permalink
construct logits mask in batch operation
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 7, 2024
1 parent 36875a0 commit 094af23
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
4 changes: 2 additions & 2 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def __call__(

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
if len(torch_logits.shape) == 2:
processed_logits = self.process_logits(input_ids.tolist(), torch_logits)
processed_logits = self.process_logits(input_ids, torch_logits)
elif len(torch_logits.shape) == 1:
processed_logits = self.process_logits(
[input_ids.tolist()], torch_logits.unsqueeze(0)
input_ids.unsqueeze(0), torch_logits.unsqueeze(0)
).squeeze(0)

# return logits as passed array type
Expand Down
33 changes: 23 additions & 10 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
self._seq_start_idx = None

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.Tensor:
"""Use the Guide to bias the logits before sampling the next token.
Expand All @@ -93,19 +93,32 @@ def process_logits(

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.ones_like(logits, dtype=torch.bool)

allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
mask[i, allowed_tokens] = False
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device, non_blocking=True
)
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token

allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)

mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))

return logits
Expand Down Expand Up @@ -202,7 +215,7 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
super().__init__(tokenizer=tokenizer, guide=cfg_guide)

def process_logits(
self, input_ids: List[List[int]], logits: torch.Tensor
self, input_ids: torch.LongTensor, logits: torch.Tensor
) -> torch.Tensor:
"""Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
if self._seq_start_idx is None:
Expand All @@ -212,11 +225,11 @@ def process_logits(

for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids))
curr_state_key = hash(tuple(gen_ids.tolist()))

if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1]))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1])
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state

sequence_states.append(self._guide_states[curr_state_key])
Expand Down

0 comments on commit 094af23

Please sign in to comment.