Skip to content

Commit

Permalink
Merge pull request #147 from OyvindTafjord/gsm-allow-newline
Browse files Browse the repository at this point in the history
Stop GSM8k generation at double new line
  • Loading branch information
hamishivi committed Apr 27, 2024
2 parents 28badc2 + 68976cc commit f342459
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions eval/gsm/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ def apply_chat_format(example, tokenizer):
tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
tensor_parallel_size=torch.cuda.device_count(),
)
stop_string = "\n\n" if args.stop_at_double_newline else "\n"
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=512,
stop=["\n"] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
stop=[stop_string] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
)
if args.use_chat_format:
prompts = [apply_chat_format(example, tokenizer) for example in test_data]
Expand Down Expand Up @@ -116,13 +117,21 @@ def apply_chat_format(example, tokenizer):
else:
prompts = [prompt_prefix + "Question: " + example["question"].strip() + "\nAnswer:" for example in test_data]
new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
stop_tokens = [new_line_token]
if args.stop_at_double_newline:
# We'll stop generation at double new line (check if that's 1 or 2 tokens)
double_new_line_token = tokenizer.encode("\n\n", add_special_tokens=False)[-1]
if new_line_token == double_new_line_token:
stop_tokens = [new_line_token, new_line_token] # double new line is two new line tokens
else:
stop_tokens = [double_new_line_token] # double new line has its own token
outputs = generate_completions(
model=model,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=512,
batch_size=args.eval_batch_size,
stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
stop_id_sequences=[stop_tokens] if not args.use_chat_format else None, # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
do_sample=False,
)
else:
Expand Down Expand Up @@ -251,6 +260,11 @@ def apply_chat_format(example, tokenizer):
default="eval.templates.create_prompt_with_tulu_chat_format",
help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
)
parser.add_argument(
"--stop_at_double_newline",
action="store_true",
help="If given, will stop generation at double newline instead of single."
)
args = parser.parse_args()

# model_name_or_path and openai_engine cannot be both None or both not None.
Expand Down

0 comments on commit f342459

Please sign in to comment.