From 4b8312c08e8d05a5f41453d63c8671aab601ed1c Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Fri, 1 Mar 2024 17:27:50 +0800 Subject: [PATCH] fix sft single turn inference example (#5416) --- applications/Colossal-LLaMA-2/inference_example.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/applications/Colossal-LLaMA-2/inference_example.py b/applications/Colossal-LLaMA-2/inference_example.py index 77e18d8b5939..63ce91e50432 100644 --- a/applications/Colossal-LLaMA-2/inference_example.py +++ b/applications/Colossal-LLaMA-2/inference_example.py @@ -15,7 +15,7 @@ def load_model(model_path, device="cuda", **kwargs): model.to(device) try: - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left') except OSError: raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.") @@ -29,6 +29,7 @@ def generate(args): if args.prompt_style == "sft": conversation = default_conversation.copy() conversation.append_message("Human", args.input_txt) + conversation.append_message("Assistant", None) input_txt = conversation.get_prompt() else: BASE_INFERENCE_SUFFIX = "\n\n->\n\n" @@ -46,7 +47,7 @@ def generate(args): num_return_sequences=1, ) response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True) - logger.info(f"Question: {input_txt} \n\n Answer: \n{response}") + logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}") return response