Skip to content

Commit

Permalink
Give biased_logits the same dtype as logits
Browse files Browse the repository at this point in the history
Exllama generates logits in torch Half-dtype, but Outlines requires the
Float-dtype.

This small change converts the logits to the required dtype (whatever
that might be), solving issue #583.

Tested with Exllama on the example code on the github front page, and
#583 is resolved.
  • Loading branch information
dnhkng authored Jan 25, 2024
1 parent eb692f6 commit 0cd9608
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor:
A view of the original logits tensor where some values are masked.
"""
biased_logits = torch.full(logits.shape, -math.inf, device=logits.device)
biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
for i, ids in enumerate(allowed_token_ids):
biased_logits[i, ids] = logits[i, ids]
return biased_logits

0 comments on commit 0cd9608

Please sign in to comment.