diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md index 890b1fed3912..bc5394a69a44 100644 --- a/applications/ColossalEval/README.md +++ b/applications/ColossalEval/README.md @@ -154,7 +154,7 @@ inference_kwargs = { "calculate_loss": True, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32 } ``` @@ -163,7 +163,7 @@ The `inference_kwargs` currently contains 5 fields: - `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated - `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None. - `language` (str, compulsory): The language for the subcategory. -- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length. +- `calculate_overall_loss` (bool, compulsory): Whether to calculate the overall loss of sentences or not if the dataset is a pretrain dataset. It is usually used for calculate perplexity when you want to evaluate a model with extended context length. - `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference. For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly. @@ -230,7 +230,7 @@ Example: In this step, you will configure your tokenizer and model arguments to infer on the given datasets. A config file consists of two parts. -1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields. +1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel`, `ChatGLMModel2` and `vLLMModel`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. `vLLMModel` is for models that can be loaded with vllm offline inference `LLM` class. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields. 2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`. Once you have all config ready, the program will run inference on all the given datasets on all the given models. @@ -272,7 +272,42 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM } ``` -Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong. +An example config using model class `vLLMModel` and dataset class `CMMLUDataset` can be: +```json +{ + "model": [ + { + "name": "model name", + "model_class": "vLLMModel", + "parameters": { + "path": "path to model", + "model_max_length": 2048, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "model_kwargs": { + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "dataset name", + "dataset_class": "CMMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset", + "save_path": "path to save converted dataset" + } + ] +} +``` + +Currently, we support Hugging Face models as well as vLLM models. For Hugging Face models, the `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. For vLLM model, the `tokenizer_kwargs` and `model_kwargs` are loaded together in `LLM` class.`few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong. > For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation. @@ -287,7 +322,7 @@ torchrun --nproc_per_node=4 inference.py \ --inference_save_path "path to save inference results" ``` -You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size. +You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size (currently not support for `vLLMModel`). ### Evaluation @@ -530,10 +565,6 @@ class CustomizedModel(BaseModel): Once you have successfully added your own model, you can specify your model class in your inference config. -## To do - -- [ ] Add visualization code for evaluation results on public dataset -- [ ] Improve the way to label target tokens ## Citations diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index c1cfe37d7599..07597048d7f9 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -47,7 +47,7 @@ "calculate_loss": True, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 1023d1e23c1f..b15dd93afc87 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -70,7 +70,7 @@ "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 05752c2486fa..402a2d4c8eab 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -81,7 +81,7 @@ "calculate_loss": True, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py index 0337454fa788..266eaef3f486 100644 --- a/applications/ColossalEval/colossal_eval/dataset/colossalai.py +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -12,7 +12,7 @@ "calculate_loss": False, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 256, } diff --git a/applications/ColossalEval/colossal_eval/dataset/cvalues.py b/applications/ColossalEval/colossal_eval/dataset/cvalues.py index 4023a4c76322..f5b81f90ed3f 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cvalues.py +++ b/applications/ColossalEval/colossal_eval/dataset/cvalues.py @@ -15,7 +15,7 @@ "calculate_loss": False, "all_classes": ["A", "B"], "language": LANGUAGE, - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index 44ccea9cfa2c..533e9b4bfa52 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -36,7 +36,7 @@ "calculate_loss": True, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/gsm.py b/applications/ColossalEval/colossal_eval/dataset/gsm.py index 775c5843ff79..a639201053ef 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gsm.py +++ b/applications/ColossalEval/colossal_eval/dataset/gsm.py @@ -72,7 +72,7 @@ "calculate_loss": True, "all_classes": None, "language": "English", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 256, } @@ -114,7 +114,7 @@ def load( dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) if forward_only: - dataset[split][subject]["inference_kwargs"]["pretrain"] = True + dataset[split][subject]["inference_kwargs"]["calculate_overall_loss"] = True if split == "test" and few_shot: dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data() diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py index eb61efaa0d7c..e663e5e108e6 100644 --- a/applications/ColossalEval/colossal_eval/dataset/longbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -60,7 +60,7 @@ "calculate_loss": True, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index e9465c91b3ce..5e3ff6af6ef3 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -11,7 +11,7 @@ "calculate_loss": True, "all_classes": ["A", "B", "C", "D"], "language": "English", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index ef474ec4ca23..abec8ebfb038 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -14,7 +14,7 @@ "calculate_loss": False, "all_classes": None, "language": "English", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 1024, "turns": 2, } diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py index 8056c3dfd8bf..494bb0993ccf 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py @@ -28,7 +28,7 @@ "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": LANGUAGE, - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py index f5f17e64c991..8c41664c02c8 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py @@ -28,7 +28,7 @@ "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": LANGUAGE, - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/models/__init__.py b/applications/ColossalEval/colossal_eval/models/__init__.py index 8f6c9b414145..ec557571ca07 100644 --- a/applications/ColossalEval/colossal_eval/models/__init__.py +++ b/applications/ColossalEval/colossal_eval/models/__init__.py @@ -1,5 +1,6 @@ from .base import BaseModel from .chatglm import ChatGLM2Model, ChatGLMModel from .huggingface import HuggingFaceCausalLM, HuggingFaceModel +from .vllm import vLLMModel -__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"] +__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"] diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py index 9c70c0d2a1ad..4a48f4c0ed3e 100644 --- a/applications/ColossalEval/colossal_eval/models/chatglm.py +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -28,7 +28,7 @@ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List @torch.no_grad() def get_loss( - self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False ) -> List[List[float]]: """ Calculate loss only on target tokens. @@ -225,7 +225,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str @torch.no_grad() def get_loss( - self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False ) -> List[List[float]]: """ Calculate loss only on target tokens. diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index e91743525f0e..200e282e7b2b 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -105,6 +105,12 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kw elif hasattr(self.tokenizer, "eod_id"): # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id + else: + self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.") + raise ValueError( + "The tokenizer does not have a pad_token_id, eos_token, or eod_id. " + "Please set pad_token_id manually." + ) def _load_model( self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None @@ -245,7 +251,7 @@ def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[L return input_ids_list, labels_list, bytes_list def _get_input_ids_and_labels( - self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool ) -> Tuple[List[torch.LongTensor]]: """ Get input_ids and labels for the given data. @@ -258,7 +264,7 @@ def _get_input_ids_and_labels( Input_ids and labels for the given batch. """ - if pretrain: + if calculate_overall_loss: batch = [] # Concatenate prompt and target answers. # You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space. @@ -342,7 +348,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d calculate_loss = inference_kwargs["calculate_loss"] classes = inference_kwargs["all_classes"] language = inference_kwargs["language"] - pretrain = inference_kwargs["pretrain"] + calculate_overall_loss = inference_kwargs["calculate_overall_loss"] max_new_tokens = inference_kwargs["max_new_tokens"] few_shot_data = inference_kwargs.get("few_shot_data", None) @@ -384,12 +390,12 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d self.logger.info("-" * 120) self.logger.info(batch_prompt[0] + batch_target[0][0]) - if not pretrain: + if not calculate_overall_loss: batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) if calculate_loss: batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( - batch_prompt, batch_target, pretrain + batch_prompt, batch_target, calculate_overall_loss ) probs = [] @@ -409,7 +415,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d ] for j in range(len(batch)): - if not pretrain: + if not calculate_overall_loss: if isinstance(batch[j]["output"], list): batch[j]["output"].append(batch_decodes[j].strip()) else: @@ -496,7 +502,9 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return decoded_sequences, scores @torch.no_grad() - def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool + ) -> List[List[float]]: """ Calculate loss only on target tokens. @@ -513,13 +521,15 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # We don't need to generate new tokens. # Target answer's length is usually << model_max_length, but we still call it in case. # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. - if not pretrain: + if not calculate_overall_loss: batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain) + input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels( + batch_prompt, batch_target, calculate_overall_loss + ) # Because of multiple target answers, the final batch size may be greater than self.batch_size. # We will generate new batches. diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py new file mode 100644 index 000000000000..2cbdb6e1b767 --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -0,0 +1,498 @@ +import copy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 +from torch.utils.data import DataLoader +from tqdm import tqdm +from vllm import LLM, SamplingParams + +from colossalai.logging import DistributedLogger + +from .huggingface import HuggingFaceModel + +IGNORE_INDEX = -100 + + +class vLLMModel(HuggingFaceModel): + """ + Model wrapper around vLLM models. + + Args: + path: The path to a vLLM model. + model_max_length: The maximum sequence length of the model. + tokenizer_path: The path to the tokenizer. + tokenizer_kwargs: Keyword arguments for the tokenizer. + model_kwargs: Keyword arguments for the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + quantization: The method used to quantize the model weights + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + enforce_eager: Whether to enforce eager execution. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + path: str, + model_max_length: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: Dict = None, + model_kwargs: Dict = None, + prompt_template: Conversation = None, + batch_size: int = 1, + logger: DistributedLogger = None, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + quantization: Optional[str] = None, + gpu_memory_utilization: float = 0.5, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ): + super().__init__( + path=path, + model_max_length=model_max_length, + prompt_template=prompt_template, + batch_size=batch_size, + logger=logger, + ) + + self._load_model( + path=path, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + tokenizer_path=tokenizer_path if tokenizer_path else None, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + quantization=quantization, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + ) + + def _load_model( + self, + path: str, + model_kwargs: dict, + tokenizer_kwargs: dict, + tokenizer_path: Optional[str] = None, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + quantization: Optional[str] = None, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + ): + """ + Load model. + + Args: + path: The path to the model. + model_kwargs: Keyword arguments for the model. + tokenizer_kwargs: Keyword arguments for the tokenizer. + tokenizer_path: The path to the tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + quantization: The method used to quantize the model weights + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + enforce_eager: Whether to enforce eager execution. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + disable_custom_all_reduce: See ParallelConfig + + """ + if "torch_dtype" in model_kwargs: + model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"]) + model_kwargs.pop("torch_dtype") + else: + model_kwargs.setdefault("dtype", torch.float16) + + if "trust_remote_code" in model_kwargs: + trust_remote_code = model_kwargs["trust_remote_code"] + model_kwargs.pop("trust_remote_code") + + if "trust_remote_code" in tokenizer_kwargs: + trust_remote_code = tokenizer_kwargs["trust_remote_code"] + tokenizer_kwargs.pop("trust_remote_code") + + self.model = LLM( + model=path, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + quantization=quantization, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + **model_kwargs, + **tokenizer_kwargs, + ) + + self.tokenizer = self.model.get_tokenizer() + + if self.batch_size > 1: + self.tokenizer.padding_side = "left" + self.tokenizer.truncation_side = "left" + + if self.tokenizer.pad_token_id is None: + self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.") + if self.tokenizer.eos_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + elif hasattr(self.tokenizer, "eod_id"): + # Qwen has an eod token "<|endoftext|>". + self.tokenizer.pad_token_id = self.tokenizer.eod_id + else: + self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.") + raise ValueError( + "The tokenizer does not have a pad_token_id, eos_token, or eod_id. " + "Please set pad_token_id manually." + ) + + def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: + """ + Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 + + Args: + input_ids_list: A batch of input string. + labels: A batch of labels. + + Returns: + A list of loss and a list of label length. + + """ + batch_size = len(inputs) + sampling_kwargs = SamplingParams(logprobs=1) + outputs = self.model.generate(inputs, sampling_kwargs) + ce_loss = [] + + if labels is not None: + lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] + else: + lens = [1] * batch_size + + for i in range(batch_size): + logprobs = outputs[i].outputs[0].logprobs + token_ids = outputs[i].outputs[0].token_ids + + logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))] + logprobs_list = [i.logprob for i in logprobs_list] + logprobs_list = np.array(logprobs_list) + + if lens is not None: + logprobs_list = logprobs_list[: lens[i]] + + loss = -logprobs_list.sum(axis=-1) / lens[i] + ce_loss.append(loss) + + batch_loss = np.array(ce_loss) + + return batch_loss, lens + + def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + """ + Infer the given data. + This function will call self.generate() to get model outputs and use LogitsProcessor param to get specific logits. + + Args: + data: The data for inference. + inference_kwargs: Arguments for inference. + debug: Whether to display generated prompt for debugging. + + Returns: + Inference results. + + """ + calculate_loss = inference_kwargs["calculate_loss"] + classes = inference_kwargs["all_classes"] + language = inference_kwargs["language"] + calculate_overall_loss = inference_kwargs["calculate_overall_loss"] + max_new_tokens = inference_kwargs["max_new_tokens"] + few_shot_data = inference_kwargs.get("few_shot_data", None) + + # Some classification questions' options are texts not a single letter such as A, B, C and D. + # If the text length is greater than 1, we won't calculate loss over choices. + if classes is not None and any(len(c) > 1 for c in classes): + classes = None + + self.choices = classes + self.indices_for_choices = None + if self.choices: + # Get indices for each choice + self._get_choices_indices(language) + + self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} + + bar = tqdm( + range(len(data_loader)), + desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps", + disable=not is_rank_0(), + ) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + answers = [] + + for i, batch in enumerate(data_loader): + batch_prompt, batch_target = get_batch_prompt( + self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length + ) + + if is_rank_0() and debug and i == 0: + self.logger.info( + f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}" + ) + self.logger.info("-" * 120) + self.logger.info("An example prompt and prompt with target is:") + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0]) + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0] + batch_target[0][0]) + + if not calculate_overall_loss: + batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) + + if calculate_loss: + batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( + batch_prompt, batch_target, calculate_overall_loss + ) + + probs = [] + if self.indices_for_choices: + scores = scores.to(torch.float32) + # If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample. + # Otherwise this will violate the single-choice setting. + + if calculate_loss: + labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))] + + loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() + + probs = scores.numpy().tolist() + probs = [ + {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) + ] + + for j in range(len(batch)): + if not calculate_overall_loss: + if isinstance(batch[j]["output"], list): + batch[j]["output"].append(batch_decodes[j].strip()) + else: + batch[j]["output"] = batch_decodes[j].strip() + + if isinstance(scores, torch.Tensor): + batch[j]["logits_over_choices"] = probs[j] + + if calculate_loss: + batch[j]["loss_over_choices"] = loss_over_choices[j] + + if calculate_loss: + batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + + # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. + # However, loss (which is per sample loss) suffices for most cases. + batch[j]["loss_sum"] = batch_losses[j] + batch[j]["token_num"] = batch_target_token_nums[j] + + if batch_bytes_nums: + batch[j]["byte_num"] = batch_bytes_nums[j] + answers.extend(batch) + + bar.update() + + return answers + + @torch.no_grad() + def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: + """Generate results given a list of inputs and get logits of the first new token over choices. + + Args: + inputs: A list of strings. + max_new_tokens: Max new tokens for generation. + kwargs: Key arguments for generation + + Returns: + A list of generated strings and logits over choices. + + Note: + Currently the function only returns the logits of the first new token. + It is used for single choice question. + For multiple choices question, please avoid using the loss over choices. + You should set argument choices as None in self.inference(). + + """ + truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) + + generation_kwargs = kwargs.copy() + generation_kwargs.update({"max_tokens": max_new_tokens}) + logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) + + sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) + + outputs = self.model.generate(truncated_inputs, sampling_kwargs) + output_strs = [] + for output in outputs: + generated_text = output.outputs[0].text + output_strs.append(generated_text) + scores = logits_processor.get_target_logits() + + return output_strs, scores + + @torch.no_grad() + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool + ) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + if not calculate_overall_loss: + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + if calculate_overall_loss: + batch = [] + bytes_list = [] + batch_prompt_pretrain = [] + for p, b in zip(batch_prompt, batch_target): + batch.append(p + b[0]) + + for input in batch: + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. + ratio = [16, 8, 4, 2, 1] + tokenized = None + for r in ratio: + tokenized = self.tokenizer( + [input[0 : len(input) // r]], + truncation=True, + max_length=self.model_max_length, + return_tensors="pt", + ) + if tokenized.input_ids.size(1) >= self.model_max_length: + break + + string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) + batch_prompt_pretrain.append(string) + bytes_list.append(len(string.encode("utf-8"))) + + batch_prompt = copy.deepcopy(batch_prompt_pretrain) + batch_target = None + else: + batch_prompt_processed = [] + batch_target_processed = [] + for prompt, targets in zip(batch_prompt, batch_target): + for target in targets: + target_tokenized = self.tokenizer( + [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + max_new_tokens = target_tokenized["input_ids"][0].size(0) + prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] + batch_prompt_processed.append(prompt_with_correct_length) + batch_target_processed.append(target) + + batch_prompt = copy.deepcopy(batch_prompt_processed) + batch_target = copy.deepcopy(batch_target_processed) + bytes_list = None + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + bytes_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + + if bytes_list: + bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length]) + + start_indice += length + + if bytes_list: + return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample + + return losses_per_sample, target_token_nums_per_sample, None + + +class GetTokenLogitsProcessor: + """ + LogitsProcessor to get specific logits + + Args: + indices_for_choices: token indices of required tokens + target_logits: store all the target logits + """ + + def __init__( + self, + indices_for_choices: List[List[int]], + ): + self.indices_for_choices = (indices_for_choices,) + self.target_logits = [] + + def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: + choice_scores = [] + + if not input_ids: + for option_indices in self.indices_for_choices[0]: + choice_scores.append(logits[option_indices].detach().cpu()) + + choice_scores = torch.max(torch.stack(choice_scores), dim=0)[0] + self.target_logits.append(choice_scores) + + return logits + + def get_target_logits(self) -> torch.Tensor: + return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index c651970ee37c..1d3f13745474 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -69,7 +69,7 @@ def rm_and_merge( os.remove(directory) except Exception as e: print(e) - print(len(answers["data"])) + all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt index c5b9bad549e2..f9985b49f9ed 100644 --- a/applications/ColossalEval/requirements.txt +++ b/applications/ColossalEval/requirements.txt @@ -10,3 +10,4 @@ matplotlib pandas seaborn scikit-learn +vllm==0.5.5