Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Lapp committed Feb 6, 2024
1 parent 95074e5 commit a579bb0
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import importlib

import pytest
import torch

import outlines.logging
from outlines.generate.generator import sequence_generator, token_generator


@pytest.mark.parametrize("enable_logger", [True, False])
def test_token_generator_logger_call(enable_logger, mocker):
"""log_logits() is expensive, assert only called when explicitly enabled"""

class MockFSM:
def next_state(self, state, next_token_ids):
return 0

def allowed_token_ids(self, _):
return []

def is_final_state(self, _):
return True

class MockTokenizer:
eos_token_id = 2

def decode(self, _):
return "x"

class MockModel:
def __init__(self):
self.tokenizer = MockTokenizer()

def __call__(*_):
return torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=torch.float), None

def sampler(biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True)

importlib.reload(outlines.logging)
if enable_logger:
outlines.logging.enable_logits_logging()

mock_logits_logger_info = mocker.patch("outlines.logging.logits_logger.info")

init_state = (
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]),
torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]),
None,
)
init_fsm_states = [0]
generate = token_generator(MockModel(), sampler)
sequence = sequence_generator(
generate, [MockFSM()], init_state, init_fsm_states, torch.Generator()
)
next(sequence)

if enable_logger:
mock_logits_logger_info.assert_called()
else:
mock_logits_logger_info.assert_not_called()

# ensure enable_logits_logging() doesn't bleed into other tests
importlib.reload(outlines.logging)

0 comments on commit a579bb0

Please sign in to comment.