Skip to content

Commit

Permalink
[ColossalEval] support for vllm (#6056)
Browse files Browse the repository at this point in the history
* support vllm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify vllm and update readme

* run pre-commit

* remove dupilicated lines and refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update param name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine code

* update readme

* refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Camille7777 and pre-commit-ci[bot] authored Sep 18, 2024
1 parent 4fa6b95 commit f9546ba
Show file tree
Hide file tree
Showing 19 changed files with 576 additions and 35 deletions.
49 changes: 40 additions & 9 deletions applications/ColossalEval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ inference_kwargs = {
"calculate_loss": True,
"all_classes": ["A", "B", "C", "D"],
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32
}
```
Expand All @@ -163,7 +163,7 @@ The `inference_kwargs` currently contains 5 fields:
- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
- `language` (str, compulsory): The language for the subcategory.
- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
- `calculate_overall_loss` (bool, compulsory): Whether to calculate the overall loss of sentences or not if the dataset is a pretrain dataset. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.

For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
Expand Down Expand Up @@ -230,7 +230,7 @@ Example:
In this step, you will configure your tokenizer and model arguments to infer on the given datasets.

A config file consists of two parts.
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel`, `ChatGLMModel2` and `vLLMModel`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. `vLLMModel` is for models that can be loaded with vllm offline inference `LLM` class. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.

Once you have all config ready, the program will run inference on all the given datasets on all the given models.
Expand Down Expand Up @@ -272,7 +272,42 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM
}
```

Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
An example config using model class `vLLMModel` and dataset class `CMMLUDataset` can be:
```json
{
"model": [
{
"name": "model name",
"model_class": "vLLMModel",
"parameters": {
"path": "path to model",
"model_max_length": 2048,
"tokenizer_path": "",
"tokenizer_kwargs": {
"trust_remote_code": true
},
"model_kwargs": {
"trust_remote_code": true
},
"prompt_template": "plain",
"batch_size": 4
}
}
],
"dataset": [
{
"name": "dataset name",
"dataset_class": "CMMLUDataset",
"debug": false,
"few_shot": true,
"path": "path to original dataset",
"save_path": "path to save converted dataset"
}
]
}
```

Currently, we support Hugging Face models as well as vLLM models. For Hugging Face models, the `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. For vLLM model, the `tokenizer_kwargs` and `model_kwargs` are loaded together in `LLM` class.`few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.

> For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation.
Expand All @@ -287,7 +322,7 @@ torchrun --nproc_per_node=4 inference.py \
--inference_save_path "path to save inference results"
```

You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size.
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size (currently not support for `vLLMModel`).

### Evaluation

Expand Down Expand Up @@ -530,10 +565,6 @@ class CustomizedModel(BaseModel):

Once you have successfully added your own model, you can specify your model class in your inference config.

## To do

- [ ] Add visualization code for evaluation results on public dataset
- [ ] Improve the way to label target tokens

## Citations

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalEval/colossal_eval/dataset/agieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"calculate_loss": True,
"all_classes": None,
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalEval/colossal_eval/dataset/ceval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"calculate_loss": False,
"all_classes": ["A", "B", "C", "D"],
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalEval/colossal_eval/dataset/cmmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"calculate_loss": True,
"all_classes": ["A", "B", "C", "D"],
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"calculate_loss": False,
"all_classes": None,
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 256,
}

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalEval/colossal_eval/dataset/cvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"calculate_loss": False,
"all_classes": ["A", "B"],
"language": LANGUAGE,
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"calculate_loss": True,
"all_classes": None,
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
4 changes: 2 additions & 2 deletions applications/ColossalEval/colossal_eval/dataset/gsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"calculate_loss": True,
"all_classes": None,
"language": "English",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 256,
}

Expand Down Expand Up @@ -114,7 +114,7 @@ def load(
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)

if forward_only:
dataset[split][subject]["inference_kwargs"]["pretrain"] = True
dataset[split][subject]["inference_kwargs"]["calculate_overall_loss"] = True

