-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add DeepSeek-V2-MoE Mcore implementation (#240)
Co-authored-by: 同润 <jerry.lp@alibaba-inc.com>
- Loading branch information
1 parent
d0c9c0b
commit a002437
Showing
30 changed files
with
4,815 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
"""Pretrain GPT.""" | ||
|
||
import os | ||
import torch | ||
from functools import partial | ||
from typing import Union | ||
from megatron.training import get_args | ||
from megatron.training import print_rank_0 | ||
from megatron.training import get_timers | ||
from megatron.core import mpu | ||
from megatron.core.enums import ModelType | ||
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder | ||
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig | ||
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset | ||
import megatron.legacy.model | ||
from megatron.training import pretrain | ||
from megatron.core.transformer.spec_utils import import_module | ||
from megatron.training.utils import ( | ||
get_batch_on_this_cp_rank, | ||
get_batch_on_this_tp_rank, | ||
average_losses_across_data_parallel_group | ||
) | ||
from megatron.training.arguments import core_transformer_config_from_args | ||
|
||
from megatron_patch.data.utils import get_batch_on_this_tp_rank_original | ||
from megatron_patch.data import build_pretrain_dataset_from_original | ||
from megatron_patch.model.deepseek_v2.layer_specs import get_gpt_layer_with_transformer_engine_spec | ||
from megatron_patch.model.deepseek_v2.model import GPTModel | ||
from megatron_patch.model.deepseek_v2.transformer_config import DeepSeekV2TransformerConfig | ||
from megatron_patch.arguments import get_patch_args | ||
from megatron_patch.tokenizer import get_tokenizer, build_tokenizer | ||
import torch._dynamo | ||
torch._dynamo.config.suppress_errors = True | ||
|
||
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: | ||
"""Builds the model. | ||
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model. | ||
Args: | ||
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. | ||
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. | ||
Returns: | ||
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model | ||
""" | ||
args = get_args() | ||
build_tokenizer(args) | ||
print_rank_0('building GPT model ...') | ||
# Experimental loading arguments from yaml | ||
config = core_transformer_config_from_args(args, DeepSeekV2TransformerConfig) | ||
if args.use_mcore_models: | ||
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, qk_layernorm=True) | ||
model = GPTModel( | ||
config=config, | ||
transformer_layer_spec=transformer_layer_spec, | ||
vocab_size=args.padded_vocab_size, | ||
max_sequence_length=args.max_position_embeddings, | ||
pre_process=pre_process, | ||
post_process=post_process, | ||
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, | ||
parallel_output=True, | ||
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, | ||
position_embedding_type=args.position_embedding_type, | ||
rotary_percent=args.rotary_percent, | ||
rotary_base=args.rotary_base, | ||
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor | ||
) | ||
else: | ||
assert(args.context_parallel_size == 1), "Context parallelism is only supported with Megatron Core!" | ||
|
||
model = megatron.legacy.model.GPTModel( | ||
config, | ||
num_tokentypes=0, | ||
parallel_output=True, | ||
pre_process=pre_process, | ||
post_process=post_process | ||
) | ||
|
||
return model | ||
|
||
|
||
def get_batch(data_iterator): | ||
"""Generate a batch.""" | ||
|
||
# TODO: this is pretty hacky, find a better way | ||
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): | ||
return None, None, None, None, None | ||
|
||
args = get_args() | ||
|
||
if "-Raw" in args.dataset: | ||
# get batches based on the TP rank you are on | ||
batch = get_batch_on_this_tp_rank_original(data_iterator) | ||
# slice batch along sequence dimension for context parallelism | ||
batch = get_batch_on_this_cp_rank(batch) | ||
|
||
elif "-Idxmap" in args.dataset: | ||
# get batches based on the TP rank you are on | ||
batch = get_batch_on_this_tp_rank(data_iterator) | ||
# slice batch along sequence dimension for context parallelism | ||
batch = get_batch_on_this_cp_rank(batch) | ||
|
||
else: | ||
raise ValueError("please set correct --dataset ") | ||
|
||
return batch.values() | ||
|
||
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): | ||
"""Loss function. | ||
Args: | ||
loss_mask (torch.Tensor): Used to mask out some portions of the loss | ||
output_tensor (torch.Tensor): The tensor with the losses | ||
""" | ||
args = get_args() | ||
|
||
losses = output_tensor.float() | ||
loss_mask = loss_mask.view(-1).float() | ||
if args.context_parallel_size > 1: | ||
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) | ||
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) | ||
loss = loss[0] / loss[1] | ||
else: | ||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() | ||
|
||
# Check individual rank losses are not NaN prior to DP all-reduce. | ||
if args.check_for_nan_in_loss_and_grad: | ||
global_rank = torch.distributed.get_rank() | ||
assert not loss.isnan(), ( | ||
f'Rank {global_rank}: found NaN in local forward loss calculation. ' | ||
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' | ||
) | ||
|
||
# Reduce loss for logging. | ||
averaged_loss = average_losses_across_data_parallel_group([loss]) | ||
|
||
return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]} | ||
|
||
|
||
def forward_step(data_iterator, model: GPTModel): | ||
"""Forward training step. | ||
Args: | ||
data_iterator : Input data iterator | ||
model (GPTModel): The GPT Model | ||
""" | ||
timers = get_timers() | ||
|
||
# Get the batch. | ||
timers('batch-generator', log_level=2).start() | ||
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( | ||
data_iterator) | ||
timers('batch-generator').stop() | ||
|
||
output_tensor = model(tokens, position_ids, attention_mask, | ||
labels=labels) | ||
|
||
return output_tensor, partial(loss_func, loss_mask) | ||
|
||
|
||
def is_dataset_built_on_rank(): | ||
return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0 | ||
|
||
|
||
def core_gpt_dataset_config_from_args(args): | ||
tokenizer = get_tokenizer() | ||
|
||
return GPTDatasetConfig( | ||
random_seed=args.seed, | ||
sequence_length=args.seq_length, | ||
blend=args.data_path, | ||
blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path], | ||
split=args.split, | ||
path_to_cache=args.data_cache_path, | ||
mock=args.mock_data, | ||
mmap_bin_files=args.mmap_bin_files, | ||
tokenizer=tokenizer, | ||
reset_position_ids=args.reset_position_ids, | ||
reset_attention_mask=args.reset_attention_mask, | ||
eod_mask_loss=args.eod_mask_loss, | ||
create_attention_mask=args.create_attention_mask_in_dataloader, | ||
) | ||
|
||
|
||
def train_valid_test_datasets_provider(train_val_test_num_samples): | ||
"""Build the train test and validation datasets. | ||
Args: | ||
train_val_test_num_samples : A list containing the number of samples in train test and validation. | ||
""" | ||
args = get_args() | ||
print_rank_0("> building train, validation, and test datasets for GPT ...") | ||
|
||
if "-Raw" in args.dataset: | ||
train_ds, valid_ds, test_ds = build_pretrain_dataset_from_original(args.dataset) | ||
else: | ||
config = core_gpt_dataset_config_from_args(args) | ||
|
||
if config.mock: | ||
dataset_type = MockGPTDataset | ||
else: | ||
dataset_type = GPTDataset | ||
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( | ||
dataset_type, | ||
train_val_test_num_samples, | ||
is_dataset_built_on_rank, | ||
config | ||
).build() | ||
|
||
print_rank_0("> finished creating GPT datasets ...") | ||
|
||
return train_ds, valid_ds, test_ds | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
train_valid_test_datasets_provider.is_distributed = True | ||
|
||
pretrain(train_valid_test_datasets_provider, | ||
model_provider, | ||
ModelType.encoder_or_decoder, | ||
forward_step, | ||
extra_args_provider=get_patch_args) |
Oops, something went wrong.