diff --git a/modules/text_generation.py b/modules/text_generation.py index d64481b24e..70a51d9165 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -122,7 +122,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question, max_new_tokens) original_input_ids = input_ids output = input_ids[0] - cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" + cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: eos_token_ids.append(int(encode(eos_token)[0][-1])) @@ -132,45 +132,48 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi t = encode(stopping_string, 0, add_special_tokens=False) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) + generate_params = {} if not shared.args.flexgen: - generate_params = [ - f"max_new_tokens=max_new_tokens", - f"eos_token_id={eos_token_ids}", - f"stopping_criteria=stopping_criteria_list", - f"do_sample={do_sample}", - f"temperature={temperature}", - f"top_p={top_p}", - f"typical_p={typical_p}", - f"repetition_penalty={repetition_penalty}", - f"top_k={top_k}", - f"min_length={min_length if shared.args.no_stream else 0}", - f"no_repeat_ngram_size={no_repeat_ngram_size}", - f"num_beams={num_beams}", - f"penalty_alpha={penalty_alpha}", - f"length_penalty={length_penalty}", - f"early_stopping={early_stopping}", - ] + generate_params.update({ + "max_new_tokens": max_new_tokens, + "eos_token_id": eos_token_ids, + "stopping_criteria": stopping_criteria_list, + "do_sample": do_sample, + "temperature": temperature, + "top_p": top_p, + "typical_p": typical_p, + "repetition_penalty": repetition_penalty, + "top_k": top_k, + "min_length": min_length if shared.args.no_stream else 0, + "no_repeat_ngram_size": no_repeat_ngram_size, + "num_beams": num_beams, + "penalty_alpha": penalty_alpha, + "length_penalty": length_penalty, + "early_stopping": early_stopping, + }) else: - generate_params = [ - f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}", - f"do_sample={do_sample}", - f"temperature={temperature}", - f"stop={eos_token_ids[-1]}", - ] + generate_params.update({ + "max_new_tokens": max_new_tokens if shared.args.no_stream else 8, + "do_sample": do_sample, + "temperature": temperature, + "stop": eos_token_ids[-1], + }) if shared.args.deepspeed: - generate_params.append("synced_gpus=True") + generate_params.update({"synced_gpus": True}) if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - generate_params.insert(0, "inputs_embeds=inputs_embeds") - generate_params.insert(0, "inputs=filler_input_ids") + generate_params.update({"inputs_embeds": inputs_embeds}) + generate_params.update({"inputs": filler_input_ids}) else: - generate_params.insert(0, "inputs=input_ids") + generate_params.update({"inputs": input_ids}) try: # Generate the entire reply at once. if shared.args.no_stream: with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] + output = shared.model.generate(**generate_params)[0] + if cuda: + output = output.cuda() if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -194,7 +197,7 @@ def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) yield formatted_outputs(original_question, shared.model_name) - with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: + with generate_with_streaming(**generate_params) as generator: for output in generator: if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) @@ -214,7 +217,7 @@ def generate_with_streaming(**kwargs): for i in range(max_new_tokens//8+1): clear_torch_cache() with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] + output = shared.model.generate(**generate_params)[0] if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) reply = decode(output)