From bca550c5fd5b5de4aaf9e2a5cc67293ec6b9f2ee Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 27 Jul 2024 11:16:54 -0500 Subject: [PATCH] Add benchmark: CFG rejection sampling + CFG no rejection sampling --- benchmarks/bench_cfg_guide.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/benchmarks/bench_cfg_guide.py b/benchmarks/bench_cfg_guide.py index 8f6de914a..14dc31c73 100644 --- a/benchmarks/bench_cfg_guide.py +++ b/benchmarks/bench_cfg_guide.py @@ -38,23 +38,31 @@ def setup(self, grammar_name): ) @staticmethod - def _run_random_cfg(guide): + def _run_random_cfg(guide, rejection_sampling=True): state = guide.initial_state token_ids = list(guide.tokenizer.vocabulary.values()) for i in range(40): # simulate ordering of logits top prob to lowest prob random.shuffle(token_ids) # simulate sampling and state update - next_token_id = next(guide.iter_valid_token_ids(state, token_ids)) - state = guide.get_next_state(state, next_token_id) + if rejection_sampling: + next_token_id = next(guide.iter_valid_token_ids(state, token_ids)) + state = guide.get_next_state(state, next_token_id) + else: + next_token_id = random.choice(guide.get_next_instruction(state).tokens) + state = guide.get_next_state(state, next_token_id) @cache_disabled() def time_cfg_guide_setup(self, grammar_name): CFGGuide(benched_grammars[grammar_name], self.tokenizer) + @cache_disabled() + def time_cfg_guide_run_rejection_sampling(self, grammar): + self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=True) + @cache_disabled() def time_cfg_guide_run(self, grammar): - self._run_random_cfg(self.prebuilt_cfg_guide) + self._run_random_cfg(self.prebuilt_cfg_guide, rejection_sampling=False) @cache_disabled() def peakmem_cfg_guide_run(self, grammar):