-
Notifications
You must be signed in to change notification settings - Fork 431
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Andrew Lapp
committed
Feb 6, 2024
1 parent
95074e5
commit a579bb0
Showing
1 changed file
with
64 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import importlib | ||
|
||
import pytest | ||
import torch | ||
|
||
import outlines.logging | ||
from outlines.generate.generator import sequence_generator, token_generator | ||
|
||
|
||
@pytest.mark.parametrize("enable_logger", [True, False]) | ||
def test_token_generator_logger_call(enable_logger, mocker): | ||
"""log_logits() is expensive, assert only called when explicitly enabled""" | ||
|
||
class MockFSM: | ||
def next_state(self, state, next_token_ids): | ||
return 0 | ||
|
||
def allowed_token_ids(self, _): | ||
return [] | ||
|
||
def is_final_state(self, _): | ||
return True | ||
|
||
class MockTokenizer: | ||
eos_token_id = 2 | ||
|
||
def decode(self, _): | ||
return "x" | ||
|
||
class MockModel: | ||
def __init__(self): | ||
self.tokenizer = MockTokenizer() | ||
|
||
def __call__(*_): | ||
return torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=torch.float), None | ||
|
||
def sampler(biased_logits, *_): | ||
return torch.argmax(biased_logits, keepdims=True) | ||
|
||
importlib.reload(outlines.logging) | ||
if enable_logger: | ||
outlines.logging.enable_logits_logging() | ||
|
||
mock_logits_logger_info = mocker.patch("outlines.logging.logits_logger.info") | ||
|
||
init_state = ( | ||
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), | ||
torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]), | ||
None, | ||
) | ||
init_fsm_states = [0] | ||
generate = token_generator(MockModel(), sampler) | ||
sequence = sequence_generator( | ||
generate, [MockFSM()], init_state, init_fsm_states, torch.Generator() | ||
) | ||
next(sequence) | ||
|
||
if enable_logger: | ||
mock_logits_logger_info.assert_called() | ||
else: | ||
mock_logits_logger_info.assert_not_called() | ||
|
||
# ensure enable_logits_logging() doesn't bleed into other tests | ||
importlib.reload(outlines.logging) |