if split == "test" and few_shot:
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"calculate_loss": True,
"all_classes": None,
"language": "Chinese",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalEval/colossal_eval/dataset/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"calculate_loss": True,
"all_classes": ["A", "B", "C", "D"],
"language": "English",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalEval/colossal_eval/dataset/mtbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"calculate_loss": False,
"all_classes": None,
"language": "English",
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 1024,
"turns": 2,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"calculate_loss": False,
"all_classes": ["A", "B", "C", "D"],
"language": LANGUAGE,
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"calculate_loss": False,
"all_classes": ["A", "B", "C", "D"],
"language": LANGUAGE,
"pretrain": False,
"calculate_overall_loss": False,
"max_new_tokens": 32,
}

Expand Down
3 changes: 2 additions & 1 deletion applications/ColossalEval/colossal_eval/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseModel
from .chatglm import ChatGLM2Model, ChatGLMModel
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
from .vllm import vLLMModel

__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"]
4 changes: 2 additions & 2 deletions applications/ColossalEval/colossal_eval/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List

@torch.no_grad()
def get_loss(
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
) -> List[List[float]]:
"""
Calculate loss only on target tokens.
Expand Down Expand Up @@ -225,7 +225,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str

@torch.no_grad()
def get_loss(
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
) -> List[List[float]]:
"""
Calculate loss only on target tokens.
Expand Down
28 changes: 19 additions & 9 deletions applications/ColossalEval/colossal_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kw
elif hasattr(self.tokenizer, "eod_id"):
# Qwen has an eod token "<|endoftext|>".
self.tokenizer.pad_token_id = self.tokenizer.eod_id
else:
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
raise ValueError(
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
"Please set pad_token_id manually."
)

def _load_model(
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
Expand Down Expand Up @@ -245,7 +251,7 @@ def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[L
return input_ids_list, labels_list, bytes_list

def _get_input_ids_and_labels(
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
) -> Tuple[List[torch.LongTensor]]:
"""
Get input_ids and labels for the given data.
Expand All @@ -258,7 +264,7 @@ def _get_input_ids_and_labels(
Input_ids and labels for the given batch.
"""
if pretrain:
if calculate_overall_loss:
batch = []
# Concatenate prompt and target answers.
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
Expand Down Expand Up @@ -342,7 +348,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d
calculate_loss = inference_kwargs["calculate_loss"]
classes = inference_kwargs["all_classes"]
language = inference_kwargs["language"]
pretrain = inference_kwargs["pretrain"]
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
max_new_tokens = inference_kwargs["max_new_tokens"]
few_shot_data = inference_kwargs.get("few_shot_data", None)

Expand Down Expand Up @@ -384,12 +390,12 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d
self.logger.info("-" * 120)
self.logger.info(batch_prompt[0] + batch_target[0][0])

if not pretrain:
if not calculate_overall_loss:
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)

if calculate_loss:
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
batch_prompt, batch_target, pretrain
batch_prompt, batch_target, calculate_overall_loss
)

probs = []
Expand All @@ -409,7 +415,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d
]

for j in range(len(batch)):
if not pretrain:
if not calculate_overall_loss:
if isinstance(batch[j]["output"], list):
batch[j]["output"].append(batch_decodes[j].strip())
else:
Expand Down Expand Up @@ -496,7 +502,9 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str
return decoded_sequences, scores

@torch.no_grad()
def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
def get_loss(
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
) -> List[List[float]]:
"""
Calculate loss only on target tokens.
Expand All @@ -513,13 +521,15 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr
# We don't need to generate new tokens.
# Target answer's length is usually << model_max_length, but we still call it in case.
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
if not pretrain:
if not calculate_overall_loss:
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]

# Get the number of target answers for different questions
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]

input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(
batch_prompt, batch_target, calculate_overall_loss
)

# Because of multiple target answers, the final batch size may be greater than self.batch_size.
# We will generate new batches.
Expand Down
Loading

0 comments on commit f9546ba

Please sign in to comment.