Skip to content

Commit

Permalink
update logits in place for GuideLogitsProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 7, 2024
1 parent ef4e819 commit 36875a0
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def process_logits(

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.full_like(logits, -math.inf)
mask = torch.ones_like(logits, dtype=torch.bool)
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
mask[i, allowed_tokens] = logits[i, allowed_tokens]
mask[i, allowed_tokens] = False
logits.masked_fill_(mask, float("-inf"))

return mask
return logits

def copy(self) -> "GuideLogitsProcessor":
"""Return a copy of the logits processor."""
Expand Down

0 comments on commit 36875a0

Please sign in to comment.