Skip to content

Commit

Permalink
Fix tokens array dimensions in is_finished test
Browse files Browse the repository at this point in the history
The dimensions of the mock decoder output and the test input token ids did not
match and were causing PyTorch to issue a warning/error in CI.
  • Loading branch information
brandonwillard committed Oct 12, 2023
1 parent 13a0d6d commit f6e33dd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/text/generate/test_continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_continuation_stop_is_finished():

model = continuation(model, stop=["\n"])

token_ids = torch.tensor([[2, 3]])
token_ids = torch.tensor([[2, 3], [2, 3]])
result = model.is_finished(token_ids)
assert torch.equal(result, torch.tensor([True, False]))

Expand Down

0 comments on commit f6e33dd

Please sign in to comment.