From 84e0355191cfd6fcc5951151c53f394cda8bfd8d Mon Sep 17 00:00:00 2001 From: Jerry Li Date: Wed, 27 Dec 2023 16:16:18 +0800 Subject: [PATCH] Add Mixtral MoE and Qwen-vl (#105) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 同润 --- README.md | 2 + README_zh-CN.md | 2 + .../mistral/evaluate_huggingface_mistral.py | 14 +- examples/mistral/evaluate_megatron_mistral.py | 12 +- .../run_evaluate_huggingface_mistral.sh | 2 +- .../mistral/run_evaluate_megatron_mistral.sh | 10 +- .../mixtral/evaluate_huggingface_mixtral.py | 136 ++ examples/mixtral/evaluate_megatron_mixtral.py | 173 ++ examples/mixtral/finetune_megatron_mixtral.py | 92 + examples/mixtral/pretrain_megatron_mixtral.py | 129 ++ .../run_evaluate_huggingface_mixtral.sh | 103 + .../mixtral/run_evaluate_megatron_mixtral.sh | 167 ++ .../mixtral/run_finetune_megatron_mixtral.sh | 201 ++ .../mixtral/run_pretrain_megatron_mixtral.sh | 218 ++ examples/qwen_vl/finetune_megatron_qwen_vl.py | 90 + .../qwen_vl/run_finetune_megatron_qwen_vl.sh | 221 ++ megatron_patch/arguments.py | 51 +- megatron_patch/checkpointing.py | 142 +- megatron_patch/data/__init__.py | 26 +- megatron_patch/data/llama.py | 2 +- megatron_patch/data/mistral.py | 2 +- megatron_patch/data/qwen_vl.py | 122 ++ megatron_patch/distributed.py | 217 ++ megatron_patch/expert_parallel_state.py | 79 + megatron_patch/initialize.py | 159 ++ megatron_patch/model/llava/language_model.py | 43 +- megatron_patch/model/mixtral/__init__.py | 13 + megatron_patch/model/mixtral/all2all.py | 62 + megatron_patch/model/mixtral/experts.py | 45 + megatron_patch/model/mixtral/gpt_model.py | 132 ++ .../model/mixtral/language_model.py | 694 ++++++ megatron_patch/model/mixtral/layer.py | 138 ++ .../model/mixtral/moe_parallel_linear.py | 470 +++++ megatron_patch/model/mixtral/router.py | 787 +++++++ megatron_patch/model/mixtral/transformer.py | 1750 +++++++++++++++ megatron_patch/model/qwen_vl/__init__.py | 13 + megatron_patch/model/qwen_vl/gpt_model.py | 133 ++ .../model/qwen_vl/language_model.py | 680 ++++++ megatron_patch/model/qwen_vl/transformer.py | 1870 +++++++++++++++++ megatron_patch/model/qwen_vl/visual.py | 425 ++++ megatron_patch/optimizer/__init__.py | 93 + megatron_patch/optimizer/distrib_optimizer.py | 707 +++++++ megatron_patch/tokenizer/__init__.py | 26 +- .../tokenizer/tokenization_mistral.py | 36 - .../tokenizer/tokenization_qwen_vl.py | 587 ++++++ megatron_patch/training.py | 45 +- ...eckpoint_reshaping_and_interoperability.py | 739 +++++++ .../mixtral/model_convertor.sh | 40 + 48 files changed, 11783 insertions(+), 117 deletions(-) create mode 100644 examples/mixtral/evaluate_huggingface_mixtral.py create mode 100644 examples/mixtral/evaluate_megatron_mixtral.py create mode 100644 examples/mixtral/finetune_megatron_mixtral.py create mode 100644 examples/mixtral/pretrain_megatron_mixtral.py create mode 100644 examples/mixtral/run_evaluate_huggingface_mixtral.sh create mode 100644 examples/mixtral/run_evaluate_megatron_mixtral.sh create mode 100644 examples/mixtral/run_finetune_megatron_mixtral.sh create mode 100644 examples/mixtral/run_pretrain_megatron_mixtral.sh create mode 100644 examples/qwen_vl/finetune_megatron_qwen_vl.py create mode 100644 examples/qwen_vl/run_finetune_megatron_qwen_vl.sh create mode 100644 megatron_patch/data/qwen_vl.py create mode 100644 megatron_patch/distributed.py create mode 100644 megatron_patch/expert_parallel_state.py create mode 100644 megatron_patch/initialize.py create mode 100644 megatron_patch/model/mixtral/__init__.py create mode 100644 megatron_patch/model/mixtral/all2all.py create mode 100644 megatron_patch/model/mixtral/experts.py create mode 100644 megatron_patch/model/mixtral/gpt_model.py create mode 100644 megatron_patch/model/mixtral/language_model.py create mode 100644 megatron_patch/model/mixtral/layer.py create mode 100644 megatron_patch/model/mixtral/moe_parallel_linear.py create mode 100644 megatron_patch/model/mixtral/router.py create mode 100644 megatron_patch/model/mixtral/transformer.py create mode 100644 megatron_patch/model/qwen_vl/__init__.py create mode 100644 megatron_patch/model/qwen_vl/gpt_model.py create mode 100644 megatron_patch/model/qwen_vl/language_model.py create mode 100644 megatron_patch/model/qwen_vl/transformer.py create mode 100644 megatron_patch/model/qwen_vl/visual.py create mode 100644 megatron_patch/optimizer/__init__.py create mode 100644 megatron_patch/optimizer/distrib_optimizer.py delete mode 100644 megatron_patch/tokenizer/tokenization_mistral.py create mode 100644 megatron_patch/tokenizer/tokenization_qwen_vl.py create mode 100644 toolkits/model_checkpoints_convertor/mixtral/checkpoint_reshaping_and_interoperability.py create mode 100644 toolkits/model_checkpoints_convertor/mixtral/model_convertor.sh diff --git a/README.md b/README.md index 04edf349..c5048176 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ English | [简体中文](./README_zh-CN.md) Pai-Megatron-Patch (https://github.com/alibaba/Pai-Megatron-Patch) is a deep learning training toolkit built for developers to train and predict large language models (LLMs) by using MegatronLM framework easily. With the continuous development of LLMs, the model structure and scale are rapidly evolving. Although these models can be conveniently manufactured using Transformers or DeepSpeed training framework, the training efficiency is comparably low. This phenomenon becomes even severer when the model scale exceeds 10 billion. The primary objective of Pai-Megatron-Patch is to effectively utilize the computational power of GPUs for LLM. This tool allows convenient training of commonly used LLM with all the accelerating techniques provided by Megatron-LM. What's New: +- **Support fine-tuning mixtral-8x7b moe model by using Megatron-LM.** [🔥🔥 2023.12.27] +- **Support fine-tuning qwen-vl multimodel by using Megatron-LM.** [🔥🔥 2023.12.15] - **Support fine-tuning LLava multimodel by using Megatron-LM.** [🔥🔥 2023.12.01] - **Support fine-tuning deepseek model by using Megatron-LM.** [🔥🔥 2023.11.24] - **Support fine-tuning qwen-72B model by using Megatron-LM.** [🔥🔥 2023.11.23] diff --git a/README_zh-CN.md b/README_zh-CN.md index e8877e3b..ab11a6d8 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -17,6 +17,8 @@ Pai-Megatron-Patch工具是阿里云机器学习平台PAI算法团队研发, - [阿里云PAI获得FewCLUE基于大模型的小样本学习双料冠军](https://developer.aliyun.com/article/788081?spm=a2c6h.12873639.article-detail.17.11c5383cHpFZks&tlog=yuekan_8) 新功能: +- **支持用MegatronLM框架训练mixtral-8x7b MoE稀疏模型** [🔥🔥 2023.12.27] +- **支持用MegatronLM框架微调多模态大模型qwen-vl.** [🔥🔥 2023.12.15] - **支持用MegatronLM框架微调多模态大模型LLava.** [🔥🔥 2023.12.01] - **支持用MegatronLM框架训练deepseek系列模型.** [🔥🔥 2023.11.24] - **支持用MegatronLM框架微调qwen-72B模型.** [🔥🔥 2023.11.23] diff --git a/examples/mistral/evaluate_huggingface_mistral.py b/examples/mistral/evaluate_huggingface_mistral.py index ecb784cf..307f4a68 100644 --- a/examples/mistral/evaluate_huggingface_mistral.py +++ b/examples/mistral/evaluate_huggingface_mistral.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron.core.enums import ModelType from megatron import get_args from megatron import print_rank_0 -from megatron import is_last_rank from megatron.core import parallel_state from megatron.core.pipeline_parallel.p2p_communication import send_forward from megatron.initialize import initialize_megatron @@ -28,23 +26,19 @@ from megatron.utils import unwrap_model from megatron.arguments import core_transformer_config_from_args -from megatron_patch.data.evaluate_dataset import build_evaluation_dataset +from megatron_patch.data import build_evaluation_dataset from megatron_patch.finetune_utils import build_data_loader -from megatron_patch.tokenizer import build_tokenizer from megatron_patch.tokenizer import get_tokenizer from megatron_patch.training import get_model from megatron_patch.arguments import get_tasks_args -from megatron_patch.model.mistral.modeling_mistral import MistralForCausalLM - +from transformers import AutoModelForCausalLM def get_model_provider(): """Based on evaluation metric set the parallel-output flag and return the model provider.""" def model_provider(pre_process=True, post_process=True): args = get_args() - tokenizer = build_tokenizer(args) - model = MistralForCausalLM.from_pretrained(args.load, - trust_remote_code=False) + model = AutoModelForCausalLM.from_pretrained(args.load, device_map="auto") return model return model_provider @@ -56,7 +50,7 @@ def forward_step(batch, model): # Get the batch. input_ids = batch['input_ids'].long().cuda() labels = batch['labels'].long().cuda() - labels[labels == -1] = -100 + labels[labels == 0] = -100 attention_mask = input_ids.ne(tokenizer.pad_token_id) # Tell the model what our actual batch size will be diff --git a/examples/mistral/evaluate_megatron_mistral.py b/examples/mistral/evaluate_megatron_mistral.py index 140fd463..2f1a70d1 100644 --- a/examples/mistral/evaluate_megatron_mistral.py +++ b/examples/mistral/evaluate_megatron_mistral.py @@ -13,9 +13,6 @@ # limitations under the License. import torch -from megatron_patch.data import \ - build_pretrain_dataset_from_original, build_pretrain_dataset_from_idxmap - from megatron.core.enums import ModelType from megatron import get_args from megatron import print_rank_0 @@ -27,6 +24,7 @@ from megatron.utils import get_ltor_masks_and_position_ids from megatron.arguments import core_transformer_config_from_args +from megatron_patch.data import build_evaluation_dataset from megatron_patch.checkpointing import load_checkpoint from megatron_patch.finetune_utils import build_data_loader from megatron_patch.model.mistral.gpt_model import GPTModel @@ -67,8 +65,8 @@ def get_batch(batch): tokens = tokens[:, :-1].contiguous() labels = labels[:, 1:].contiguous() - - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask = tokens.ne(tokenizer.pad_token_id) + _, loss_mask, position_ids = get_ltor_masks_and_position_ids( labels, tokenizer.pad_token_id, args.reset_position_ids, @@ -145,9 +143,7 @@ def main(): exit() # Data stuff. - #dataset = build_evaluation_dataset(args.dataset) - dataset, _, _ = \ - build_pretrain_dataset_from_original(args.dataset) + dataset = build_evaluation_dataset(args.dataset) dataloader = build_data_loader(dataset, args.micro_batch_size, args.num_workers, diff --git a/examples/mistral/run_evaluate_huggingface_mistral.sh b/examples/mistral/run_evaluate_huggingface_mistral.sh index 26fc99f6..7b71866d 100644 --- a/examples/mistral/run_evaluate_huggingface_mistral.sh +++ b/examples/mistral/run_evaluate_huggingface_mistral.sh @@ -74,7 +74,7 @@ fi megatron_options=" \ --transformer-type huggingface \ - --data-path ${DATASET_PATH} + --valid-data-path ${DATASET_PATH} --micro-batch-size ${BATCH_SIZE} \ --num-layers ${NUM_LAYERS} \ --hidden-size ${HIDDEN_SIZE} \ diff --git a/examples/mistral/run_evaluate_megatron_mistral.sh b/examples/mistral/run_evaluate_megatron_mistral.sh index 13c8f300..c6358ead 100644 --- a/examples/mistral/run_evaluate_megatron_mistral.sh +++ b/examples/mistral/run_evaluate_megatron_mistral.sh @@ -1,5 +1,5 @@ #!/bin/bash -#sh run_evaluate_megatron_mistral.sh dsw /workspace/Pai-Megatron-Patch 7B 1 80 80 0 bf16 2 1 sel true false true false /mnt/llama2-datasets/alpaca_data.json /mnt/mistral-ckpts/Mistral-7B-v0.1-to-mg-tp2-pp1/ +#sh run_evaluate_megatron_mistral.sh dsw ../.. 7B 1 81 81 0 bf16 2 1 sel true false true false /mnt/llama2-datasets/alpaca_data.json /mnt/mistral-ckpts/Mistral-7B-v0.1-to-mg-tp2-pp1/ set -e ENV=$1 MEGATRON_PATCH_PATH=$2 @@ -7,12 +7,12 @@ MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH export CUDA_DEVICE_MAX_CONNECTIONS=1 if [ $ENV = dsw ]; then -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export CUDA_VISIBLE_DEVICES=0 MASTER_ADDR=localhost MASTER_PORT=$(shuf -n 1 -i 10000-65535) NNODES=1 NODE_RANK=0 -GPUS_PER_NODE=8 +GPUS_PER_NODE=1 elif [ $ENV = dlc ]; then @@ -125,7 +125,7 @@ fi megatron_options=" \ - --train-data-path ${DATASET_PATH} + --valid-data-path ${DATASET_PATH} --micro-batch-size ${BATCH_SIZE} \ --num-layers ${NUM_LAYERS} \ --hidden-size ${HIDDEN_SIZE} \ @@ -145,7 +145,7 @@ megatron_options=" \ --max-padding-length ${PAD_LEN} \ --extra-vocab-size ${EXTRA_VOCAB_SIZE} \ --patch-tokenizer-type MistralTokenizer \ - --dataset LLama-Pretrain-Raw \ + --dataset Mistral-SFT \ --sliding-window ${SLW} \ --swiglu \ --normalization RMSNorm \ diff --git a/examples/mixtral/evaluate_huggingface_mixtral.py b/examples/mixtral/evaluate_huggingface_mixtral.py new file mode 100644 index 00000000..63e52a9b --- /dev/null +++ b/examples/mixtral/evaluate_huggingface_mixtral.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + +from megatron.core.enums import ModelType +from megatron import get_args +from megatron import print_rank_0 +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.p2p_communication import send_forward +from megatron.initialize import initialize_megatron +from megatron.model import DistributedDataParallel as LocalDDP +from megatron.model import Float16Module +from megatron.utils import unwrap_model +from megatron.arguments import core_transformer_config_from_args + +from megatron_patch.data import build_evaluation_dataset +from megatron_patch.finetune_utils import build_data_loader +from megatron_patch.tokenizer import get_tokenizer +from megatron_patch.training import get_model +from megatron_patch.arguments import get_tasks_args +from transformers import AutoModelForCausalLM + +def get_model_provider(): + """Based on evaluation metric set the parallel-output flag and + return the model provider.""" + def model_provider(pre_process=True, post_process=True): + args = get_args() + """ + from accelerate import load_checkpoint_and_dispatch + from accelerate import init_empty_weights + with init_empty_weights(): + config = MixtralConfig() + model = MixtralForCausalLM(config=config) + model = load_checkpoint_and_dispatch(model, checkpoint=args.load, device_map=device_map, dtype=torch.bfloat16) + """ + model = AutoModelForCausalLM.from_pretrained(args.load, torch_dtype=torch.bfloat16, device_map="auto") + return model + + return model_provider + + +def forward_step(batch, model): + """Forward step.""" + tokenizer = get_tokenizer() + # Get the batch. + input_ids = batch['input_ids'].long().cuda() + labels = batch['labels'].long().cuda() + labels[labels == 0] = -100 + attention_mask = input_ids.ne(tokenizer.pad_token_id) + + # Tell the model what our actual batch size will be + args = get_args() + + # Forward pass through the model. + unwrapped_model = unwrap_model(model, (torchDDP, LocalDDP, Float16Module)) + output = unwrapped_model(input_ids=input_ids, + labels=labels, + attention_mask=attention_mask) + config = core_transformer_config_from_args(args) + send_forward(output, config) + if parallel_state.is_pipeline_last_stage(): + print_rank_0(output.loss) + return output.loss + + return None + + +def evaluate(data_loader, model): + """Evaluation.""" + args = get_args() + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_output = 0.0 + with torch.no_grad(): + # For all the batches in the dataset. + for iteration, batch in enumerate(data_loader): + if iteration % args.log_interval == 0: + print_rank_0('> working on iteration: {}'.format(iteration)) + # Forward evaluation. + output = forward_step(batch, model) + + # Reduce across processes. + if parallel_state.is_pipeline_last_stage(): + torch.distributed.all_reduce( + output, group=parallel_state.get_data_parallel_group()) + + total_output += output + + return total_output + +def main(): + """Main program.""" + args = get_args() + if args.num_layers_per_virtual_pipeline_stage is not None: + print('Interleaved pipeline schedule ' + 'is not yet supported for text generation.') + exit() + + # Set up model and load checkpoint. + model = get_model(get_model_provider(), + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False) + + assert len(model) == 1, 'Above condition should have caught this' + model = model[0] + + # Data stuff. + dataset = build_evaluation_dataset(args.dataset) + dataloader = build_data_loader(dataset, + args.micro_batch_size, + args.num_workers, + drop_last=False) + + # Run evaluation. + evaluate(dataloader, model) + print_rank_0('done :-)') + + +if __name__ == '__main__': + initialize_megatron(extra_args_provider=get_tasks_args) + main() diff --git a/examples/mixtral/evaluate_megatron_mixtral.py b/examples/mixtral/evaluate_megatron_mixtral.py new file mode 100644 index 00000000..7409bfba --- /dev/null +++ b/examples/mixtral/evaluate_megatron_mixtral.py @@ -0,0 +1,173 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.enums import ModelType +from megatron import get_args +from megatron import print_rank_0 +from megatron.core import parallel_state, tensor_parallel +from megatron.core.pipeline_parallel.p2p_communication import recv_forward +from megatron.core.pipeline_parallel.p2p_communication import send_forward +from megatron.utils import unwrap_model +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.arguments import core_transformer_config_from_args + +from megatron_patch.data import build_evaluation_dataset +from megatron_patch.checkpointing import load_checkpoint +from megatron_patch.finetune_utils import build_data_loader +from megatron_patch.model.mixtral.gpt_model import GPTModel +from megatron_patch.arguments import get_tasks_args +from megatron_patch.tokenizer import get_tokenizer +from megatron_patch.training import get_model +from megatron_patch.initialize import initialize_megatron + +def get_model_provider(): + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + config = core_transformer_config_from_args(args) + model = GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + + return model + + return model_provider + +def get_batch(batch): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ['input_ids', 'labels'] + datatype = torch.int64 + + data_b = tensor_parallel.broadcast_data(keys, batch, datatype) + + tokens = data_b['input_ids'].long().cuda().contiguous() + labels = data_b['labels'].long().cuda().contiguous() + + tokens = tokens[:, :-1].contiguous() + labels = labels[:, 1:].contiguous() + attention_mask = tokens.ne(tokenizer.pad_token_id) + _, loss_mask, position_ids = get_ltor_masks_and_position_ids( + labels, + tokenizer.pad_token_id, + args.reset_position_ids, + args.reset_attention_mask, + True) + + return tokens, labels, loss_mask, attention_mask, position_ids + +def forward_step(batch, model): + """Forward step.""" + + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + batch) + + # Tell the model what our actual batch size will be + args = get_args() + args.micro_batch_size = len(labels) + config = core_transformer_config_from_args(args) + tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) + input_tensor = recv_forward(tensor_shape, config) + + # Forward pass through the model. + unwrapped_model = unwrap_model(model) + unwrapped_model.set_input_tensor(input_tensor) + output = model(tokens, position_ids, attention_mask) + send_forward(output, config) + #if parallel_state.is_pipeline_last_stage(): + if output.shape[-1] != args.hidden_size: + loss_mask = loss_mask.view(-1).float() + # For loss, return the unreduced loss. + losses = tensor_parallel.vocab_parallel_cross_entropy( + output.contiguous().float(), labels.contiguous()) + loss = torch.sum( + losses.view(-1) * loss_mask.contiguous().view(-1).float()) / loss_mask.sum() + print(loss) + print_rank_0(loss) + return loss + + return None + + +def evaluate(data_loader, model): + """Evaluation.""" + args = get_args() + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_output = 0.0 + with torch.no_grad(): + # For all the batches in the dataset. + for iteration, batch in enumerate(data_loader): + if iteration % args.log_interval == 0: + print_rank_0('> working on iteration: {}'.format(iteration)) + # Forward evaluation. + output = forward_step(batch, model) + + # Reduce across processes. + if parallel_state.is_pipeline_last_stage(): + torch.distributed.all_reduce( + output, group=parallel_state.get_data_parallel_group()) + + total_output += output + + return total_output + + +def main(): + """Main program.""" + args = get_args() + if args.num_layers_per_virtual_pipeline_stage is not None: + print('Interleaved pipeline schedule ' + 'is not yet supported for text generation.') + exit() + + # Data stuff. + dataset = build_evaluation_dataset(args.dataset) + dataloader = build_data_loader(dataset, + args.micro_batch_size, + args.num_workers, + drop_last=False) + + + # Set up model and load checkpoint. + model = get_model(get_model_provider(), + model_type=ModelType.encoder_or_decoder, + wrap_with_ddp=False) + + if args.load is not None: + load_checkpoint(model, None, None) + + assert len(model) == 1, 'Above condition should have caught this' + model = model[0] + + + + # Run evaluation. + evaluate(dataloader, model) + print_rank_0('done :-)') + + +if __name__ == '__main__': + initialize_megatron(extra_args_provider=get_tasks_args) + main() diff --git a/examples/mixtral/finetune_megatron_mixtral.py b/examples/mixtral/finetune_megatron_mixtral.py new file mode 100644 index 00000000..e59caf23 --- /dev/null +++ b/examples/mixtral/finetune_megatron_mixtral.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch + +from megatron import get_args +from megatron.core import parallel_state, tensor_parallel +from megatron.initialize import initialize_megatron +from megatron.utils import average_losses_across_data_parallel_group +from megatron.utils import get_ltor_masks_and_position_ids +from megatron_patch.data import build_finetune_dataset +from megatron_patch.finetune_utils import finetune +from megatron_patch.model.mixtral.gpt_model import GPTModel +from megatron_patch.tokenizer import get_tokenizer +from megatron_patch.arguments import get_tasks_args +from megatron.arguments import core_transformer_config_from_args + + +def model_provider(pre_process=True, post_process=True): + config = core_transformer_config_from_args(get_args()) + model = GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + + +def train_valid_datasets_provider(): + args = get_args() + train_dataset, valid_dataset = build_finetune_dataset(args.dataset) + return train_dataset, valid_dataset + + +def forward_step(data_iterator, model): + args = get_args() + tokenizer = get_tokenizer() + + try: + data_iterator = next(data_iterator) + except BaseException: + data_iterator = data_iterator + + tokens = data_iterator['input_ids'].long().cuda().contiguous() + labels = data_iterator['labels'].long().cuda().contiguous() + + tokens = tokens[:, :-1].contiguous() + labels = labels[:, 1:].contiguous() + + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + labels, + tokenizer.pad_token_id, + args.reset_position_ids, + args.reset_attention_mask, + True) + + logits = model(input_ids=tokens, + position_ids=position_ids, + attention_mask=attention_mask) + + def loss_func(loss_mask, logits): + losses = tensor_parallel.vocab_parallel_cross_entropy( + logits.contiguous().float(), labels.contiguous()) + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + averaged_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'lm loss': averaged_loss[0]} + + return logits, partial(loss_func, loss_mask) + + +if __name__ == '__main__': + + initialize_megatron(extra_args_provider=get_tasks_args) + + finetune(train_valid_datasets_provider=train_valid_datasets_provider, + model_provider=model_provider, + forward_step=forward_step) diff --git a/examples/mixtral/pretrain_megatron_mixtral.py b/examples/mixtral/pretrain_megatron_mixtral.py new file mode 100644 index 00000000..0eb8d18a --- /dev/null +++ b/examples/mixtral/pretrain_megatron_mixtral.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch +import os + +from megatron.core.enums import ModelType +from megatron.utils import get_ltor_masks_and_position_ids +from megatron.arguments import core_transformer_config_from_args +from megatron import get_args +from megatron import get_timers +from megatron.core import tensor_parallel +from megatron.utils import average_losses_across_data_parallel_group + +from megatron_patch.data import \ + build_pretrain_dataset_from_original, build_pretrain_dataset_from_idxmap +from megatron_patch.model.mixtral.gpt_model import GPTModel +from megatron_patch.tokenizer import get_tokenizer, build_tokenizer +from megatron_patch.training import pretrain +from megatron_patch.arguments import get_tasks_args + + +def model_provider(pre_process=True, post_process=True): + args = get_args() + build_tokenizer(args) + config = core_transformer_config_from_args(get_args()) + model = GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + datatype = torch.int64 + + keys = ['input_ids', 'labels'] + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = tensor_parallel.broadcast_data(keys, data, datatype) + tokens_ = data_b['input_ids'].long() + + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + attention_mask = tokens.ne(tokenizer.pad_token_id) + # Get the masks and postition ids. + _, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.pad_token_id, + args.reset_position_ids, + args.reset_attention_mask, + True) + + return tokens, labels, loss_mask, attention_mask, position_ids + +def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + + if os.path.isfile(args.train_data_path[0]): + train_ds, valid_ds, test_ds = \ + build_pretrain_dataset_from_original(args.dataset) + else: + train_ds, valid_ds, test_ds = \ + build_pretrain_dataset_from_idxmap( + data_prefix=args.train_data_path, + max_padding_length=args.max_padding_length, + dataset_type=args.dataset, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seed=args.seed, + skip_warmup=(not args.mmap_warmup) + ) + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + pretrain(train_valid_test_datasets_provider, model_provider, + ModelType.encoder_or_decoder, + forward_step, + extra_args_provider=get_tasks_args, + moe=True) diff --git a/examples/mixtral/run_evaluate_huggingface_mixtral.sh b/examples/mixtral/run_evaluate_huggingface_mixtral.sh new file mode 100644 index 00000000..6156bc6b --- /dev/null +++ b/examples/mixtral/run_evaluate_huggingface_mixtral.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# sh run_evaluate_huggingface_mixtral.sh dsw ../.. 7B 1 80 80 0 bf16 /mnt/llama2-datasets/alpaca_data.json /mnt/mixtral-ckpts/Mixtral-8x7B-v0.1 + +set -e +ENV=$1 +MEGATRON_PATCH_PATH=$2 +MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main +export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 +if [ $ENV = dsw ]; then +export CUDA_VISIBLE_DEVICES=0 +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=1 + +elif [ $ENV = dlc ]; then + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + +fi + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +MODEL_SIZE=$3 +BATCH_SIZE=$4 +SEQ_LEN=$5 +PAD_LEN=$6 +EXTRA_VOCAB_SIZE=$7 +PR=$8 +DATASET_PATH=$9 +PRETRAIN_CHECKPOINT_PATH=${10} + + +if [ $MODEL_SIZE = 7B ]; then + +NUM_LAYERS=32 +HIDDEN_SIZE=4096 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=11008 + +elif [ $MODEL_SIZE = 13B ]; then + +NUM_LAYERS=40 +HIDDEN_SIZE=5120 +NUM_ATTN_HEADS=40 +INTERMEDIATE_SIZE=13824 + +elif [ $MODEL_SIZE = 70B ]; then + +NUM_LAYERS=80 +HIDDEN_SIZE=8192 +NUM_ATTN_HEADS=64 +INTERMEDIATE_SIZE=28672 + +fi + +if [ $PR = fp16 ]; then + pr_options=" \ + --fp16" +elif [ $PR = bf16 ]; then + pr_options=" \ + --bf16" +fi + +if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then + load_options=" \ + --load $PRETRAIN_CHECKPOINT_PATH" +fi + + +megatron_options=" \ + --transformer-type huggingface \ + --valid-data-path ${DATASET_PATH} + --micro-batch-size ${BATCH_SIZE} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTN_HEADS} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --log-interval 1 \ + --eval-interval 100 \ + --eval-iters 10 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --no-load-optim \ + --num-workers 0 \ + --dataset Mistral-SFT \ + --use-distributed-optimizer \ + --max-padding-length ${PAD_LEN} \ + --extra-vocab-size ${EXTRA_VOCAB_SIZE} \ + --patch-tokenizer-type MistralTokenizer + " + +run_cmd="torchrun $DISTRIBUTED_ARGS evaluate_huggingface_mixtral.py + ${megatron_options} ${pr_options} ${load_options}" + +echo ${run_cmd} +eval ${run_cmd} +set +x diff --git a/examples/mixtral/run_evaluate_megatron_mixtral.sh b/examples/mixtral/run_evaluate_megatron_mixtral.sh new file mode 100644 index 00000000..c97db2f0 --- /dev/null +++ b/examples/mixtral/run_evaluate_megatron_mixtral.sh @@ -0,0 +1,167 @@ +#!/bin/bash +#sh run_evaluate_megatron_mixtral.sh dsw ../.. 7B 1 80 80 0 bf16 1 1 sel true false false false /mnt/llama2-datasets/alpaca_data.json /mnt/mixtral-ckpts/Mixtral-8x7B-v0.1-to-mg-tp1-pp1/ +set -e +ENV=$1 +MEGATRON_PATCH_PATH=$2 +MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main +export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 +if [ $ENV = dsw ]; then +export CUDA_VISIBLE_DEVICES=0,1,2,3 +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=4 + +elif [ $ENV = dlc ]; then + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + +fi + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +MODEL_SIZE=$3 +BATCH_SIZE=$4 +SEQ_LEN=$5 +PAD_LEN=$6 +EXTRA_VOCAB_SIZE=$7 +PR=$8 +TP=$9 +PP=${10} +AC=${11} +DO=${12} +FL=${13} +SP=${14} +TE=${15} +DATASET_PATH=${16} +PRETRAIN_CHECKPOINT_PATH=${17} + +if [ $MODEL_SIZE = 7B ]; then + +NUM_LAYERS=32 +HIDDEN_SIZE=4096 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=14336 +MPE=32768 +SLW=4096 + +gqa_options=" \ + --group-query-attention \ + --num-query-groups 8" + +fi + +if [ $AC = full ]; then + activation_checkpoint_options=" \ + --recompute-method uniform \ + --recompute-granularity full" +elif [ $AC = sel ]; then + activation_checkpoint_options=" \ + --recompute-activations" +elif [ $AC = none ]; then + activation_checkpoint_options=" \ + " +fi + +if [ $PR = fp16 ]; then + pr_options=" \ + --fp16" +elif [ $PR = bf16 ]; then + pr_options=" \ + --bf16" +elif [ $PR = fp8 ]; then + pr_options=" \ + --bf16 + --fp8-hybrid \ + --fp8-amax-compute-algo max \ + --fp8-amax-history-len 1024 \ + --transformer-impl transformer_engine" +fi + +if [ $DO = true ]; then + do_options=" \ + --use-distributed-optimizer" + +elif [ $DO = false ]; then + do_options=" \ + " +fi + +if [ $FL = true ]; then + flash_options=" \ + --use-flash-attn" + +elif [ $FL = false ]; then + flash_options=" \ + " +fi + +if [ $TE = true ]; then + te_options=" \ + --transformer-impl transformer_engine" + +elif [ $TE = false ]; then + te_options=" \ + " +fi + +if [ $SP = true ] && [ $TP -gt 1 ]; then + sp_options=" \ + --sequence-parallel" + +elif [ $SP = false ]; then + sp_options=" \ + " +fi + +if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then + load_options=" \ + --load $PRETRAIN_CHECKPOINT_PATH" +fi + + +megatron_options=" \ + --valid-data-path ${DATASET_PATH} + --micro-batch-size ${BATCH_SIZE} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTN_HEADS} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${MPE} \ + --ffn-hidden-size ${INTERMEDIATE_SIZE} \ + --log-interval 1 \ + --eval-interval 100 \ + --eval-iters 10 \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --no-load-optim \ + --no-load-rng \ + --seed 1234 \ + --num-workers 0 \ + --max-padding-length ${PAD_LEN} \ + --extra-vocab-size ${EXTRA_VOCAB_SIZE} \ + --patch-tokenizer-type MistralTokenizer \ + --dataset Mistral-SFT \ + --sliding-window ${SLW} \ + --swiglu \ + --normalization RMSNorm \ + --use-mistral-rotary-position-embeddings \ + --position-embedding-type rope \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --router-type topk \ + --expert-interval 1 \ + --num-experts 8 \ + --moe-topk 2 + " + +run_cmd="torchrun $DISTRIBUTED_ARGS evaluate_megatron_mixtral.py + ${megatron_options} ${pr_options} ${load_options} ${te_options} ${activation_checkpoint_options} ${do_options} ${flash_options} ${sp_options} ${gqa_options}" + +echo ${run_cmd} +eval ${run_cmd} +set +x diff --git a/examples/mixtral/run_finetune_megatron_mixtral.sh b/examples/mixtral/run_finetune_megatron_mixtral.sh new file mode 100644 index 00000000..7636f000 --- /dev/null +++ b/examples/mixtral/run_finetune_megatron_mixtral.sh @@ -0,0 +1,201 @@ +#!/bin/bash +#sh run_finetune_megatron_mistral.sh dsw /workspace/Pai-Megatron-Patch 7B 2 1e-5 1e-6 80 80 0 bf16 8 1 sel true false true false /mnt/llama2-datasets/alpaca_data.json /mnt/llama2-datasets/alpaca_data.json /mnt/mistral-ckpts/Mistral-7B-v0.1-to-mg-tp8-pp1/ 2 /mnt/output_patch_test +set -e +ENV=$1 +MEGATRON_PATCH_PATH=$2 +MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main +export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 +if [ $ENV = dsw ]; then +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=8 + +elif [ $ENV = dlc ]; then + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + +fi + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +MODEL_SIZE=$3 #7B +BATCH_SIZE=$4 +LR=$5 +MIN_LR=$6 +SEQ_LEN=$7 +PAD_LEN=$8 +EXTRA_VOCAB_SIZE=$9 +PR=${10} +TP=${11} +PP=${12} +AC=${13} +DO=${14} +FL=${15} +SP=${16} +TE=${17} +TRAIN_DATASET_PATH=${18} +VALID_DATASET_PATH=${19} +PRETRAIN_CHECKPOINT_PATH=${20} +EPOCH=${21} +OUTPUT_BASEPATH=${22} + +if [ $MODEL_SIZE = 7B ]; then + +NUM_LAYERS=32 +HIDDEN_SIZE=4096 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=14336 +MPE=32768 +SLW=4096 + +gqa_options=" \ + --group-query-attention \ + --num-query-groups 8" + +fi + +if [ $AC = full ]; then + activation_checkpoint_options=" \ + --recompute-method uniform \ + --recompute-granularity full" +elif [ $AC = sel ]; then + activation_checkpoint_options=" \ + --recompute-activations" +elif [ $AC = none ]; then + activation_checkpoint_options=" \ + " +fi + +if [ $PR = fp16 ]; then + pr_options=" \ + --fp16" +elif [ $PR = bf16 ]; then + pr_options=" \ + --bf16" +elif [ $PR = fp8 ]; then + pr_options=" \ + --bf16 + --fp8-hybrid \ + --fp8-amax-compute-algo max \ + --fp8-amax-history-len 1024 \ + --transformer-impl transformer_engine" +fi + +if [ $DO = true ]; then + do_options=" \ + --use-distributed-optimizer" + +elif [ $DO = false ]; then + do_options=" \ + " +fi + +if [ $FL = true ]; then + flash_options=" \ + --use-flash-attn" + +elif [ $FL = false ]; then + flash_options=" \ + " +fi + +if [ $TE = true ]; then + te_options=" \ + --transformer-impl transformer_engine" + +elif [ $TE = false ]; then + te_options=" \ + " +fi + +if [ $SP = true ] && [ $TP -gt 1 ]; then + sp_options=" \ + --sequence-parallel" + +elif [ $SP = false ]; then + sp_options=" \ + " +fi + +if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then + load_options=" \ + --load $PRETRAIN_CHECKPOINT_PATH" +fi + +FT_NAME="${ENV}-finetune-megatron-llama-${MODEL_SIZE}-lr-${LR}-ep-${EPOCH}-bs-${BATCH_SIZE}-seqlen-${SEQ_LEN}-pr-${PR}--do-${DO}-tp-${TP}-ac-${AC}-sp-${SP}" +mkdir -p "${OUTPUT_BASEPATH}/tensorboard/" +mkdir -p "${OUTPUT_BASEPATH}/checkpoint/" +mkdir -p "${OUTPUT_BASEPATH}/log/" +current_time=$(date "+%Y.%m.%d-%H.%M.%S") +TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${FT_NAME}_${current_time}" +mkdir -p ${TENSORBOARD_DIR} + +FINETUNE_CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${FT_NAME}" + +megatron_options=" \ + --load ${PRETRAIN_CHECKPOINT_PATH} \ + --save ${FINETUNE_CHECKPOINT_PATH} \ + --train-data-path ${TRAIN_DATASET_PATH} \ + --valid-data-path ${VALID_DATASET_PATH} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTN_HEADS} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${MPE} \ + --ffn-hidden-size ${INTERMEDIATE_SIZE} \ + --keep-last \ + --micro-batch-size ${BATCH_SIZE} \ + --epochs ${EPOCH} \ + --lr ${LR} \ + --min-lr ${MIN_LR} \ + --lr-decay-style cosine \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.01 \ + --num-workers 0\ + --log-interval 1 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --save-interval 1000000 \ + --tensorboard-queue-size 1 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --finetune \ + --no-load-optim \ + --no-load-rng \ + --seed 1234 \ + --max-padding-length ${PAD_LEN} \ + --extra-vocab-size ${EXTRA_VOCAB_SIZE} \ + --patch-tokenizer-type MistralTokenizer \ + --dataset Mistral-SFT \ + --sliding-window ${SLW} \ + --swiglu \ + --normalization RMSNorm \ + --use-mistral-rotary-position-embeddings \ + --position-embedding-type rope \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --router-type topk \ + --expert-interval 1 \ + --num-experts 8 \ + --moe-topk 2 + " + +run_cmd="torchrun $DISTRIBUTED_ARGS finetune_megatron_mixtral.py + ${megatron_options} ${pr_options} ${load_options} ${te_options} ${activation_checkpoint_options} ${do_options} ${flash_options} ${sp_options} ${gqa_options}" + +echo ${run_cmd} +eval ${run_cmd} +set +x diff --git a/examples/mixtral/run_pretrain_megatron_mixtral.sh b/examples/mixtral/run_pretrain_megatron_mixtral.sh new file mode 100644 index 00000000..24369e29 --- /dev/null +++ b/examples/mixtral/run_pretrain_megatron_mixtral.sh @@ -0,0 +1,218 @@ +#!/bin/bash +#sh run_pretrain_megatron_mixtral.sh dsw ../.. 0.125B 1 8 1e-5 1e-6 2048 2048 0 bf16 8 1 sel true false false false 100000 /mnt/llama2-datasets/alpaca_data.json /mnt/mistral-ckpts/Mistral-7B-v0.1 10000000000 100000000 /mnt/output_patch_test +set -e +ENV=$1 +MEGATRON_PATCH_PATH=$2 +MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main +export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 +if [ $ENV = dsw ]; then +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=8 + +elif [ $ENV = dlc ]; then + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + +fi + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +MODEL_SIZE=$3 +BATCH_SIZE=$4 +GLOBAL_BATCH_SIZE=$5 +LR=$6 +MIN_LR=$7 +SEQ_LEN=$8 +PAD_LEN=$9 +EXTRA_VOCAB_SIZE=${10} +PR=${11} +TP=${12} +PP=${13} +AC=${14} +DO=${15} +FL=${16} +SP=${17} +TE=${18} +SAVE_INTERVAL=${19} +DATASET_PATH=${20} +PRETRAIN_CHECKPOINT_PATH=${21} +TRAIN_TOKENS=${22} +WARMUP_TOKENS=${23} +OUTPUT_BASEPATH=${24} + +if [ $MODEL_SIZE = 0.125B ]; then + +NUM_LAYERS=12 +HIDDEN_SIZE=768 +NUM_ATTN_HEADS=12 +INTERMEDIATE_SIZE=3072 +MPE=32768 +SLW=4096 + + +elif [ $MODEL_SIZE = 7B ]; then + +NUM_LAYERS=32 +HIDDEN_SIZE=4096 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=14336 +MPE=32768 +SLW=4096 + +gqa_options=" \ + --group-query-attention \ + --num-query-groups 8" + +fi + +if [ $AC = full ]; then + activation_checkpoint_options=" \ + --recompute-method uniform \ + --recompute-granularity full" +elif [ $AC = sel ]; then + activation_checkpoint_options=" \ + --recompute-activations" +elif [ $AC = none ]; then + activation_checkpoint_options=" \ + " +fi + +if [ $PR = fp16 ]; then + pr_options=" \ + --fp16" +elif [ $PR = bf16 ]; then + pr_options=" \ + --bf16" +elif [ $PR = fp8 ]; then + pr_options=" \ + --bf16 + --fp8-hybrid \ + --fp8-amax-compute-algo max \ + --fp8-amax-history-len 1024 \ + --transformer-impl transformer_engine" +fi + +if [ $DO = true ]; then + do_options=" \ + --use-distributed-optimizer" + +elif [ $DO = false ]; then + do_options=" \ + " +fi + +if [ $FL = true ]; then + flash_options=" \ + --use-flash-attn" + +elif [ $FL = false ]; then + flash_options=" \ + " +fi + +if [ $TE = true ]; then + te_options=" \ + --transformer-impl transformer_engine" + +elif [ $TE = false ]; then + te_options=" \ + " +fi + +if [ $SP = true ] && [ $TP -gt 1 ]; then + sp_options=" \ + --sequence-parallel" + +elif [ $SP = false ]; then + sp_options=" \ + " +fi + +if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then + load_options=" \ + --load $PRETRAIN_CHECKPOINT_PATH" +fi + +TRAIN_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} )) +LR_WARMUP_ITERS=$(( ${WARMUP_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} )) +LR_DECAY_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} )) + +NAME="${ENV}-pretrain-megatron-gpt3-${MODEL_SIZE}-lr-${LR}-bs-${BATCH_SIZE}-seqlen-${SEQ_LEN}-pr-${PR}-tp-${TP}-pp-${PP}-ac-${AC}-do-${DO}-sp-${SP}-tt-${TRAIN_TOKENS}-wt-${WARMUP_TOKENS}" +mkdir -p "${OUTPUT_BASEPATH}/tensorboard/" +mkdir -p "${OUTPUT_BASEPATH}/checkpoint/" +mkdir -p "${OUTPUT_BASEPATH}/log/" +current_time=$(date "+%Y.%m.%d-%H.%M.%S") +TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${NAME}_${current_time}" +mkdir -p ${TENSORBOARD_DIR} + +SAVED_PRETRAIN_CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${NAME}" + +megatron_options=" \ + --save ${SAVED_PRETRAIN_CHECKPOINT_PATH} \ + --train-data-path ${DATASET_PATH} \ + --valid-data-path ${DATASET_PATH} \ + --test-data-path ${DATASET_PATH} \ + --lr ${LR} \ + --min-lr ${MIN_LR} \ + --lr-decay-style linear \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --init-method-std 0.006 \ + --lr-decay-iters ${LR_DECAY_ITERS} \ + --lr-warmup-iters ${LR_WARMUP_ITERS} \ + --train-iters ${TRAIN_ITERS} \ + --micro-batch-size ${BATCH_SIZE} \ + --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTN_HEADS} \ + --ffn-hidden-size ${INTERMEDIATE_SIZE} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${MPE} \ + --log-interval 1 \ + --eval-interval 10000 \ + --eval-iters 10 \ + --save-interval ${SAVE_INTERVAL} \ + --tensorboard-queue-size 1 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --no-load-optim \ + --no-load-rng \ + --num-workers 8 \ + --seed 1234 \ + --max-padding-length ${PAD_LEN} \ + --extra-vocab-size ${EXTRA_VOCAB_SIZE} \ + --patch-tokenizer-type MistralTokenizer \ + --dataset Mistral-Pretrain-Raw \ + --sliding-window ${SLW} \ + --swiglu \ + --normalization RMSNorm \ + --use-mistral-rotary-position-embeddings \ + --position-embedding-type rope \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --router-type topk \ + --expert-interval 1 \ + --num-experts 8 \ + --moe-topk 2 + " + +run_cmd="torchrun $DISTRIBUTED_ARGS pretrain_megatron_mixtral.py + ${megatron_options} ${pr_options} ${load_options} ${te_options} ${activation_checkpoint_options} ${do_options} ${flash_options} ${sp_options} ${gqa_options}" + +echo ${run_cmd} +eval ${run_cmd} +set +x diff --git a/examples/qwen_vl/finetune_megatron_qwen_vl.py b/examples/qwen_vl/finetune_megatron_qwen_vl.py new file mode 100644 index 00000000..d905bfb4 --- /dev/null +++ b/examples/qwen_vl/finetune_megatron_qwen_vl.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import torch + +from megatron import get_args +from megatron.core import parallel_state, tensor_parallel +from megatron.initialize import initialize_megatron +from megatron.utils import average_losses_across_data_parallel_group +from megatron.utils import get_ltor_masks_and_position_ids +from megatron_patch.finetune_utils import finetune +from megatron_patch.model.qwen_vl.gpt_model import GPTModel +from megatron_patch.tokenizer import get_tokenizer +from megatron_patch.arguments import get_tasks_args +from megatron.arguments import core_transformer_config_from_args +from megatron_patch.data import build_finetune_dataset +from megatron_patch.data.llava.constants import IGNORE_INDEX + +def model_provider(pre_process=True, post_process=True): + + config = core_transformer_config_from_args(get_args()) + model = GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + return model + +def train_valid_datasets_provider(): + args = get_args() + train_dataset, valid_dataset = build_finetune_dataset(args.dataset) + return train_dataset, valid_dataset + + +def forward_step(data_iterator, model): + args = get_args() + tokenizer = get_tokenizer() + + try: + data_iterator = next(data_iterator) + except BaseException: + data_iterator = data_iterator + + tokens = data_iterator['input_ids'].long().cuda().contiguous() + labels = data_iterator['labels'].long().cuda().contiguous() + tokens = tokens[:, :-1].contiguous() + labels = labels[:, 1:].contiguous() + attention_mask = tokens.ne(tokenizer.pad_token_id) + _, loss_mask, position_ids = get_ltor_masks_and_position_ids( + labels, + IGNORE_INDEX, + args.reset_position_ids, + args.reset_attention_mask, + True) + logits = model(input_ids=tokens, + position_ids=position_ids, + attention_mask=attention_mask) + + def loss_func(loss_mask, logits): + losses = tensor_parallel.vocab_parallel_cross_entropy( + logits.contiguous().float(), labels.contiguous()) + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + averaged_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'lm loss': averaged_loss[0]} + + return logits, partial(loss_func, loss_mask) + + +if __name__ == '__main__': + + initialize_megatron(extra_args_provider=get_tasks_args) + + finetune(train_valid_datasets_provider=train_valid_datasets_provider, + model_provider=model_provider, + forward_step=forward_step) diff --git a/examples/qwen_vl/run_finetune_megatron_qwen_vl.sh b/examples/qwen_vl/run_finetune_megatron_qwen_vl.sh new file mode 100644 index 00000000..ce48c0bb --- /dev/null +++ b/examples/qwen_vl/run_finetune_megatron_qwen_vl.sh @@ -0,0 +1,221 @@ +#!/bin/bash +#sh run_finetune_megatron_qwen_vl.sh dsw /workspace/Alibaba-internal/PAI-Megatron-Patch 7B 1 1e-3 1e-4 2048 2048 0 bf16 1 1 sel true false true false /mnt/qwen-datasets/qwenvl_train.json /mnt/qwen-datasets/qwenvl_train.json /mnt/qwen-ckpts/Qwen-VL-Chat-small 1 /mnt/output_patch_test +set -e +ENV=$1 +MEGATRON_PATCH_PATH=$2 +MEGATRON_PATH=${MEGATRON_PATCH_PATH}/Megatron-LM-main +export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATCH_PATH}:$PYTHONPATH +export CUDA_DEVICE_MAX_CONNECTIONS=1 +if [ $ENV = dsw ]; then +export CUDA_VISIBLE_DEVICES=7 +MASTER_ADDR=localhost +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=1 + +elif [ $ENV = dlc ]; then + +NNODES=${WORLD_SIZE} +NODE_RANK=${RANK} +GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU} + +fi + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +MODEL_SIZE=$3 #7B, 13B, 70B +BATCH_SIZE=$4 +LR=$5 +MIN_LR=$6 +SEQ_LEN=$7 +PAD_LEN=$8 +EXTRA_VOCAB_SIZE=$9 +PR=${10} +TP=${11} +PP=${12} +AC=${13} +DO=${14} +FL=${15} +SP=${16} +TE=${17} +TRAIN_DATASET_PATH=${18} +VALID_DATASET_PATH=${19} +PRETRAIN_CHECKPOINT_PATH=${20} +EPOCH=${21} +OUTPUT_BASEPATH=${22} + + +if [ $MODEL_SIZE = 7B ]; then + +NUM_LAYERS=2 +HIDDEN_SIZE=2048 +NUM_ATTN_HEADS=32 +INTERMEDIATE_SIZE=11008 + +gqa_options="" + +elif [ $MODEL_SIZE = 13B ]; then + +NUM_LAYERS=40 +HIDDEN_SIZE=5120 +NUM_ATTN_HEADS=40 +INTERMEDIATE_SIZE=13824 + +gqa_options="" + +elif [ $MODEL_SIZE = 70B ]; then + +NUM_LAYERS=80 +HIDDEN_SIZE=8192 +NUM_ATTN_HEADS=64 +INTERMEDIATE_SIZE=28672 + +gqa_options=" \ + --group-query-attention \ + --num-query-groups 8" + +fi + +if [ $AC = full ]; then + activation_checkpoint_options=" \ + --recompute-method uniform \ + --recompute-granularity full" +elif [ $AC = sel ]; then + activation_checkpoint_options=" \ + --recompute-activations" +elif [ $AC = none ]; then + activation_checkpoint_options=" \ + " +fi + +if [ $PR = fp16 ]; then + pr_options=" \ + --fp16" +elif [ $PR = bf16 ]; then + pr_options=" \ + --bf16" +elif [ $PR = fp8 ]; then + pr_options=" \ + --bf16 + --fp8-hybrid \ + --fp8-amax-compute-algo max \ + --fp8-amax-history-len 1024 \ + --transformer-impl transformer_engine" +fi + +if [ $DO = true ]; then + do_options=" \ + --use-distributed-optimizer" + +elif [ $DO = false ]; then + do_options=" \ + " +fi + +if [ $FL = true ]; then + flash_options=" \ + --use-flash-attn" + +elif [ $FL = false ]; then + flash_options=" \ + " +fi + +if [ $TE = true ]; then + te_options=" \ + --transformer-impl transformer_engine" + +elif [ $TE = false ]; then + te_options=" \ + " +fi + +if [ $SP = true ] && [ $TP -gt 1 ]; then + sp_options=" \ + --sequence-parallel" + +elif [ $SP = false ]; then + sp_options=" \ + " +fi + +if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then + load_options=" \ + --load $PRETRAIN_CHECKPOINT_PATH" +fi + +FT_NAME="${ENV}-finetune-megatron-llama-${MODEL_SIZE}-lr-${LR}-ep-${EPOCH}-bs-${BATCH_SIZE}-seqlen-${SEQ_LEN}-pr-${PR}--do-${DO}-tp-${TP}-ac-${AC}-sp-${SP}" +mkdir -p "${OUTPUT_BASEPATH}/tensorboard/" +mkdir -p "${OUTPUT_BASEPATH}/checkpoint/" +mkdir -p "${OUTPUT_BASEPATH}/log/" +current_time=$(date "+%Y.%m.%d-%H.%M.%S") +TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${FT_NAME}_${current_time}" +mkdir -p ${TENSORBOARD_DIR} + +FINETUNE_CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${FT_NAME}" + +megatron_options=" \ + --save ${FINETUNE_CHECKPOINT_PATH} \ + --image-folder /mnt/llava-datasets/LLaVA-Pretrain/images \ + --vision-tower /mnt/openai-ckpts/clip-vit-large-patch14-336 \ + --image-size 336 \ + --patch-size 14 \ + --version plain \ + --mm-projector-type mlp2x_gelu \ + --tune-mm-mlp-adapter \ + --train-data-path ${TRAIN_DATASET_PATH} \ + --valid-data-path ${VALID_DATASET_PATH} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size ${HIDDEN_SIZE} \ + --num-attention-heads ${NUM_ATTN_HEADS} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings ${SEQ_LEN} \ + --ffn-hidden-size ${INTERMEDIATE_SIZE} \ + --keep-last \ + --micro-batch-size ${BATCH_SIZE} \ + --epochs ${EPOCH} \ + --lr ${LR} \ + --min-lr ${MIN_LR} \ + --lr-decay-style cosine \ + --lr-warmup-fraction 0.03 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.01 \ + --num-workers 0\ + --log-interval 1 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --save-interval 1000000 \ + --tensorboard-queue-size 1 \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --finetune \ + --no-load-optim \ + --no-load-rng \ + --seed 1234 \ + --max-padding-length ${PAD_LEN} \ + --extra-vocab-size ${EXTRA_VOCAB_SIZE} \ + --patch-tokenizer-type QwenVLTokenizer \ + --dataset Qwen-VL-SFT \ + --swiglu \ + --normalization RMSNorm \ + --use-llama2-rotary-position-embeddings \ + --position-embedding-type rope \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --no-gradient-accumulation-fusion \ + " + +run_cmd="torchrun $DISTRIBUTED_ARGS finetune_megatron_qwen_vl.py + ${megatron_options} ${pr_options} ${load_options} ${te_options} ${activation_checkpoint_options} ${do_options} ${flash_options} ${sp_options} ${gqa_options}" + +echo ${run_cmd} +eval ${run_cmd} +set +x diff --git a/megatron_patch/arguments.py b/megatron_patch/arguments.py index 3a879d4c..a0b44c67 100644 --- a/megatron_patch/arguments.py +++ b/megatron_patch/arguments.py @@ -14,6 +14,18 @@ import argparse +def validate_moe_args(args, defaults={}): + if args.num_experts is not None: + args.moe = True + if args.moe_expert_parallel_size is None: + args.moe_expert_parallel_size = args.data_parallel_size + if args.tensor_model_parallel_size > 0 and not args.expert_tensor_parallelism: + # EP will use the span of DP*TP + args.moe_expert_parallel_size *= args.tensor_model_parallel_size + if args.rank == 0: + print('Experts set to %s, expert parallel size set to %d' + % (str(args.num_experts), args.moe_expert_parallel_size)) + def get_tasks_args(parser): group = parser.add_argument_group(title='patch') @@ -278,5 +290,42 @@ def get_tasks_args(parser): group.add_argument('--cvcuda-image-processing', action='store_true') - + + group.add_argument('--expert-tensor-parallelism', action='store_true', + default=False, + help="use tensor parallelism for expert layers in MoE") + + group.add_argument('--expert-interval', type=int, default=2, + help='Use experts in every "expert-interval" layers') + + group.add_argument('--moe-topk', type=int, default=1, + help='moe-topk') + + group.add_argument('--moe-expert-parallel-size', type=int, default=None, + help='Degree of the MoE expert parallelism. By default, ' + 'the size of this value will be automatically determined.') + + group.add_argument('--disable-moe-token-dropping', action='store_false', + help='Disable MoE expert token dropping.', + dest='moe_token_dropping') + + group.add_argument('--moe-train-capacity-factor', type=float, default=1.0, + help='The capacity of the MoE expert at training time') + + group.add_argument('--moe-eval-capacity-factor', type=float, default=1.0, + help='The capacity of the MoE expert at eval time.') + + group.add_argument('--moe-min-capacity', type=int, default=4, + help='The minimum capacity per MoE expert regardless of the capacity_factor.') + + group.add_argument('--moe-loss-coeff', type=float, default=0.01, + help='Scaling coefficient for adding MoE loss to model loss') + + group.add_argument('--use-tutel', action='store_true', + help='Use Tutel optimization for MoE') + + group.add_argument('--router-type', type=str, default='topk', + choices=['topk', 'expert_choice'], + help='Options for router type, support top1 & top2 and expert_choice') + return parser diff --git a/megatron_patch/checkpointing.py b/megatron_patch/checkpointing.py index 789907f0..c1cc7e2a 100644 --- a/megatron_patch/checkpointing.py +++ b/megatron_patch/checkpointing.py @@ -15,7 +15,7 @@ import os import random import sys - +from collections import defaultdict import numpy as np import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP @@ -149,7 +149,7 @@ def fix_query_key_value_ordering(model, checkpoint_version): ' checkpoint version {}'.format(checkpoint_version)) -def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False): +def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, model=None): """ Load the base state_dict from the given directory If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. @@ -195,7 +195,13 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False): try: model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') - optim_state_dict = None + if not args.no_load_optim: + if use_distributed_optimizer or args.moe: + optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu') + else: + optim_state_dict = model_state_dict + else: + optim_state_dict = None except ModuleNotFoundError: # For backward compatibility. if not rank0: @@ -215,6 +221,19 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False): print_rank_0(e) sys.exit() + # Load MoE + if args.moe: + if args.expert_tensor_parallelism and \ + mpu.get_tensor_model_parallel_world_size() > 1: + # expert with tensor parallel, save to the mp_rank dir. + moe_checkpoint_dir = os.path.dirname(model_checkpoint_name) + else: + # save to the root dir. + moe_checkpoint_dir = os.path.dirname(os.path.dirname(model_checkpoint_name)) + _load_moe_state_dict(moe_checkpoint_dir, model_state_dict['model'], + model_list=model, mpu=mpu) + + return model_state_dict, optim_state_dict, release, optim_checkpoint_name @@ -236,7 +255,8 @@ def load_checkpoint(model, _load_base_checkpoint( load_dir, use_distributed_optimizer=args.use_distributed_optimizer, - rank0=False) + rank0=False, + model=model) if model_state_dict is None: return 0 @@ -445,6 +465,18 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): (torchDDP, LocalDDP, Float16Module)) unwrapped_model.save_pretrained(checkpoint_dir) + if args.moe: + if args.expert_tensor_parallelism and \ + mpu.get_tensor_model_parallel_world_size() > 1: + # expert with tensor parallel, save to the mp_rank dir. + moe_checkpoint_dir = os.path.dirname(model_checkpoint_name) + else: + # save to the root dir. + moe_checkpoint_dir = os.path.dirname(os.path.dirname(model_checkpoint_name)) + print_rank_0(' save moe checkpoints to {}'.format(moe_checkpoint_dir)) + ensure_directory_exists(moe_checkpoint_dir) + _save_moe_checkpoint(moe_checkpoint_dir, model) + # Wait so everyone is done (necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() @@ -462,4 +494,104 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): # Wait so everyone is done (not necessary) if torch.distributed.is_initialized(): - torch.distributed.barrier() \ No newline at end of file + torch.distributed.barrier() + + +def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu): + args = get_args() + mp_rank = 0 + if args.expert_tensor_parallelism: + mp_rank = mpu.get_tensor_model_parallel_rank() + # Used to support expert saving and loading. + ckpt_name = os.path.join( + checkpoints_path, + '' if tag is None else str(tag), + f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt' + ) + return ckpt_name + +def _load_moe_state_dict(checkpoint_path, state_dict, model_list, mpu): + from megatron_patch.model.mixtral.layer import MoE + from megatron_patch import expert_parallel_state + moe_state_dict = state_dict['language_model'].setdefault('moe_state_dict', {}) + + # Loop through all the models in the list + for model in model_list: + # Loop through all the modules in the model + for _, module in model.named_modules(): + # Check if the module is an MoE layer + if isinstance(module, MoE): + moe_layer_index = module.get_moe_layer_index() + num_local_experts = module.num_local_experts + + # Get the rank of the current process and calculate the global expert ID + ep_rank = torch.distributed.get_rank(group=expert_parallel_state.get_expert_parallel_group()) + + # Loop through all the local experts + for local_expert_id in range(num_local_experts): + # Calculate the name of the checkpoint file and load the expert state dictionary + global_expert_id = ep_rank * num_local_experts + local_expert_id + expert_ckpt_name = _get_expert_ckpt_name(checkpoint_path, moe_layer_index, + global_expert_id, None, mpu) + expert_state_dict = torch.load(expert_ckpt_name, + map_location=torch.device('cpu')) + + # Update the expert state dictionary with the local expert ID + moe_str_prefix = '.megatron_moe.experts.megatron_experts.' + for key in list(expert_state_dict.keys()): + local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', + f'{moe_str_prefix}{local_expert_id}') + expert_state_dict[local_key] = expert_state_dict.pop(key) + + # Update the MoE state dictionary with the expert state dictionary + moe_state_dict.update(expert_state_dict) + +def _save_moe_checkpoint(save_dir, model_list): + # Using layer_#_export_# to save the model's expert state_dict + from megatron_patch.model.mixtral.layer import MoE + import re + moe_layer_id = 0 + + # Loop through all the models in the list + for model in model_list: + # Loop through all the modules in the model + for name, module in model.named_modules(): + # Check if the module is an MoE layer + if isinstance(module, MoE): + moe_layer_id = module.get_moe_layer_index() + num_local_experts = module.num_local_experts + ep_rank = torch.distributed.get_rank(group=mpu.get_expert_parallel_group()) + + # Extract the state dict of MoE experts + moe_state_dict = { + f"{name}.{n}": p + for n, p in module.state_dict().items() + if "expert" in n and "moe.gate.wg.weight" not in n + } + + # Loop through all the experts and update the state dict with global expert IDs + experts_state_dict = defaultdict(dict) + for key in list(moe_state_dict.keys()): + match = re.match(f".*{name}.megatron_moe.experts.megatron_experts.([0-9]+).*", + key) + + if match is None: + print(f"No expert found in key {key}.") + continue + + local_expert_id = match.group(1) + global_expert_id = ep_rank * num_local_experts + int(local_expert_id) + + expert_key = key.replace( + f"{name}.megatron_moe.experts.megatron_experts.{local_expert_id}", + f"{name}.megatron_moe.experts.megatron_experts.{global_expert_id}") + + experts_state_dict[str(global_expert_id)][expert_key] = moe_state_dict.pop(key) + + # Save the expert state dictionaries + for global_expert_id, expert_state_dict in experts_state_dict.items(): + save_path = _get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, + None, mpu) + torch.save(expert_state_dict, save_path) + + moe_layer_id += 1 \ No newline at end of file diff --git a/megatron_patch/data/__init__.py b/megatron_patch/data/__init__.py index bda8cdf3..a5d0d899 100644 --- a/megatron_patch/data/__init__.py +++ b/megatron_patch/data/__init__.py @@ -15,16 +15,14 @@ import numpy as np from megatron.data.dataset_utils import get_datasets_weights_and_num_samples -from megatron.data.blendable_dataset import BlendableDataset from megatron import print_rank_0 -from megatron.data.gpt_dataset import get_indexed_dataset_ -from megatron.data.gpt_dataset import get_train_valid_test_split_ from megatron import get_args from megatron_patch.tokenizer import build_tokenizer from .mistral import MistralRawDataset, MistralIdxMapDataset from .llama import LLamaRawDataset, LLamaIdxMapDataset -from .llava.mm_pretrain_dataset import LazySupervisedDataset +from .llava.mm_pretrain_dataset import LazySupervisedDataset as LLavaSupervisedDataset +from .qwen_vl import LazySupervisedDataset as QwenVLSupervisedDataset def build_evaluation_dataset(dataset): @@ -54,10 +52,17 @@ def build_finetune_dataset(dataset): return train_dataset, valid_dataset elif dataset == 'LLava-SFT': - train_dataset = LazySupervisedDataset(args.train_data_path) - valid_dataset = LazySupervisedDataset(args.valid_data_path) + train_dataset = LLavaSupervisedDataset(args.train_data_path) + valid_dataset = LLavaSupervisedDataset(args.valid_data_path) return train_dataset, valid_dataset + + elif dataset == 'Qwen-VL-SFT': + train_dataset = QwenVLSupervisedDataset(args.train_data_path) + valid_dataset = QwenVLSupervisedDataset(args.valid_data_path) + + return train_dataset, valid_dataset + else: raise NotImplementedError('dataset {} is not implemented.'.format(dataset)) @@ -81,9 +86,9 @@ def build_pretrain_dataset_from_original(dataset): return train_dataset, valid_dataset, test_dataset elif dataset == 'LLava-Pretrain-Raw': - train_dataset = LazySupervisedDataset(args.train_data_path) - valid_dataset = LazySupervisedDataset(args.valid_data_path) - test_dataset = LazySupervisedDataset(args.test_data_path) + train_dataset = LLavaSupervisedDataset(args.train_data_path) + valid_dataset = LLavaSupervisedDataset(args.valid_data_path) + test_dataset = LLavaSupervisedDataset(args.test_data_path) return train_dataset, valid_dataset, test_dataset @@ -148,6 +153,7 @@ def build_pretrain_dataset_from_idxmap(data_prefix, # Blend. blending_train_dataset = None + from megatron.data.blendable_dataset import BlendableDataset if train_datasets: blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples) blending_valid_dataset = None @@ -164,6 +170,8 @@ def _build_train_valid_test_datasets(data_prefix, max_padding_length, dataset_ty train_valid_test_num_samples, seed, skip_warmup, return_doc_ids=False): + from megatron.data.gpt_dataset import get_indexed_dataset_ + from megatron.data.gpt_dataset import get_train_valid_test_split_ # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] diff --git a/megatron_patch/data/llama.py b/megatron_patch/data/llama.py index 4097d9f9..95697550 100644 --- a/megatron_patch/data/llama.py +++ b/megatron_patch/data/llama.py @@ -17,7 +17,6 @@ import copy import json import torch -from megatron.data.gpt_dataset import _build_index_mappings from megatron import get_args from megatron_patch.tokenizer import get_tokenizer @@ -194,6 +193,7 @@ def __init__(self, assert np.max(documents) < indexed_dataset.sizes.shape[0] # Build index mappings. + from megatron.data.gpt_dataset import _build_index_mappings try: self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \ _build_index_mappings(self.name, data_prefix, diff --git a/megatron_patch/data/mistral.py b/megatron_patch/data/mistral.py index df47fd0a..95be938e 100644 --- a/megatron_patch/data/mistral.py +++ b/megatron_patch/data/mistral.py @@ -17,7 +17,6 @@ import copy import json import torch -from megatron.data.gpt_dataset import _build_index_mappings from megatron import get_args from megatron_patch.tokenizer import get_tokenizer @@ -195,6 +194,7 @@ def __init__(self, assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] + from megatron.data.gpt_dataset import _build_index_mappings # Build index mappings. try: self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \ diff --git a/megatron_patch/data/qwen_vl.py b/megatron_patch/data/qwen_vl.py new file mode 100644 index 00000000..32c419c8 --- /dev/null +++ b/megatron_patch/data/qwen_vl.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict +import torch +from torch.utils.data import Dataset +import transformers +from transformers.trainer_pt_utils import LabelSmoother +from megatron import get_args + +from megatron_patch.tokenizer import get_tokenizer + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, + max_len: int, + system_message: str = "You are a helpful assistant." +) -> Dict: + """ + Preprocess conversation data for the model input. + + Parameters: + sources (List[Dict]): A list of conversation segments. + tokenizer (PreTrainedTokenizer): A tokenizer instance. + max_len (int): The maximum sequence length. + system_message (str, optional): A default system message. + + Returns: + Dict: A dictionary with 'input_ids', 'labels', and 'attention_mask'. + """ + roles = {"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"} + im_start = tokenizer.im_start_id + im_end = tokenizer.im_end_id + nl_tokens = tokenizer('\n').input_ids + _system = tokenizer('system').input_ids + nl_tokens + _user = tokenizer('user').input_ids + nl_tokens + _assistant = tokenizer('assistant').input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["user"]: + source = source[1:] + + input_id, target = [], [] + system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens + input_id += system + target += [im_start] + [IGNORE_TOKEN_ID] * (len(system)-3) + [im_end] + nl_tokens + assert len(input_id) == len(target) + for sentence in enumerate(source): + role = roles[sentence["from"]] + _input_id = tokenizer(role).input_ids + nl_tokens + \ + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens + input_id += _input_id + if role == '<|im_start|>user': + _target = [im_start] + [IGNORE_TOKEN_ID] * (len(_input_id)-3) + [im_end] + nl_tokens + elif role == '<|im_start|>assistant': + _target = [im_start] + [IGNORE_TOKEN_ID] * len(tokenizer(role).input_ids) + \ + _input_id[len(tokenizer(role).input_ids)+1:-2] + [im_end] + nl_tokens + else: + raise NotImplementedError + target += _target + assert len(input_id) == len(target) + input_id += [tokenizer.pad_token_id] * (max_len - len(input_id)) + target += [IGNORE_TOKEN_ID] * (max_len - len(target)) + input_ids.append(input_id[:max_len]) + targets.append(target[:max_len]) + input_ids = torch.tensor(input_ids, dtype=torch.int) + targets = torch.tensor(targets, dtype=torch.int) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + +class LazySupervisedDataset(Dataset): + """ + A dataset class for supervised fine-tuning. + + Attributes: + data_path (str): Path to the JSON file containing the conversational data. + """ + + def __init__(self, data_path: str): + super(LazySupervisedDataset, self).__init__() + self.args = get_args() + self.tokenizer = get_tokenizer() + self.max_len = self.args.max_padding_length + self.raw_data = json.load(open(data_path[0], "r")) + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.max_len) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret \ No newline at end of file diff --git a/megatron_patch/distributed.py b/megatron_patch/distributed.py new file mode 100644 index 00000000..5fdef29f --- /dev/null +++ b/megatron_patch/distributed.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import torch +import math +from contextlib import contextmanager +from typing import Dict + +from megatron.core import mpu +from megatron.model.distributed import DistributedDataParallelBase +from megatron.model.distributed import GradBuffer + + +def is_moe_param(param: torch.Tensor) -> bool: + if hasattr(param, "allreduce") and not param.allreduce: + return True + return False + +class DistributedDataParallel(DistributedDataParallelBase): + """ + DDP wrapper which stores grads in contiguous buffers. Also has option of + overlapping communication with backprop computation by breaking up full model's + gradients into smaller buckets and running all-reduce / reduce-scatter + on each bucket asynchronously. + This class: + - has the potential to reduce memory fragmentation. + - provides the option to do the gradient accumulation + in a type other than the params type (e.g., fp32). + + Arguments: + module: input model. + data_parallel_group: data-parallel group. + accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation + and communication in float32. + overlap_grad_reduce: if true, overlap communication with backprop + computation by breaking up grads into buckets. If false, single + synchronous communication call is used instead. + use_distributed_optimizer: if true, issue reduce-scatter communication + calls as part of distributed optimizer. If false, issue all-reducde + communication calls. + + """ + + def __init__( + self, + module: torch.nn.Module, + data_parallel_group: torch.distributed.ProcessGroup, + accumulate_allreduce_grads_in_fp32: bool, + overlap_grad_reduce: bool, + use_distributed_optimizer: bool, + bucket_size: int = 40000000, + ): + super(DistributedDataParallel, self).__init__(module) + + # Set bucket_size to infinity if overlap_grad_reduce is False. + self.overlap_grad_reduce = overlap_grad_reduce + self.use_distributed_optimizer = use_distributed_optimizer + + if not self.overlap_grad_reduce: + bucket_size = None + self.bucket_size = bucket_size + + self.module = module + self.grad_buffers = {} + self.expert_params = [] + self.expert_grads = [] + self.grad_buffer_param_index_map = {} + self.param_to_grad_buffer = {} + + # Group parameters by their gradient type. + grad_dtype_to_params = {} + grad_dtype_to_numel = {} + param_to_name = {} + for name, param in self.module.named_parameters(): + if param.requires_grad and getattr(param, 'allreduce', True): + param.grad_added_to_main_grad = False + param_to_name[param] = name + dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype + + params = grad_dtype_to_params.get(dtype, []) + params.append(param) + grad_dtype_to_params[dtype] = params + + # Calculate number of elements per dtype. + grad_dtype_to_numel[dtype] = ( + grad_dtype_to_numel.get(dtype, 0) + param.data.nelement() + ) + + # Allocate the grad buffers and map the grads. + # The grad buffer under the hood creates buckets as appropriate, depending on + # whether overlap_grad_reduce is True or not. + data_parallel_world_size = torch.distributed.get_world_size(group=data_parallel_group) + for dtype, params in grad_dtype_to_params.items(): + # Pad so size is divisible by the data parallel size. + numel = grad_dtype_to_numel[dtype] + numel_padded = ( + int(math.ceil(numel / data_parallel_world_size)) * data_parallel_world_size + ) + + self.grad_buffers[dtype] = GradBuffer( + numel, + numel_padded, + dtype, + params, + data_parallel_group, + bucket_size, + param_to_name, + self.overlap_grad_reduce, + self.use_distributed_optimizer, + ) + + # Parameters are laid out in the corresponding grad_buffer in reverse + # order, so count indices from the back. + index = grad_dtype_to_numel[dtype] + for param in params: + self.param_to_grad_buffer[param] = self.grad_buffers[dtype] + if dtype not in self.grad_buffer_param_index_map: + self.grad_buffer_param_index_map[dtype] = {} + + index -= param.data.nelement() + # Store the indices / bucket of each param. + self.grad_buffer_param_index_map[dtype][param] = ( + index, + index + param.data.nelement(), + self.grad_buffers[dtype].param_to_bucket_index[param], + ) + + # Allocate discreate buffer for MoE params' grads + for param in self.module.parameters(): + if param.requires_grad and not getattr(param, 'allreduce', True): + dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype + param.main_grad = \ + torch.zeros(param.data.shape, + dtype=dtype, + device=torch.cuda.current_device(), + requires_grad=False) + self.expert_grads.append(param.main_grad) + self.expert_params.append(param) + + + # Register backward hook. + # Accumulation function for the gradients need to be stored so they + # don't go out of scope. + self.grad_accs = [] + for param in self.module.parameters(): + if param.requires_grad: + # Expand so we get access to grad_fn. + param_tmp = param.expand_as(param) + # Get the gradient accumulator function. + grad_acc = param_tmp.grad_fn.next_functions[0][0] + grad_acc.register_hook(self._make_param_hook(param, self.param_to_grad_buffer)) + self.grad_accs.append(grad_acc) + + def _make_param_hook( + self, param: torch.nn.Parameter, param_to_grad_buffer: Dict[torch.nn.Parameter, GradBuffer] + ): + """Create the all-reduce / reduce-scatter hook for backprop.""" + + def param_hook(*unused): + if param.requires_grad: + if self.overlap_grad_reduce: + assert ( + param.grad is not None + ), 'param.grad being None is not safe when overlap_grad_reduce is True' + if param.grad is not None and not param.grad_added_to_main_grad: + param.main_grad.add_(param.grad.data) + param.grad = None + if self.overlap_grad_reduce: + param_to_grad_buffer[param].mark_grad_as_done(param) + + return param_hook + + @contextmanager + def no_sync(self): + """Context manager that turns off gradient synchronization.""" + for grad_buffer in self.grad_buffers.values(): + grad_buffer.is_last_microbatch = False + try: + yield + finally: + for grad_buffer in self.grad_buffers.values(): + grad_buffer.is_last_microbatch = True + + def grad_sync(self, *unused): + """Method to dispatch grad sync operations.""" + for grad_buffer in self.grad_buffers.values(): + grad_buffer.grad_sync() + + def zero_grad_buffer(self): + """Set the grad buffer data to zero. Needs to be called at the + begining of each iteration.""" + for param in self.module.parameters(): + if param.requires_grad: + param.grad_added_to_main_grad = False + for grad_buffer in self.grad_buffers.values(): + grad_buffer.reset() + for expert_grad in self.expert_grads: + expert_grad.zero_() + + def broadcast_params(self): + """Sync params across all DP ranks.""" + for param in self.module.parameters(): + torch.distributed.broadcast( + param.data, + src=mpu.get_data_parallel_src_rank(), + group=mpu.get_data_parallel_group(), + ) + + def sync_gradients(self): + """ + Reduce gradients across data-parallel ranks. + When overlap_grad_reduce is set to True, waits for asynchronous + communication calls to complete. + When overlap_grad_reduce is set to False, calls synchronous + communication ops. + """ + for grad_buffer in self.grad_buffers.values(): + grad_buffer.done() diff --git a/megatron_patch/expert_parallel_state.py b/megatron_patch/expert_parallel_state.py new file mode 100644 index 00000000..f7101718 --- /dev/null +++ b/megatron_patch/expert_parallel_state.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Expert parallel groups.""" +import torch + +_EXPERT_PARALLEL_GROUP = None +_MPU_EXPERT_PARALLEL_WORLD_SIZE = None + +def initialize_moe_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_tensor_parallelism: bool = False +) -> None: + """Initialize model data parallel groups. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if ( + world_size + % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) + != 0 + ): + raise RuntimeError( + f"world_size ({world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) " + f"x context_parallel_size ({context_parallel_size})" + ) + + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + rank = torch.distributed.get_rank() + + global _EXPERT_PARALLEL_GROUP + assert _EXPERT_PARALLEL_GROUP is None, \ + 'expert parallel group is already initialized' + # Currently, data parallelism is not supported for experts. + if expert_tensor_parallelism: + # ETP + EP + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(tensor_model_parallel_size): + ranks = range(start_rank + j, end_rank, + tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _EXPERT_PARALLEL_GROUP = group + else: + # Pure EP + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + ranks = range(start_rank, end_rank) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _EXPERT_PARALLEL_GROUP = group + + +def get_expert_parallel_group(): + """Get the expert parallel group the caller rank belongs to.""" + assert _EXPERT_PARALLEL_GROUP is not None, \ + 'expert parallel group is not initialized' + return _EXPERT_PARALLEL_GROUP + + +def set_expert_parallel_world_size(world_size): + """Set the expert parallel size""" + global _MPU_EXPERT_PARALLEL_WORLD_SIZE + _MPU_EXPERT_PARALLEL_WORLD_SIZE = world_size + + +def get_expert_parallel_world_size(): + """Return world size for the expert parallel group.""" + global _MPU_EXPERT_PARALLEL_WORLD_SIZE + if _MPU_EXPERT_PARALLEL_WORLD_SIZE is not None: + return _MPU_EXPERT_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_expert_parallel_group()) \ No newline at end of file diff --git a/megatron_patch/initialize.py b/megatron_patch/initialize.py new file mode 100644 index 00000000..37c503f6 --- /dev/null +++ b/megatron_patch/initialize.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron initialization.""" + +import torch +from datetime import timedelta + +from megatron import get_args +from megatron.core import mpu, tensor_parallel +from megatron.arguments import parse_args, validate_args +from megatron.checkpointing import load_args_from_checkpoint +from megatron.global_vars import set_global_variables +from megatron.initialize import _set_random_seed, _init_autoresume +from megatron.initialize import _compile_dependencies + +from .arguments import validate_moe_args +import megatron_patch.expert_parallel_state as moe_mpu + + +def initialize_megatron( + extra_args_provider=None, + args_defaults={}, + ignore_unknown_args=False, + allow_no_cuda=False, +): + """Set global variables, initialize distributed, and + set autoresume and random seeds. + `allow_no_cuda` should not be set unless using megatron for cpu only + data processing. In general this arg should not be set unless you know + what you are doing. + Returns a function to finalize distributed env initialization + (optionally, only when args.lazy_mpu_init == True) + """ + if not allow_no_cuda: + # Make sure cuda is available. + assert torch.cuda.is_available(), "Megatron requires CUDA." + + # Parse arguments + args = parse_args(extra_args_provider, ignore_unknown_args) + + if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): + assert args.load is not None, "--use-checkpoints-args requires --load argument" + load_args_from_checkpoint(args) + + validate_args(args, args_defaults) + validate_moe_args(args, args_defaults) + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args) + + # torch.distributed initialization + def finish_mpu_init(): + args = get_args() + # Pytorch distributed. + _initialize_distributed() + + # Random seeds for reproducibility. + if args.rank == 0: + print("> setting random seeds to {} ...".format(args.seed)) + _set_random_seed(args.seed, args.data_parallel_random_init) + + args = get_args() + if args.lazy_mpu_init: + # TODO is this still a necessary option? + args.use_cpu_initialization = True + # delayed initialization of DDP-related stuff + # We only set basic DDP globals + mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) + # and return function for external DDP manager + # to call when it has DDP initialized + mpu.set_tensor_model_parallel_rank(args.rank) + return finish_mpu_init + else: + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + # No continuation function + return None + +def _initialize_distributed(): + """Initialize torch.distributed and core model parallel.""" + args = get_args() + + device_count = torch.cuda.device_count() + if torch.distributed.is_initialized(): + + if args.rank == 0: + print( + "torch distributed is already initialized, " + "skipping initialization ...", + flush=True, + ) + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + + else: + + if args.rank == 0: + print("> initializing torch distributed ...", flush=True) + # Manually set the device ids. + if device_count > 0: + device = args.rank % device_count + if args.local_rank is not None: + assert ( + args.local_rank == device + ), "expected local-rank to be the same as rank % device-count." + else: + args.local_rank = device + torch.cuda.set_device(device) + # Call the init process + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, + rank=args.rank, + timeout=timedelta(minutes=args.distributed_timeout_minutes), + ) + + # Set the tensor model-parallel, pipeline model-parallel, and + # data-parallel communicators. + if device_count > 0: + if mpu.model_parallel_is_initialized(): + print("model parallel is already initialized") + else: + mpu.initialize_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + ) + + moe_mpu.initialize_moe_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + expert_tensor_parallelism=args.expert_tensor_parallelism + ) + + if args.rank == 0: + print( + f"> initialized tensor model parallel with size " + f"{mpu.get_tensor_model_parallel_world_size()}" + ) + print( + f"> initialized pipeline model parallel with size " + f"{mpu.get_pipeline_model_parallel_world_size()}" + ) + + + + + + + diff --git a/megatron_patch/model/llava/language_model.py b/megatron_patch/model/llava/language_model.py index aaa54c1a..dfc1c447 100644 --- a/megatron_patch/model/llava/language_model.py +++ b/megatron_patch/model/llava/language_model.py @@ -480,53 +480,54 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, enc_hidden_states=None, output_enc_hidden=False, images=None): image_features = self.encode_images(images) + + input_embeds = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + input_embeds = input_embeds.permute(1, 0, 2) + new_input_embeds = [] - cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(enc_input_ids): - cur_position_ids = enc_position_ids[batch_idx] + cur_input_embeds = input_embeds[batch_idx] # s h if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal # FIXME: this is a hacky fix, for deepspeed zero3 to work half_len = cur_input_ids.shape[0] // 2 - cur_image_features = image_features[cur_image_idx] - cur_input_embeds_1 = self.embedding(cur_input_ids[:half_len].unsqueeze(0), cur_position_ids[:half_len].unsqueeze(0)) - cur_input_embeds_2 = self.embedding(cur_input_ids[half_len:].unsqueeze(0), cur_position_ids[half_len:].unsqueeze(0)) + cur_image_features = image_features[batch_idx] + cur_input_embeds_1 = cur_input_embeds[:half_len].unsqueeze(1) + cur_input_embeds_2 = cur_input_embeds[half_len:].unsqueeze(1) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0) new_input_embeds.append(cur_input_embeds) - cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] + cur_start = 0 while image_token_indices.numel() > 0: - cur_image_features = image_features[cur_image_idx].unsqueeze(1) + cur_image_features = image_features[batch_idx].unsqueeze(1) image_token_start = image_token_indices[0] if getattr(self.args, 'tune_mm_mlp_adapter', False) and getattr(self.args, 'mm_use_im_start_end', False): - cur_new_input_embeds.append(self.embedding(cur_input_ids[:image_token_start-1].unsqueeze(0), cur_position_ids[:image_token_start-1].unsqueeze(0)).detach()) - cur_new_input_embeds.append(self.embedding(cur_input_ids[image_token_start-1:image_token_start].unsqueeze(0), cur_position_ids[image_token_start-1:image_token_start].unsqueeze(0))) + cur_new_input_embeds.append(cur_input_embeds[cur_start:image_token_start-1].unsqueeze(1).detach()) + cur_new_input_embeds.append(cur_input_embeds[image_token_start-1:image_token_start].unsqueeze(1)) # special token: cur_new_input_embeds.append(cur_image_features) - cur_new_input_embeds.append(self.embedding(cur_input_ids[image_token_start+1:image_token_start+2].unsqueeze(0), cur_position_ids[image_token_start+1:image_token_start+2].unsqueeze(0))) - + cur_new_input_embeds.append(cur_input_embeds[image_token_start+1:image_token_start+2].unsqueeze(1)) # special token: + cur_start = image_token_start + 2 else: - cur_new_input_embeds.append(self.embedding(cur_input_ids[:image_token_start].unsqueeze(0), cur_position_ids[:image_token_start].unsqueeze(0))) + cur_new_input_embeds.append(cur_input_embeds[cur_start:image_token_start].unsqueeze(1)) cur_new_input_embeds.append(cur_image_features) + cur_start = image_token_start + 1 - cur_image_idx += 1 + image_token_indices = torch.where(cur_input_ids[cur_start:] == IMAGE_TOKEN_INDEX)[0] + if cur_input_ids[cur_start:].numel() > 0: if getattr(self.args, 'tune_mm_mlp_adapter', False) and getattr(self.args, 'mm_use_im_start_end', False): - cur_input_ids = cur_input_ids[image_token_start+2:] + cur_new_input_embeds.append(cur_input_embeds[cur_start:].unsqueeze(1).detach()) else: - cur_input_ids = cur_input_ids[image_token_start+1:] - image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] - if cur_input_ids.numel() > 0: - if getattr(self.args, 'tune_mm_mlp_adapter', False) and getattr(self.args, 'mm_use_im_start_end', False): - cur_new_input_embeds.append(self.embedding(cur_input_ids.unsqueeze(0), cur_position_ids.unsqueeze(0)).detach()) - else: - cur_new_input_embeds.append(self.embedding(cur_input_ids.unsqueeze(0), cur_position_ids.unsqueeze(0))) + cur_new_input_embeds.append(cur_input_embeds[cur_start:].unsqueeze(1)) cur_new_input_embeds = [x.to(device=enc_input_ids.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) encoder_input = torch.cat(new_input_embeds, dim=1) + if enc_attn_mask is not None: batch_size = enc_input_ids.shape[0] new_enc_attn_mask = _prepare_4d_causal_attention_mask( diff --git a/megatron_patch/model/mixtral/__init__.py b/megatron_patch/model/mixtral/__init__.py new file mode 100644 index 00000000..1f6175dc --- /dev/null +++ b/megatron_patch/model/mixtral/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/megatron_patch/model/mixtral/all2all.py b/megatron_patch/model/mixtral/all2all.py new file mode 100644 index 00000000..8110378c --- /dev/null +++ b/megatron_patch/model/mixtral/all2all.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from megatron_patch.expert_parallel_state import get_expert_parallel_world_size + +class _AllToAll(torch.autograd.Function): + @staticmethod + def forward(ctx, + group, + input, + output_split_sizes, + input_split_sizes): + ctx.group = group + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + + world_size = get_expert_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input + + input = input.contiguous() + if output_split_sizes is None: + # Equal split (all2all) + output = torch.empty_like(input) + else: + # Unequal split (all2all-v) + output = input.new_empty( + size=[sum(output_split_sizes)] + list(input.size()[1:]), + dtype=input.dtype, + device=torch.cuda.current_device()) + torch.distributed.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group) + return output + + @staticmethod + def backward(ctx, *grad_output): + return (None, _AllToAll.apply( + ctx.group, * grad_output, + ctx.input_split_sizes, + ctx.output_split_sizes), None, None) + +def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes_=None): + return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes_) diff --git a/megatron_patch/model/mixtral/experts.py b/megatron_patch/model/mixtral/experts.py new file mode 100644 index 00000000..d345cfe0 --- /dev/null +++ b/megatron_patch/model/mixtral/experts.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import copy + + +class Experts(torch.nn.Module): + def __init__(self, expert, num_local_experts=1, expert_group_name=None): + super(Experts, self).__init__() + + self.megatron_experts = torch.nn.ModuleList( + [copy.deepcopy(expert) for i in range(num_local_experts)]) + self.num_local_experts = num_local_experts + + # TODO: revisit allreduce for moe.gate... + for expert in self.megatron_experts: + # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) + for name, param in expert.named_parameters(): + param.allreduce = False + param.group_name = expert_group_name + + def forward(self, inputs): + chunks = inputs.chunk(self.num_local_experts, dim=1) + expert_outputs = [] + for chunk, expert in zip(chunks, self.megatron_experts): + out = expert(chunk) + if type(out) is tuple: + out = out[0] # Ignore the bias term for now + expert_outputs += [out] + + expert_output = torch.cat(expert_outputs, dim=1) + return expert_output diff --git a/megatron_patch/model/mixtral/gpt_model.py b/megatron_patch/model/mixtral/gpt_model.py new file mode 100644 index 00000000..c43f963c --- /dev/null +++ b/megatron_patch/model/mixtral/gpt_model.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron import get_args +from megatron.core import tensor_parallel +from megatron.model.module import MegatronModule + +from megatron.model.enums import AttnMaskType +from .language_model import parallel_lm_logits +from .language_model import get_language_model + + +def post_language_model_processing(lm_output, labels, logit_weights, + parallel_output, + fp16_lm_cross_entropy): + + # Output. Format [s b h] + output = parallel_lm_logits( + lm_output, + logit_weights, + parallel_output) + + if labels is None: + # [s b h] => [b s h] + return output.transpose(0,1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + return loss + + +class GPTModel(MegatronModule): + """GPT-2 Language model.""" + + def __init__(self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True): + args = get_args() + super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process) + + if not args.untie_embeddings_and_output_weights: + self.initialize_word_embeddings() + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, input_ids, position_ids, attention_mask, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + labels=None, tokentype_ids=None, inference_params=None): + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + retriever_input_ids=retriever_input_ids, + retriever_position_ids=retriever_position_ids, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) + + if self.post_process: + return post_language_model_processing( + lm_output, labels, + self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), + self.parallel_output, + self.fp16_lm_cross_entropy) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Load word_embeddings. + if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/megatron_patch/model/mixtral/language_model.py b/megatron_patch/model/mixtral/language_model.py new file mode 100644 index 00000000..408b041e --- /dev/null +++ b/megatron_patch/model/mixtral/language_model.py @@ -0,0 +1,694 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron import get_args +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.model.enums import AttnMaskType +from megatron.model.enums import LayerType +from megatron.model.module import MegatronModule +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding + +from megatron_patch.model.mistral.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from .transformer import ParallelTransformer + + +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, + bias=None): + """LM logits using word embedding weights.""" + args = get_args() + # Parallel logits. + if args.async_tensor_model_parallel_allreduce or\ + args.sequence_parallel: + input_parallel = input_ + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 + async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ + model_parallel and not args.sequence_parallel + else: + input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) + async_grad_allreduce = False + + # Matrix multiply. + logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( + input=input_parallel, + weight=word_embeddings_weight, + bias=bias, + gradient_accumulation_fusion=args.gradient_accumulation_fusion, + async_grad_allreduce=async_grad_allreduce, + sequence_parallel=args.sequence_parallel) + # Gather if needed. + + if parallel_output: + return logits_parallel + + return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) + + +def get_language_model(config, num_tokentypes, add_pooler, + encoder_attn_mask_type, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + pre_process=True, post_process=True): + """Build language model and return along with the key to save.""" + args = get_args() + if config.init_method is None: + config.init_method = init_method_normal(config.init_method_std) + + if config.output_layer_init_method is None: + config.output_layer_init_method = scaled_init_method_normal(config.init_method_std, + config.num_layers) + + # Language model. + language_model = TransformerLanguageModel( + config, + encoder_attn_mask_type, + num_tokentypes=num_tokentypes, + add_encoder=add_encoder, + add_decoder=add_decoder, + decoder_attn_mask_type=decoder_attn_mask_type, + add_pooler=add_pooler, + pre_process=pre_process, + post_process=post_process + ) + # key used for checkpoints. + language_model_key = 'language_model' + + return language_model, language_model_key + + +class Pooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, init_method): + super(Pooler, self).__init__() + args = get_args() + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.sequence_parallel = args.sequence_parallel + + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [s, b, h] + # sequence_index: index of the token to pool. + + # gather data along sequence dimensions + # same pooler is run on all tensor parallel nodes + if self.sequence_parallel: + hidden_states = tensor_parallel.gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=False) + + pooled = hidden_states[sequence_index, :, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled + + +class Embedding(MegatronModule): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + config, + num_tokentypes=0): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.init_method = config.init_method + self.num_tokentypes = num_tokentypes + + args = get_args() + + # Word embeddings (parallel). + self.params_dtype = args.params_dtype + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + vocab_size, self.hidden_size, config=config, init_method=config.init_method) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.add_position_embedding = args.position_embedding_type == 'learned_absolute' + if self.add_position_embedding: + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + if args.perform_initialization: + self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + if args.perform_initialization: + self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + self.fp32_residual_connection = args.fp32_residual_connection + self.sequence_parallel = args.sequence_parallel + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + if self.add_position_embedding: + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), + flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + args = get_args() + self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + if self.add_position_embedding: + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + else: + embeddings = words_embeddings + + if tokentype_ids is not None: + assert self.tokentype_embeddings is not None + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + else: + assert self.tokentype_embeddings is None + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + with tensor_parallel.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.add_position_embedding: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self.add_position_embedding: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, + strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', flush=True) + + +class TransformerLanguageModel(MegatronModule): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + config, + encoder_attn_mask_type, + num_tokentypes=0, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + add_pooler=False, + pre_process=True, + post_process=True): + args = get_args() + # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. + if args.untie_embeddings_and_output_weights: assert not add_decoder + super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = config.hidden_size + self.num_tokentypes = num_tokentypes + self.init_method = config.init_method + self.add_encoder = add_encoder + self.encoder_attn_mask_type = encoder_attn_mask_type + self.add_decoder = add_decoder + self.decoder_attn_mask_type = decoder_attn_mask_type + self.add_pooler = add_pooler + self.encoder_hidden_state = None + self.add_retriever = args.retro_add_retriever + self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + self.sliding_window = args.sliding_window + + # Embeddings. + if self.pre_process: + self.embedding = Embedding(self.hidden_size, + args.padded_vocab_size, + args.max_position_embeddings, + args.hidden_dropout, + config, + self.num_tokentypes) + self._embedding_key = 'embedding' + + # Rotary positional embeddings + if args.use_rotary_position_embeddings: + self.seq_length = args.seq_length + rotary_dim = args.hidden_size // args.num_attention_heads \ + if args.kv_channels is None else args.kv_channels + + if args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * args.rotary_percent) + + # partial rotary embeddings, which is better than full rotary + # Wang and Komatsuzaki et al + # https://github.com/kingoflolz/mesh-transformer-jax/ + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim, + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + self.use_rotary_position_embeddings = True + elif args.use_mistral_rotary_position_embeddings: + self.use_rotary_position_embeddings = False + + + # Encoder (usually set to True, False if part of an encoder-decoder + # architecture and in encoder-only stage). + if self.add_encoder: + self.encoder = ParallelTransformer( + config, + model_type=args.model_type if not args.retro_add_retriever \ + else ModelType.retro_decoder, + self_attn_mask_type=self.encoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self._encoder_key = 'encoder' + else: + self.encoder = None + + # Decoder (usually set to False, True if part of an encoder-decoder + # architecture and in decoder-only stage). + if self.add_decoder: + self.decoder = ParallelTransformer( + config, + model_type=args.model_type, + layer_type=LayerType.decoder, + self_attn_mask_type=self.decoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process) + self._decoder_key = 'decoder' + else: + self.decoder = None + + if self.post_process: + # Pooler. + if self.add_pooler: + self.pooler = Pooler(self.hidden_size, self.init_method) + self._pooler_key = 'pooler' + + if self.untie_embeddings_and_output_weights: + self.output_layer = tensor_parallel.ColumnParallelLinear( + args.hidden_size, + args.padded_vocab_size, + config=config, + init_method=self.init_method, + bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. + self._output_layer_key = 'output_layer' + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + if self.add_encoder and self.add_decoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with both encoder and decoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with only encoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_decoder: + if len(input_tensor) == 2: + self.decoder.set_input_tensor(input_tensor[0]) + self.encoder_hidden_state = input_tensor[1] + elif len(input_tensor) == 1: + self.decoder.set_input_tensor(None) + self.encoder_hidden_state = input_tensor[0] + else: + raise Exception('input_tensor must have either length 1 or 2') + else: + raise Exception('Stage must have at least either encoder or decoder') + + def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, + dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + enc_dec_attn_mask=None, tokentype_ids=None, + inference_params=None, + pooling_sequence_index=0, + enc_hidden_states=None, output_enc_hidden=False): + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + # Retriever embedding. + if self.add_retriever and self.pre_process: + retriever_input = self.embedding(retriever_input_ids, + retriever_position_ids, + tokentype_ids=tokentype_ids) + else: + retriever_input = None + + # Rotary positional embeddings + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + + if enc_position_ids is None: + past_key_values_length = 0 + seq_length = self.seq_length + device = enc_input_ids.device\ + if enc_input_ids is not None else encoder_input.device + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + enc_position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + batch_size = enc_input_ids.shape[0] + seq_length = enc_input_ids.shape[1] + enc_attn_mask = _prepare_4d_causal_attention_mask( + enc_attn_mask, + (batch_size, seq_length), + encoder_input, + 0, + sliding_window=self.sliding_window, + ) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + retriever_input=retriever_input, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + position_ids=enc_position_ids + ) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + if self.post_process: + if self.add_pooler: + pooled_output = self.pooler(encoder_output, + pooling_sequence_index) + + # output_enc_hidden refers to when we just need the encoder's + # output. For example, it is helpful to compute + # similarity between two sequences by average pooling + if not self.add_decoder or output_enc_hidden: + if self.add_pooler and self.post_process: + return encoder_output, pooled_output + else: + return encoder_output + + # Decoder embedding. + if self.pre_process: + decoder_input = self.embedding(dec_input_ids, + dec_position_ids) + else: + decoder_input = None + + # Run decoder. + decoder_output = self.decoder( + decoder_input, + dec_attn_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + + if self.add_pooler and self.post_process: + return decoder_output, encoder_output, pooled_output + else: + return decoder_output, encoder_output + + def _gather_moe_state_dict(self, state_dict_, enc_or_dec_key): + """Handle MoE states separately""" + moe_state_dict_ = {} + for key in list(state_dict_[enc_or_dec_key].keys()): + if 'megatron_moe' in key and 'megatron_moe.gate.wg.weight' not in key: + moe_state_dict_[enc_or_dec_key + key] = state_dict_[enc_or_dec_key].pop(key) + return moe_state_dict_ + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + args = get_args() + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.add_encoder: + state_dict_[self._encoder_key] \ + = self.encoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + if args.moe: + moe_state_dict_ = {} + moe_state_dict_.update(self._gather_moe_state_dict(state_dict_, self._encoder_key)) + + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] \ + = self.pooler.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.untie_embeddings_and_output_weights: + state_dict_[self._output_layer_key] \ + = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_decoder: + state_dict_[self._decoder_key] \ + = self.decoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + args = get_args() + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if self.add_encoder: + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", + ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + if args.transformer_impl == "transformer_engine": + self.encoder.load_state_dict(state_dict_, strict=False) + else: + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Gather encoder MoE states + if "moe_state_dict" in state_dict: + for key in list(state_dict["moe_state_dict"].keys()): + if self._encoder_key in key: + key_list = key.split('.') + while key_list[0].find('encoder') == -1: + key_list.pop(0) + key_list[0] = key_list[0].replace("encoder", "") + if key_list[0] == "": + key_list.pop(0) + actual_key = '.'.join(key_list) + state_dict_[actual_key] = state_dict["moe_state_dict"].pop(key) + if len(state_dict["moe_state_dict"]) == 0: + del state_dict["moe_state_dict"] + + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Pooler. + if self.post_process: + if self.add_pooler: + assert 'pooler' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], + strict=strict) + if self.untie_embeddings_and_output_weights: + assert 'output_layer' in state_dict, \ + 'could not find data for output_layer in the checkpoint' + self.output_layer.load_state_dict(state_dict[self._output_layer_key], + strict=strict) + # Decoder. + if self.add_decoder: + assert 'decoder' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.decoder.load_state_dict(state_dict[self._decoder_key], + strict=strict) diff --git a/megatron_patch/model/mixtral/layer.py b/megatron_patch/model/mixtral/layer.py new file mode 100644 index 00000000..1044e083 --- /dev/null +++ b/megatron_patch/model/mixtral/layer.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .router import MOELayer, Router +from .experts import Experts +import typing +from megatron import get_args +from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) + +class MoE(torch.nn.Module): + def __init__(self, + hidden_size, + expert, + num_experts=1, + ep_size=1, + k=1, + capacity_factor=1., + eval_capacity_factor=1., + min_capacity=4, + use_residual=False, + noisy_gate_policy: typing.Optional[str] = None, + drop_tokens: bool = True, + use_rts=True, + use_tutel: bool = False, + expert_tensor_parallelism: bool = False, + moe_layer_index: int = None): + """Initialize an MoE layer. + + Arguments: + hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension. + expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear). + num_experts (int, optional): default=1, the total number of experts per layer. + ep_size (int, optional): default=1, number of ranks in the expert parallel world or group. + k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'. + drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). + use_rts (bool, optional): default=True, whether to use Random Token Selection. + use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed). + expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts + """ + + super(MoE, self).__init__() + + self.use_residual = use_residual + self.expert_tensor_parallelism = expert_tensor_parallelism + assert num_experts % ep_size == 0, f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" + self.ep_size = ep_size + self.expert_group_name = f"ep_size_{self.ep_size}" + self.num_experts = num_experts + self.num_local_experts = num_experts // self.ep_size + self.moe_layer_index = moe_layer_index + + experts = Experts(expert, self.num_local_experts, self.expert_group_name) + self.megatron_moe = MOELayer(Router(hidden_size, + num_experts, + k, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts), + experts, + self.expert_group_name, + self.ep_size, + self.num_local_experts, + use_tutel=use_tutel, + expert_tensor_parallelism=expert_tensor_parallelism) + if self.use_residual: + self.mlp = expert + # coefficient is used for weighted sum of the output of expert and mlp + self.coefficient = torch.nn.Linear(hidden_size, 2) + + def forward(self, hidden_states, used_token=None): + """ MoE forward + + Arguments: + hidden_states (Tensor): input to the layer + used_token (Tensor, optional): default: None, mask only used tokens + + Returns: + A tuple including output, gate loss, and expert count. + + * output (Tensor): output of the model + + * mlp_bias (Tensor): placehoder, no effect + """ + args = get_args() + + # Gathering hidden_states for expert tensor parallel. + if args.expert_tensor_parallelism and args.sequence_parallel: + hidden_states = \ + gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) + + output = self.megatron_moe(hidden_states, used_token) + + if self.use_residual: + # Residual MoE + output_mlp = self.mlp(hidden_states) + if type(output_mlp) is tuple: + output_mlp = output_mlp[0] # Ignore the bias term for now + coef = self.coefficient(hidden_states) + coef = torch.nn.functional.softmax(coef, dim=-1) + output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] + + # Reduce hidden_states after expert tensor parallel + if args.expert_tensor_parallelism: + if args.sequence_parallel: + output = reduce_scatter_to_sequence_parallel_region(output) + else: + output = reduce_from_tensor_model_parallel_region(output) + + mlp_bias = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + return output, mlp_bias + + def get_moe_layer_index(self): + return self.moe_layer_index diff --git a/megatron_patch/model/mixtral/moe_parallel_linear.py b/megatron_patch/model/mixtral/moe_parallel_linear.py new file mode 100644 index 00000000..890b8235 --- /dev/null +++ b/megatron_patch/model/mixtral/moe_parallel_linear.py @@ -0,0 +1,470 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from typing import Callable, Optional + +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.parallel_state import ( + get_global_memory_buffer, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +from megatron.core.tensor_parallel.mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + gather_from_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, + scatter_to_tensor_model_parallel_region, +) +from megatron.core.tensor_parallel.utils import divide +from megatron.core.tensor_parallel.layers import _initialize_affine_weight_cpu, _initialize_affine_weight_gpu, set_tensor_model_parallel_attributes +from megatron.core.tensor_parallel.layers import linear_with_grad_accumulation_and_async_allreduce, linear_with_frozen_weight + +_grad_accum_fusion_available = True +try: + import fused_weight_gradient_mlp_cuda +except ImportError: + _grad_accum_fusion_available = False + +_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { + 'tensor_model_parallel': False, + 'partition_dim': -1, + 'partition_stride': 1, +} + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: If True, do not add the bias term, instead + return it to be added by the caller. This + enables performance optimations where bias can + be fused with other elementwise operations. + skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed + as a keyword argument `weight` during the forward pass. Note + that this does not affect bias, which will be allocated if + bias is True. Defaults to False. + is_expert: If True, the layer is treated as an MoE expert layer. + config: ModelParallelConfig object + + """ + + def __init__( + self, + input_size, + output_size, + *, + config: ModelParallelConfig, + init_method: Callable, + bias=True, + gather_output=False, + stride=1, + keep_master_weight_for_test=False, + skip_bias_add=False, + skip_weight_param_allocation: bool = False, + is_expert: bool = False + ): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + self.skip_bias_add = skip_bias_add + self.is_expert = is_expert + self.config = config + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if not skip_weight_param_allocation: + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, self.input_size, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.output_size_per_partition, + 0, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size_per_partition, + self.input_size, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu( + self.weight, + init_method, + partition_dim=0, + stride=stride, + expert_parallel=(self.is_expert and config.expert_parallel), + ) + + setattr(self.weight, 'allreduce', not (self.is_expert and config.expert_parallel)) + else: + self.weight = None + + if bias: + if config.use_cpu_initialization: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=config.params_dtype) + ) + else: + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + set_tensor_model_parallel_attributes(self.bias, True, 0, stride) + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', not (self.is_expert and config.expert_parallel)) + else: + self.register_parameter('bias', None) + + self.async_tensor_model_parallel_allreduce = ( + config.async_tensor_model_parallel_allreduce and world_size > 1 + ) + + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and world_size <= 1: + warnings.warn( + f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. " + f"Disabling sequence parallel." + ) + self.sequence_parallel = False + + if config.gradient_accumulation_fusion and not _grad_accum_fusion_available: + raise RuntimeError( + "ColumnParallelLinear was called with gradient_accumulation_fusion set " + "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " + "module is not found. To use gradient_accumulation_fusion you must " + "install APEX with --cpp_ext and --cuda_ext. For example: " + "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " + "Note that the extension requires CUDA>=11. Otherwise, you must turn off " + "gradient accumulation fusion." + ) + self.gradient_accumulation_fusion = config.gradient_accumulation_fusion + + if self.async_tensor_model_parallel_allreduce and self.sequence_parallel: + raise RuntimeError( + "`async_tensor_model_parallel_allreduce` and `sequence_parallel` " + "cannot be enabled at the same time." + ) + + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + self.explicit_expert_comm = self.is_expert and ( + self.sequence_parallel or config.expert_parallel + ) + + def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): + """Forward of ColumnParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + weight (optional): weight tensor to use, compulsory when + skip_weight_param_allocation is True. + + Returns: + - output + - bias + + """ + if weight is None: + if self.weight is None: + raise RuntimeError( + "weight was not supplied to ColumnParallelLinear forward pass " + "and skip_weight_param_allocation is True." + ) + weight = self.weight + else: + # Check the weight passed in is the correct shape + expected_shape = (self.output_size_per_partition, self.input_size) + if weight.shape != expected_shape: + raise RuntimeError( + f"supplied weight's shape is {tuple(weight.shape)}, " + f"not {expected_shape} as expected" + ) + + bias = self.bias if not self.skip_bias_add else None + + if ( + self.async_tensor_model_parallel_allreduce + or self.sequence_parallel + or self.explicit_expert_comm + ): + input_parallel = input_ + else: + input_parallel = copy_to_tensor_model_parallel_region(input_) + + # Matrix multiply. + if not weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + if self.is_expert: + output_parallel = F.linear(input_parallel, weight, bias) + else: + output_parallel = self._forward_impl( + input=input_parallel, + weight=weight, + bias=bias, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=False + if self.explicit_expert_comm + else self.async_tensor_model_parallel_allreduce, + sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel, + ) + + if self.gather_output: + # All-gather across the partitions. + assert not self.sequence_parallel + output = gather_from_tensor_model_parallel_region(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: If True, do not add the bias term, instead + return it to be added by the caller. This + enables performance optimations where bias can + be fused with other elementwise operations. + is_expert: If True, the layer is treated as an MoE expert layer + config: ModelParallelConfig object + + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool = True, + input_is_parallel: bool = False, + stride: int = 1, + keep_master_weight_for_test: bool = False, + skip_bias_add: bool = False, + is_expert: bool = False, + ): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + self.skip_bias_add = skip_bias_add + self.config = config + self.is_expert = is_expert + self.gradient_accumulation_fusion = config.gradient_accumulation_fusion + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and not self.input_is_parallel: + raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + # Initialize weight. + if config.use_cpu_initialization: + self.weight = Parameter( + torch.empty( + self.output_size, self.input_size_per_partition, dtype=config.params_dtype + ) + ) + if config.perform_initialization: + self.master_weight = _initialize_affine_weight_cpu( + self.weight, + self.output_size, + self.input_size, + self.input_size_per_partition, + 1, + init_method, + stride=stride, + return_master_weight=keep_master_weight_for_test, + params_dtype=config.params_dtype, + ) + else: + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size_per_partition, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + if config.perform_initialization: + _initialize_affine_weight_gpu( + self.weight, + init_method, + partition_dim=1, + stride=stride, + expert_parallel=(self.is_expert and config.expert_parallel), + ) + setattr(self.weight, 'allreduce', not (self.is_expert and config.expert_parallel)) + + if bias: + if config.use_cpu_initialization: + self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype)) + else: + self.bias = Parameter( + torch.empty( + self.output_size, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + ) + + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, 'allreduce', not (self.is_expert and config.expert_parallel)) + setattr(self.bias, 'sequence_parallel', self.sequence_parallel) + else: + self.register_parameter('bias', None) + + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + self.explicit_expert_comm = self.is_expert and ( + self.sequence_parallel or config.expert_parallel + ) + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: 3D tensor whose order of dimension is [sequence, batch, hidden] + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + assert not self.sequence_parallel + input_parallel = scatter_to_tensor_model_parallel_region(input_) + # Matrix multiply. + if not self.weight.requires_grad: + self._forward_impl = linear_with_frozen_weight + else: + self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + if self.is_expert: + output_parallel = F.linear(input_parallel, self.weight, None) + else: + output_parallel = self._forward_impl( + input=input_parallel, + weight=self.weight, + bias=None, + gradient_accumulation_fusion=self.gradient_accumulation_fusion, + async_grad_allreduce=False, + sequence_parallel=False, + ) + + # All-reduce across all the partitions. + if self.explicit_expert_comm: + assert self.skip_bias_add + output_ = output_parallel + elif self.sequence_parallel: + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + else: + output_ = reduce_from_tensor_model_parallel_region(output_parallel) + if not self.skip_bias_add: + output = (output_ + self.bias) if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/megatron_patch/model/mixtral/router.py b/megatron_patch/model/mixtral/router.py new file mode 100644 index 00000000..56d54a61 --- /dev/null +++ b/megatron_patch/model/mixtral/router.py @@ -0,0 +1,787 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import Callable +from typing import Dict, TYPE_CHECKING, Any, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Module +import torch.nn.functional as F +from megatron import get_args +from megatron.core.tensor_parallel.mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region + +) + +from megatron_patch import expert_parallel_state +from .all2all import all_to_all + +if TYPE_CHECKING: + Base = Module[Tensor] +else: + Base = Module + + +uniform_map: Dict[torch.device, Callable] = {} +gumbel_map: Dict[torch.device, Callable] = {} +exp_selection_uniform_map: Dict[torch.device, Callable] = {} + +try: + from tutel import moe as tutel_moe + TUTEL_INSTALLED = True +except: + TUTEL_INSTALLED = False + pass + +# einsum rewrites are on par or more performant +# switch can be bubbled up in future +USE_EINSUM = True +# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity +# See https://arxiv.org/pdf/2006.16668.pdf for details. +def einsum(rule, a, b): + if USE_EINSUM: + return torch.einsum(rule, a, b) + elif rule == 's,se->se': + return a.reshape(a.shape[0], -1) * b + elif rule == 'se,sc->sec': + return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == 'se,sec->sec': + return a.unsqueeze(2) * b + elif rule == 'se,se->s': + return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == 'sec,sm->ecm': + s = a.shape[0] + e = a.shape[1] + c = a.shape[2] + m = b.shape[1] + return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) + elif rule == 'sec,ecm->sm': + return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) + elif rule == 'ks,ksm->sm': + k = b.shape[0] + s = b.shape[1] + m = b.shape[2] + # [k, s] -> [s, k] -> [s, 1, k] + a = a.t().unsqueeze(1) + # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] + b = b.reshape(k, -1).t().reshape(s, m, k) + # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] + return torch.bmm(a, b.transpose(1, 2)).squeeze(2) + else: + return torch.einsum(rule, a, b) + + +def scatter_tokens_to_tensor_parallel_region(input_): + # E, C, M -> C, E, M + input_ = input_.transpose(0, 1).contiguous() + input_ = scatter_to_sequence_parallel_region(input_) + # C, E, M -> E, C, M + input_ = input_.transpose(0, 1).contiguous() + return input_ + +def gather_tokens_from_tensor_parallel_region(input_): + input_ = input_.transpose(0, 1).contiguous() + input_ = gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=False) + input_ = input_.transpose(0, 1).contiguous() + return input_ + +class AuxLossBackwardHook(torch.autograd.Function): + main_loss_backward_scale = 1 + + @staticmethod + def forward(ctx, output, aux_loss): + # Preserve the aux_loss by storing it in the context to avoid garbage collection. + ctx.save_for_backward(aux_loss) + return output + + @staticmethod + def backward(ctx, grad_output): + # Scale the auxiliary loss like the main loss. + args = get_args() + aux_loss, = ctx.saved_tensors + + aux_loss_backward_scale = AuxLossBackwardHook.main_loss_backward_scale + if args.sequence_parallel and not args.expert_tensor_parallelism: + # When using the sequence partitioned activation directly as the input to the Gate, + # we need normalize the loss with regard to the number of input segements + # (tensor_model_parallel_size). See our MR for the details. + aux_loss_backward_scale /= args.tensor_model_parallel_size + + scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale + return grad_output, scaled_aux_loss_grad + + @staticmethod + def set_main_loss_backward_scale(scale): + # No matter how the Main loss scales, the Aux loss needs to be scaled in the same way to + # ensure that the gradients produced by both are scaled equally. + AuxLossBackwardHook.main_loss_backward_scale = scale + +def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): + """ + Modified from switch transformer paper. mesh transformers + Multiply values by a random number between 1-epsilon and 1+epsilon. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + Args: + x: a torch.tensor + device: torch.device + epsilon: a floating point value + Returns: + a jittered x. + """ + if epsilon == 0: + return x + uniform = uniform_map.get(device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - epsilon, + device=device), + high=torch.tensor(1.0 + epsilon, + device=device)).rsample # type: ignore + uniform_map[device] = uniform + return x * uniform(x.shape) + + +def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: + gumbel = gumbel_map.get(device) + if gumbel is None: + one = torch.tensor(1.0, device=device) + zero = torch.tensor(0.0, device=device) + gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore + gumbel_map[device] = gumbel + return gumbel(shape) + + +@torch.jit.script +def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor: + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # to(torch.int64) works around a bug in torch.onnx.export: + # it should cast k to int64 when converting torch.topk but it doesn't. + capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64) + if capacity < min_capacity: + capacity = min_capacity.to(torch.int64) + return capacity + + +@torch.jit.script +def _top_idx(source, k): + return torch.topk(source, k=k, dim=0)[1] + + +@torch.jit.script +def _one_hot_to_float(x, num_classes): + return F.one_hot(x, num_classes=num_classes).float() + + +# Implemented by refer to this paper: https://arxiv.org/pdf/2202.09368.pdf +def expert_choice(logits: Tensor, + capacity_factor: float, + min_capacity: float, + used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, + Tensor, + Tensor]: + + """ Implements Expert Choice Routing """ + # min_capacity, used_token noisy_gate_policy, use_rts and use_tutel are not used in this router + # keep them as parameters for compatibility + + scores = F.softmax(logits, dim=1) + # from [T, E] to [E, T] + scores = torch.transpose(scores, 0, 1).contiguous() + k = int(scores.shape[1] * capacity_factor / scores.shape[0]) + gatings, indices = torch.topk(scores, k=k, dim=1) + + return 0, gatings, indices + + +def top1gating(logits: Tensor, capacity_factor: float, min_capacity: int, used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements Top1Gating on logits.""" + + if noisy_gate_policy == 'RSample': + logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, + torch.tensor(capacity_factor), + torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = torch.argmax( + logits_w_noise if noisy_gate_policy == 'RSample' else gates, + dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # mask only used tokens + if used_token is not None: + mask1 *= used_token.unsqueeze(1) + + # if we don't want to drop any tokens + if not drop_tokens: + from torch import distributed as dist + exp_counts = mask1.sum(dim=0).detach().cpu() + new_capacity = torch.max(exp_counts).to(logits.device) + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + capacity = new_capacity + + # Compute l_aux + me = gates.mean(dim=0) + ce = mask1.float().mean(dim=0) + l_aux = (me * ce).sum() * num_experts + + # Random Token Selection + if use_rts: + uniform = exp_selection_uniform_map.get(logits.device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=logits.device), + high=torch.tensor(1.0, device=logits.device)).rsample + exp_selection_uniform_map[logits.device] = uniform + mask1_rand = mask1 * uniform(mask1.shape) + else: + mask1_rand = mask1 + + assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. \ + Either set min_capacity to 0 or increase your batch size." + top_idx = _top_idx(mask1_rand, capacity) + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + gatings, indices = torch.topk(gates.T.contiguous(), k=capacity, dim=1, sorted=False) + return l_aux, gatings, indices, new_mask1 + + +# This function has been adapted from deepspeed file: +# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py +def top1gating_tutel(logits: Tensor, + capacity_factor: float, + min_capacity: int, + used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, + Tensor, + Tensor]: + + """Implements Top1Gating on logits.""" + # if noisy_gate_policy == 'RSample': + # logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, + torch.tensor(capacity_factor), + torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = torch.argmax( + logits_w_noise if noisy_gate_policy == 'RSample' else gates, + dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = torch.einsum("s,se->se", used_token, mask1) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + # if we don't want to drop any tokens + if not drop_tokens: + from torch import distributed as dist + new_capacity = torch.max(exp_counts).to(logits.device) + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + capacity = new_capacity + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + # Revisit: whether to divide l_aux by micro-batches or not? + l_aux = torch.sum(me * ce) * num_experts + + # Random Token Selection + if use_rts: + uniform = exp_selection_uniform_map.get(logits.device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, + device=logits.device), + high=torch.tensor(1.0, + device=logits.device)).rsample + exp_selection_uniform_map[logits.device] = uniform + + mask1_rand = mask1 * uniform(mask1.shape) + else: + mask1_rand = mask1 + + assert logits.shape[0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. \ + Either set min_capacity to 0 or increase your batch size." + top_idx = _top_idx(mask1_rand, capacity) + + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 + + # Tutel doesn't support index values masked with zero + # so we need to replace masked indices with -1 + indices_mask = mask1.sum(dim=1) * num_experts - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + # Compute locations in capacity buffer + + locations1 = tutel_moe.fast_cumsum_sub_one(mask1) + + gates1_s = (gates * mask1).sum(dim=1) + locations1_s = torch.sum(locations1 * mask1, dim=1) + return l_aux, capacity, num_experts, [indices1_s,], [locations1_s,], [gates1_s,] + +def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int, used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements Top2Gating on logits.""" + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, + torch.tensor(capacity_factor * 2), + torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + # Replace top-expert with min value + logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + indices2_s = torch.argmax(logits_except1, dim=1) + mask2 = F.one_hot(indices2_s, num_classes=num_experts) + + # Compute locations in capacity buffer + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(mask1, dim=0, keepdim=True) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + l_aux = torch.mean(me * ce) * num_experts * num_experts + + # Remove locations outside capacity from mask + mask1 *= torch.lt(locations1, capacity) + mask2 *= torch.lt(locations2, capacity) + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + locations2_s = torch.sum(locations2 * mask2, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + mask2_float = mask2.float() + gates1_s = einsum("se,se->s", gates, mask1_float) + gates2_s = einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = einsum("s,se->se", gates1_s, mask1_float) + gates2 = einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = _one_hot_to_float(locations1_s, capacity) + locations2_sc = _one_hot_to_float(locations2_s, capacity) + combine1_sec = einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + +# Copy from Megatron MoE branch https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/tree/moe +def sinkhorn(cost, tol=0.0001): + cost = torch.exp(cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps) + d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps) + error = torch.mean(torch.abs(d1_old-d1)) + d1_old = d1 + return d1*cost*d0.unsqueeze(1) + + +# Sinkhorn implementation refer to this paper: +# Unified scaling law for large language model (https://arxiv.org/pdf/2202.01169.pdf) +def sinkhornv2(logits, tol=0.01): + f = torch.zeros(logits.size(0), device=logits.device, dtype=logits.dtype) + g = torch.zeros(logits.size(1), device=logits.device, dtype=logits.dtype) + + # ToDo: add iteration early stop + for _ in range(50): + f = -torch.log(1/logits.size(1) * torch.sum(torch.exp(logits + g[None,:]), dim=1)) + g = -torch.log(1/logits.size(0) * torch.sum(torch.exp(logits + f[:, None]), dim=0)) + gates = torch.exp(logits + f[:, None] + g[None,:]) + return gates + + +def s_base(logits: Tensor, + capacity_factor: float, + min_capacity: int, + used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, + Tensor, + Tensor]: + """Implements s-base on logits.""" + # used_token, drop_token, noisy_gate_policy, drop_tokens and use_rts are not used in this router + # keep them as paramaters for compatibility + + # reference: https://arxiv.org/pdf/2209.15466.pdf + # "As in Clark (2022), we linearly combine the output of experts using a softmax matrix softmax(WX)" + gates = F.softmax(logits, dim=1) + + with torch.no_grad(): + # Both sinkhorn implementations work fine + # we choose sinkhornv2 as default here + sinkroute = sinkhornv2(logits.detach().to(dtype=torch.float32)) + _, indices1_s = torch.max(sinkroute, dim=1) + + capacity = _capacity(logits, + torch.tensor(capacity_factor), + torch.tensor(min_capacity)) + + num_experts = int(logits.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # gating decisions + mask1_rand = mask1 + top_idx = _top_idx(mask1_rand, capacity) + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 + + if use_tutel: + # Tutel doesn't support index values masked with zero + # so we need to replace masked indices with -1 + indices_mask = mask1.sum(dim=1) * num_experts - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + # Compute locations in capacity buffer + if use_tutel: + locations1 = tutel_moe.fast_cumsum_sub_one(mask1) + else: + locations1 = torch.cumsum(mask1, dim=0) - 1 + + if use_tutel: + sinkroute1_s = (sinkroute * mask1).sum(dim=1) + locations1_s = torch.sum(locations1 * mask1, dim=1) + return 0, capacity, num_experts, [indices1_s,], [locations1_s,], [sinkroute1_s,] + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + sinkroute = sinkroute * mask1_float + gates = gates * mask1_float + + locations1_sc = _one_hot_to_float(locations1_s, capacity) + # reference: https://arxiv.org/pdf/2209.15466.pdf + # "As in Clark (2022), we linearly combine the output of experts using a softmax matrix softmax(WX)" + combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = torch.einsum("se,sc->sec", sinkroute, locations1_sc) + dispatch_mask = dispatch_mask.bool() + + return 0, combine_weights, dispatch_mask + + +class Router(Module): + """ Gate / Router module """ + + def __init__(self, + model_dim: int, + num_experts: int, + k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 8, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True) -> None: + super().__init__() + + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() + setattr(self.wg.weight, "router", True) + self.k = k + self.capacity_factor = capacity_factor + self.eval_capacity_factor = eval_capacity_factor + self.min_capacity = min_capacity + self.noisy_gate_policy = noisy_gate_policy + self.drop_tokens = drop_tokens + self.use_rts = use_rts + args = get_args() + self.sequence_parallel = args.sequence_parallel + self.expert_tensor_parallelism = args.expert_tensor_parallelism + self.tensor_model_parallel_size = args.tensor_model_parallel_size + self.use_tutel = args.use_tutel + self.moe_loss_coeff = args.moe_loss_coeff + self.router_type = args.router_type + + if self.router_type == 'topk': + if self.k == 1: + if self.use_tutel: + self.gate = top1gating_tutel + else: + self.gate = top1gating + elif self.k == 2: + self.gate = top2gating + elif self.router_type == 'expert_choice': + self.gate = expert_choice + + def forward(self, + input: torch.Tensor, + used_token: torch.Tensor = None, + use_tutel: bool = False) -> Tuple[Tensor, + Tensor, + Tensor]: + + if self.wg.weight.dtype != torch.float32: + self.wg = self.wg.float() + setattr(self.wg.weight, 'router', True) + input_fp32 = input.float() + logits = self.wg(input_fp32) + + gate_output = self.gate( + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + used_token, + self.noisy_gate_policy if self.training else None, + self.drop_tokens, + self.use_rts, + self.use_tutel) + + if self.router_type == 'top1': + gate_output[0].mul_(self.moe_loss_coeff) + + return gate_output + +class MOELayer(Base): + """MOELayer module""" + + def __init__(self, + gate: Module, + experts: Module, + ep_group_name, + ep_size, + num_local_experts: int, + use_tutel: bool = False, + expert_tensor_parallelism: bool = False) -> None: + super().__init__() + self.gate = gate + self.experts = experts + self.ep_group = None + self.ep_size = ep_size + self.ep_group_name = ep_group_name + self.num_local_experts = num_local_experts + self.num_experts = num_local_experts * ep_size + self.ep_group = expert_parallel_state.get_expert_parallel_group() + self.expert_tensor_parallelism = expert_tensor_parallelism + + args = get_args() + self.router_type = args.router_type + self.use_tutel = use_tutel and TUTEL_INSTALLED + if self.use_tutel: + print('Using Tutel optimizations.') + elif use_tutel and not TUTEL_INSTALLED: + print("Tutel optimization requested but not installed. " + "Proceeding without Tutel.") + + if self.router_type == 'topk': + if args.moe_topk == 1: + self.moe_execution_func = self._top1_execution + elif args.moe_topk == 2: + self.moe_execution_func = self._top2_execution + elif self.router_type == 'expert_choice': + self.moe_execution_func = self._ec_execution + + def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + return self.moe_execution_func(input) + + def dispatch_expert_combine(self, dispatched_input, model_dim): + # token dispatching + dispatched_input = self._before_dispatch_a2a_in_tp(dispatched_input) + dispatched_input = all_to_all(self.ep_group, dispatched_input) + dispatched_input = self._after_dispatch_a2a_in_tp(dispatched_input) + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape(self.ep_size, + self.num_local_experts, + -1, model_dim) + expert_output = self.experts(dispatched_input) + expert_output = self._before_combine_a2a_in_tp(expert_output) + # token combining + expert_output = all_to_all(self.ep_group, expert_output) + # Re-shape back: gecm -> ecm + expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, + -1, model_dim) + expert_output = self._after_combine_a2a_in_tp(expert_output) + return expert_output + + # TODO: remove tutel code + def _top1_execution(self, input): + d_model = input[0].shape[-1] + reshaped_input = input[0].reshape(-1, d_model) + if self.use_tutel: + l_aux, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, input[1], True) + S, M = reshaped_input.size(0), reshaped_input.size(1) + if not hasattr(self, '_tutel_dispatcher'): + self._tutel_dispatcher = tutel_moe.fast_dispatcher( + E, C, M, dispatch_dtype=reshaped_input.dtype) + self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) + dispatched_input = self._tutel_dispatcher.encode(reshaped_input) + # Reshape tutel's output from [e*c,m] to [e,c,m] + dispatched_input = dispatched_input.reshape(self.ep_size * self.num_local_experts, + -1, d_model) + else: + l_aux, gating, indices, mask = self.gate(reshaped_input, input[1]) + masked_reshaped_input = reshaped_input * (mask.sum(axis=1).unsqueeze(1)) + dispatched_input = masked_reshaped_input.index_select( + dim=0, index=indices.view(-1)).reshape(self.ep_size * self.num_local_experts, -1, d_model) + + expert_output = self.dispatch_expert_combine(dispatched_input, d_model) + if self.use_tutel: + combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M)) + else: + combined_output = torch.einsum("ec,ecm->ecm", gating.type_as(input[0]), expert_output) + combined_output = torch.scatter_add( + torch.zeros_like(reshaped_input), 0, + indices.view(-1, 1).expand(-1, reshaped_input.shape[1]), + combined_output.reshape(-1, d_model)) + acts = combined_output.reshape(input[0].shape) + # Use an autograd function to activate the backward computation for l_aux + acts = AuxLossBackwardHook.apply(acts, l_aux) + return acts + + def _top2_execution(self, input): + d_model = input[0].shape[-1] + reshaped_input = input[0].reshape(-1, d_model) + + l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) + dispatched_input = einsum("sec,sm->ecm", + dispatch_mask.type_as(input[0]), + reshaped_input) + + expert_output = self.dispatch_expert_combine(dispatched_input, d_model) + combined_output = einsum("sec,ecm->sm", + combine_weights.type_as(input[0]), + expert_output) + + acts = combined_output.reshape(input[0].shape) + # Use an autograd function to activate the backward computation for l_aux + acts = AuxLossBackwardHook.apply(acts, l_aux) + return acts + + def _ec_execution(self, input): + d_model = input[0].shape[-1] + reshaped_input = input[0].reshape(-1, d_model) + + l_aux, gating, indices = self.gate(reshaped_input, input[1]) + dispatched_input = reshaped_input.index_select(dim=0, index=indices.view(-1)).reshape( + self.ep_size * self.num_local_experts, -1, d_model) + + expert_output = self.dispatch_expert_combine(dispatched_input, d_model) + + combined_output = torch.einsum("ec,ecm->ecm", gating.type_as(input[0]), expert_output) + combined_output = torch.scatter_add( + torch.zeros_like(reshaped_input), 0, + indices.view(-1, 1).expand(-1, reshaped_input.shape[1]), + combined_output.reshape(-1, d_model)) + acts = combined_output.reshape(input[0].shape) + return acts + + def _before_dispatch_a2a_in_tp(self, dispatched_input): + args = get_args() + if args.tensor_model_parallel_size <= 1: + return dispatched_input + if self.expert_tensor_parallelism: + # Expert Tensor Parallel + # No operation in the forward pass and all-reduce in the backward pass + dispatched_input = copy_to_tensor_model_parallel_region(dispatched_input) + if args.moe_input_feature_slicing: + dispatched_input = scatter_to_tensor_model_parallel_region(dispatched_input) + else: + # EP w/o SP + # Drop duplicate tokens to ensure correctness. + # Better enable SP with TP+EP as it no extra + # communications, less memory footprint and better performance. + if not args.sequence_parallel: + dispatched_input = scatter_tokens_to_tensor_parallel_region(dispatched_input) + + return dispatched_input + + def _after_dispatch_a2a_in_tp(self, dispatched_input): + args = get_args() + if args.tensor_model_parallel_size <= 1: + return dispatched_input + if self.expert_tensor_parallelism and args.moe_input_feature_slicing: + dispatched_input = gather_from_tensor_model_parallel_region(dispatched_input) + return dispatched_input + + def _before_combine_a2a_in_tp(self, expert_output): + args = get_args() + if args.tensor_model_parallel_size <= 1: + return expert_output + # if self.expert_tensor_parallelism and args.moe_input_feature_slicing: + # expert_output = scatter_to_tensor_model_parallel_region(expert_output) + return expert_output + + def _after_combine_a2a_in_tp(self, expert_output): + args = get_args() + if args.tensor_model_parallel_size <= 1: + return expert_output + # if self.expert_tensor_parallelism and args.moe_input_feature_slicing: + # expert_output = gather_from_tensor_model_parallel_region(expert_output) + if not self.expert_tensor_parallelism and not args.sequence_parallel: + # The dropped duplicate tokens need to be gathered on each + # tensor parallel rank again for the tensor-parallel + # non-expert of the next layer. + expert_output = gather_tokens_from_tensor_parallel_region(expert_output) + return expert_output diff --git a/megatron_patch/model/mixtral/transformer.py b/megatron_patch/model/mixtral/transformer.py new file mode 100644 index 00000000..0f508816 --- /dev/null +++ b/megatron_patch/model/mixtral/transformer.py @@ -0,0 +1,1750 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +import math +import numpy as np +import torch +import torch.nn.functional as F +from typing import Optional + +from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches +from megatron.model.module import MegatronModule +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.model.enums import AttnMaskType, LayerType, AttnType +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.fused_bias_gelu import bias_gelu_impl +from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_norm + +from megatron_patch.model.mistral.rotary_pos_embedding import RotaryEmbedding +from megatron_patch.model.mistral.rotary_pos_embedding import apply_rotary_pos_emb as apply_mistral_rotary_pos_emb +from .layer import MoE +from .moe_parallel_linear import ColumnParallelLinear, RowParallelLinear + + +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + except ImportError: + flash_attn_unpadded_func = None + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + +class DropPath(MegatronModule): + """Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_state): + if self.drop_prob == 0. or not self.training: + return hidden_state + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + # hidden_state: [s, b, h] + shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2) + random_tensor = keep_prob + \ + torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) + random_tensor.floor_() # binarize + output = hidden_state.div(keep_prob) * random_tensor + return output + +class ParallelMLP(MegatronModule): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config, is_expert=False): + super(ParallelMLP, self).__init__() + args = get_args() + + self.add_bias = config.add_bias_linear + + ffn_hidden_size = config.ffn_hidden_size + if config.gated_linear_unit: + ffn_hidden_size *= 2 + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + ffn_hidden_size, + config=config, + init_method=config.init_method, + bias=self.add_bias, + gather_output=False, + skip_bias_add=True, + is_expert=is_expert) + + self.bias_gelu_fusion = False + self.activation_func = None + self.swiglu = args.swiglu + + if args.openai_gelu: + self.activation_func = openai_gelu + elif args.onnx_safe: + self.activation_func = erf_gelu + elif args.swiglu: + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + self.activation_func = swiglu + elif args.squared_relu: + def squared_relu(x): + return torch.pow(F.relu(x), 2) + self.activation_func = squared_relu + else: + self.bias_gelu_fusion = args.bias_gelu_fusion + self.activation_func = F.gelu + + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=self.add_bias, + input_is_parallel=True, + skip_bias_add=True, + is_expert=is_expert, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + assert self.add_bias is True + assert self.activation_func == F.gelu + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + +class CoreAttention(MegatronModule): + + def __init__(self, layer_number, config, + attn_mask_type=AttnMaskType.padding): + super(CoreAttention, self).__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = core.utils.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = core.utils.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, + self.attn_mask_type, + config.masked_softmax_fusion, + attention_mask_func, + self.attention_softmax_in_fp32, + coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, + value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( + (output_size[0]*output_size[1], output_size[2], output_size[3]), + query_layer.dtype, "mpu") + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_mask = attention_mask.to(torch.bool) + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) + assert all((i.is_cuda for i in (q,k,v))) + + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + device=q.device) + + if self.training: + # during training q,k,v always have same seqlen + assert seqlen_k == seqlen_q + + is_causal = self.causal + cu_seqlens_k = cu_seqlens_q + dropout_p = self.dropout_p + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = seqlen_q == seqlen_k + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + device=q.device) + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, causal=is_causal + ) + + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + +class ParallelAttention(MegatronModule): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding): + super(ParallelAttention, self).__init__() + args = get_args() + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + self.params_dtype = config.params_dtype + self.sequence_parallel = config.sequence_parallel + + self.group_query_attention = args.group_query_attention + self.num_query_groups = args.num_query_groups + + query_projection_size = config.kv_channels * config.num_attention_heads + if self.group_query_attention: + kv_projection_size = args.kv_channels * args.num_query_groups + else: + kv_projection_size = args.kv_channels * args.num_attention_heads + + self.use_flash_attn = args.use_flash_attn \ + and attention_type == AttnType.self_attn \ + and self.attn_mask_type == AttnMaskType.causal + if self.use_flash_attn: + if flash_attn_unpadded_func is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = core.utils.divide( + query_projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + if self.group_query_attention: + if args.num_query_groups % world_size != 0: + raise NotImplementedError('Currently the num_query_groups should be ' + 'a multiple of the tensor parallel size') + self.num_query_groups_per_partition = core.utils.divide( + args.num_query_groups, world_size) + else: + self.num_query_groups_per_partition = self.num_attention_heads_per_partition + + # Strided linear layer. + if attention_type == AttnType.self_attn: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + + else: + assert attention_type == AttnType.cross_attn + + if self.group_query_attention: + raise NotImplementedError("Grouped query attention not implemented for cross-attention.") + assert query_projection_size == kv_projection_size + + self.query = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.core_attention = CoreAttention(self.layer_number, config, + self.attn_mask_type) + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + if self.use_flash_attn: + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attention_dropout + ) + + # Output. + self.dense = tensor_parallel.RowParallelLinear( + query_projection_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=args.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True) + + if args.use_mistral_rotary_position_embeddings: + self.use_mistral_rotary_position_embeddings = True + self.seq_length = args.seq_length + rotary_dim = args.hidden_size // args.num_attention_heads \ + if args.kv_channels is None else args.kv_channels + + if args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * args.rotary_percent) + + # partial rotary embeddings, which is better than full rotary + # Wang and Komatsuzaki et al + # https://github.com/kingoflolz/mesh-transformer-jax/ + self.rotary_emb = RotaryEmbedding( + rotary_dim, + args.max_position_embeddings + ) + else: + self.use_mistral_rotary_position_embeddings = False + + def _checkpointed_attention_forward(self, query_layer, key_layer, + value_layer, attention_mask, + rotary_pos_emb=None): + """Forward method with activation checkpointing.""" + def custom_forward(*inputs): + query_layer = inputs[0] + key_layer = inputs[1] + value_layer = inputs[2] + attention_mask = inputs[3] + output_ = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + return output_ + + q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ + else rotary_pos_emb + + hidden_states = tensor_parallel.checkpoint( + custom_forward, + False, query_layer, key_layer, value_layer, attention_mask, + q_pos_emb, k_pos_emb) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads): + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, + encoder_output=None, inference_params=None, + rotary_pos_emb=None, position_ids=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + is_first_step = True + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + if self.attention_type == AttnType.self_attn: + + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query_layer, + key_layer, + value_layer) = torch.split( + mixed_x_layer, + [ + ( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - + query_layer = query_layer.contiguous().view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) + else: + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + # duplicate the pos_emb for self attention + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + if inference_params: + if self.use_mistral_rotary_position_embeddings: + kv_seq_len = key_layer.shape[0] + kv_seq_len += inference_params.sequence_len_offset + value_layer = value_layer.transpose(0, 1).transpose(1, 2) + query_layer = query_layer.transpose(0, 1).transpose(1, 2) + key_layer = key_layer.transpose(0, 1).transpose(1, 2) + cos, sin = self.rotary_emb(value_layer, kv_seq_len) + query_layer, key_layer = apply_mistral_rotary_pos_emb( + query_layer, key_layer, cos, sin, position_ids) + + value_layer = value_layer.transpose(1, 2).transpose(0, 1) + query_layer = query_layer.transpose(1, 2).transpose(0, 1) + key_layer = key_layer.transpose(1, 2).transpose(0, 1) + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[ + :sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[ + :sequence_end, batch_start:batch_end, ...] + + + # adjust the key rotary positional embedding + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + # need to cross check this condition during inference + # if not set_inference_key_value_memory: + if not is_first_step: + # In inference, we compute one token at a time. + # Select the correct positional embedding + # (only the last token in the sequence) + q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] + else: + # In the first forward pass of inference, + # we use the entire provided prefix. + # q_pos_emb here has the rope embeddings of the entire + # prefix + to-be-generated output so + # we slice to just the prefix. + q_pos_emb = q_pos_emb[:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] + key_layer = key_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + value_layer = value_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + + # apply relative positional encoding (rotary embedding) + if self.use_mistral_rotary_position_embeddings: + kv_seq_len = key_layer.shape[0] + value_layer = value_layer.transpose(0, 1).transpose(1, 2) + query_layer = query_layer.transpose(0, 1).transpose(1, 2) + key_layer = key_layer.transpose(0, 1).transpose(1, 2) + cos, sin = self.rotary_emb(value_layer, kv_seq_len) + query_layer, key_layer = apply_mistral_rotary_pos_emb( + query_layer, key_layer, cos, sin, position_ids) + + value_layer = value_layer.transpose(1, 2).transpose(0, 1) + query_layer = query_layer.transpose(1, 2).transpose(0, 1) + key_layer = key_layer.transpose(1, 2).transpose(0, 1) + else: + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + if not self.use_flash_attn: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) + else: + q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + context_layer = self.core_attention_flash(q, k, v) + else: + context_layer = self.core_attention_flash(q, k, v) + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class ParallelTransformerLayer(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0., num_experts=1): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_norm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Normalize the input data. + self.input_norm = get_norm(config) + + # Self attention. + self.self_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None + + # Normalize the attention output + self.post_attention_norm = get_norm(config) + + # Cross attention. + if self.layer_type in (LayerType.decoder, + LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever, + LayerType.retro_encoder): + self.inter_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.cross_attn) + # Normalize the attention output. + self.post_inter_attention_norm = get_norm(config) + + # MLP + if num_experts == 1: + self.mlp = ParallelMLP(config) + else: + expert_tensor_parallelism = args.expert_tensor_parallelism + moe_layer_index = (layer_number-1) // args.expert_interval + if args.rank == 0: + print('Experts set to %s, expert parallel size set to %d' + % (str(args.num_experts), args.moe_expert_parallel_size)) + self.mlp = MoE(args.hidden_size, + ParallelMLP(config, is_expert=True), + num_experts=args.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_topk, + use_residual=False, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + expert_tensor_parallelism=expert_tensor_parallelism, + moe_layer_index=moe_layer_index) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + if args.retro_add_retriever: + retro_args = get_retro_args() + self.retro_num_neighbors = args.retro_num_neighbors + self.retro_chunk_length = retro_args.retro_gpt_chunk_length + self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length + + # Retriever (bi-directional transformer with cross attention) + if layer_type == LayerType.retro_decoder_with_retriever: + self.retriever = ParallelTransformer( + config=config, + model_type=ModelType.retro_encoder, + self_attn_mask_type=AttnMaskType.padding, + pre_process=True, + post_process=False, + ) + self._retriever_key = 'retriever' + else: + self.retriever = None + + def default_decoder_cross_attention(self, + encoder_output, + enc_dec_attn_mask, + norm_input, + norm_output, + bias_dropout_add_func): + '''Cross attention for a standard encoder-decoder model.''' + + # Attention. + attention_output, attention_bias = \ + self.inter_attention(norm_output, + enc_dec_attn_mask, + encoder_output=encoder_output) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + + # Bias-dropout-add. + with self.bias_dropout_add_exec_handler(): + norm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + + # Normalize. + norm_output = self.post_inter_attention_norm(norm_input) + + return norm_input, norm_output + + def retro_encoder_cross_attention(self, + retriever_output, + norm_input, + norm_output, + bias_dropout_add_func): + """Cross attention for Retro encoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = norm_output.shape # [r, bs * l * k, d] + + # Divide sequence dimension into chunks. + chunked_outputs = norm_output.reshape(self.retro_retrieved_length, + -1, + self.retro_num_neighbors, + d) + chunked_outputs_before_norm = \ + norm_input.reshape(self.retro_retrieved_length, -1, + self.retro_num_neighbors, d) # [r, bs*l, k, d] + + # Per-chunk attention. + norm_inputs = [] + norm_outputs = [] + for k in range(self.retro_num_neighbors): + + # Attention. + chunked_output = chunked_outputs[:,:,k].contiguous() + attention_output, attention_bias = \ + self.inter_attention( + chunked_output, # Q (neighbor embedding) + None, + encoder_output=retriever_output) # K, V (hidden act) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = chunked_output + else: + residual = chunked_outputs_before_norm[:,:,k] + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + norm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + norm_inputs.append(norm_input) + + # Layer norm. + norm_output = self.post_inter_attention_norm(norm_input) + norm_outputs.append(norm_output) + + # Concatenate layer norms. + # norm_input : [r, k * bs * l, d] + # norm_output : [r, k * bs * l, d] + norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d) + norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d) + + return norm_input, norm_output + + def retro_decoder_cross_attention(self, + retriever_input, + retriever_output, + retriever_attn_mask, + norm_input, + norm_output, + inference_params, + bias_dropout_add_func): + """Cross attention for Retro decoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + m : Number of tokens per chunk. + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = norm_output.shape + l = int(np.ceil(ns / self.retro_chunk_length)) + + # Retrieve neighbors. + if self.layer_type == LayerType.retro_decoder_with_retriever: + first_ns = ns % self.retro_chunk_length + if first_ns > 0: + raise Exception("test this case.") + first_chunk, rest_chunk = \ + norm_output[:first_ns], norm_output[first_ns:] + first_chunk = torch.nn.functional.pad( + first_chunk, + (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), + 'constant', + 0) + chunked_output = \ + torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] + else: + chunked_output = norm_output # [l * m, bs, d] + chunked_output = chunked_output \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) \ + .reshape(self.retro_chunk_length, bs * l, d) \ + .contiguous() + + # Get Encoder Output + retriever_output = self.retriever( + hidden_states=retriever_input, + attention_mask=retriever_attn_mask, + retriever_output=chunked_output, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) # [r, k * bs * l , d] + retriever_output = retriever_output.reshape( + self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d] + + # Chunks. + pad = (ns - 1) % self.retro_chunk_length + attending_chunks = norm_output[pad:] + padded_chunks = torch.nn.functional.pad( + attending_chunks, + (0, 0, 0, 0, 0, self.retro_chunk_length - 1), + 'constant', 0) + padded_chunked_output = padded_chunks \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) + padded_chunked_output = padded_chunked_output.reshape( + self.retro_chunk_length, bs * l, d).contiguous() + + # Encoder output. + attention_output, attention_bias = \ + self.inter_attention(padded_chunked_output, + None, + encoder_output=retriever_output) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + norm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(attention_output), + torch.zeros_like(attention_output), + self.hidden_dropout) + norm_input = norm_input \ + .reshape(self.retro_chunk_length, bs, l, d) \ + .permute(2, 0, 1, 3) # [l, m, bs, d] + norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d) + norm_input = torch.nn.functional.pad( + norm_input, + (0, 0, 0, 0, pad, 0), + 'constant', 0)[:ns] # [ns, b, d] + norm_input = norm_input + residual + + # Layer norm post the decoder attention + norm_output = self.post_inter_attention_norm(norm_input) + + return retriever_output, norm_input, norm_output + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + position_ids=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + norm_output = self.input_norm(hidden_states) + + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + norm_output, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + position_ids=position_ids + ) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = hidden_states + + if self.drop_path is None: + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + norm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + else: + out = torch.nn.functional.dropout(attention_output + attention_bias, + p=self.hidden_dropout, + training=self.training) + norm_input = residual + self.drop_path(out) + + # Layer norm post the self attention. + norm_output = self.post_attention_norm(norm_input) + + # Cross attention. + if self.layer_type == LayerType.encoder: + pass + elif self.layer_type == LayerType.decoder: + norm_input, norm_output = \ + self.default_decoder_cross_attention( + encoder_output, + enc_dec_attn_mask, + norm_input, + norm_output, + bias_dropout_add_func) + elif self.layer_type == LayerType.retro_encoder: + norm_input, norm_output = \ + self.retro_encoder_cross_attention( + retriever_output, + norm_input, + norm_output, + bias_dropout_add_func) + elif self.layer_type in (LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever): + retriever_output, norm_input, norm_output = \ + self.retro_decoder_cross_attention( + retriever_input, + retriever_output, + retriever_attn_mask, + norm_input, + norm_output, + inference_params, + bias_dropout_add_func) + else: + raise Exception("Unsupported layer type, '%s'." % + self.layer_type.name) + + # MLP. + mlp_output, mlp_bias = self.mlp(norm_output) + #import pdb;pdb.set_trace() + # Second residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + if self.drop_path is None: + if mlp_bias is not None: + mlp_bias = mlp_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias, + residual, + self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = core.utils.make_viewless_tensor(inp = output, + requires_grad = output.requires_grad, + keep_graph = True) + + else: + if mlp_bias is not None: + mlp_output = mlp_output + mlp_bias + out = torch.nn.functional.dropout(mlp_output, + p=self.hidden_dropout, + training=self.training) + output = residual + self.drop_path(out) + + if self.layer_type == LayerType.retro_decoder_with_retriever: + return output, retriever_output + else: + return output + + +class NoopTransformerLayer(MegatronModule): + """A single 'no-op' transformer layer. + + The sole purpose of this layer is for when a standalone embedding layer + is used (i.e., args.standalone_embedding_stage == True). In this case, + zero transformer layers are assigned when pipeline rank == 0. Additionally, + when virtual pipeline rank >= 1, zero total model parameters are created + (virtual rank 0 contains the input embedding). This results in the model's + input and output tensors being the same, which causes an error when + performing certain memory optimiations on the output tensor (e.g., + deallocating it). Thus, this layer disconnects the input from the output + via a clone. Since ranks containing a no-op layer are generally under- + utilized (both compute and memory), there's no worry of any performance + degredation. + """ + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def _get_num_layers(args, model_type, is_decoder=False): + """Compute the number of transformer layers resident on the current rank.""" + is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder) + if model_type == ModelType.retro_encoder: + num_layers = args.retro_encoder_layers + elif mpu.get_pipeline_model_parallel_world_size() > 1: + if is_encoder_and_decoder_model: + assert args.pipeline_model_parallel_split_rank is not None + + # When a standalone embedding stage is used, a rank is taken from + # the encoder's ranks, to be used for the encoder's embedding + # layer. This way, the rank referenced by the 'split rank' remains + # the same whether or not a standalone embedding stage is used. + num_ranks_in_encoder = ( + args.pipeline_model_parallel_split_rank - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_split_rank + ) + num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder + assert args.encoder_num_layers % num_ranks_in_encoder == 0, \ + 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) + assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ + 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) + if mpu.is_pipeline_stage_before_split(): + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.encoder_num_layers // num_ranks_in_encoder + ) + else: + num_layers = args.decoder_num_layers // num_ranks_in_decoder + else: + assert args.num_layers == args.encoder_num_layers + assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + + # When a standalone embedding stage is used, all transformer layers + # are divided among pipeline rank >= 1, while on pipeline rank 0, + # ranks either contain the input embedding layer (virtual pp rank 0), + # or no layers at all (virtual pp rank >= 1). + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.num_layers // args.transformer_pipeline_model_parallel_size + ) + else: + if not is_decoder: + num_layers = args.encoder_num_layers + else: + num_layers = args.decoder_num_layers + return num_layers + + +def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, + layer_number): + args = get_args() + if args.retro_add_retriever and layer_number in retro_layer_numbers: + if model_type == ModelType.retro_decoder: + return LayerType.retro_decoder_with_retriever \ + if layer_number == retro_layer_numbers[0] \ + else LayerType.retro_decoder + elif model_type == ModelType.retro_encoder: + return LayerType.retro_encoder + else: + raise Exception("Unsupported model type, '%s'." % model_type) + else: + return default_layer_type + + +class ParallelTransformer(MegatronModule): + """Transformer class.""" + + def __init__(self, config, + model_type, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + post_norm=True, + pre_process=True, + post_process=True, + drop_path_rate=0.0): + super(ParallelTransformer, self).__init__() + args = get_args() + + self.layer_type = layer_type + self.model_type = model_type + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_norm = post_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + self.drop_path_rate = drop_path_rate + self.transformer_impl = args.transformer_impl + self.retro_add_retriever = args.retro_add_retriever + + # Store activation checkpoiting flag. + self.recompute_granularity = config.recompute_granularity + self.recompute_method = config.recompute_method + self.recompute_num_layers = config.recompute_num_layers + self.distribute_saved_activations = \ + config.distribute_saved_activations and not config.sequence_parallel + + self.sequence_parallel = config.sequence_parallel + + # Transformer Engine Init. + self.transformer_engine_v_0_10 = False + self.transformer_engine_v_0_11 = False + self.transformer_engine_v_0_8 = False + if self.transformer_impl == 'transformer_engine': + global transformer_engine + import transformer_engine + from importlib.metadata import version + from pkg_resources import packaging + + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("0.8.0"): + self.transformer_engine_v_0_8 = True + if te_version >= packaging.version.Version("0.10.0"): + self.transformer_engine_v_0_10 = True + if te_version >= packaging.version.Version("0.11.0"): + self.transformer_engine_v_0_11 = True + + del version, packaging + + assert not args.squared_relu, "TransformerEngine does not support squared relu activation." + + self.use_fp8 = args.fp8 is not None + self.fp8_recipe = None + self.fp8_group = None + if self.use_fp8: + assert args.transformer_impl == 'transformer_engine', \ + 'transformer-engine required for fp8 training and inference' + self.fp8_group = mpu.get_amax_reduction_group() + if args.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif args.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") + self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=args.fp8_margin, + interval=args.fp8_interval, + fp8_format=fp8_format, + amax_history_len=args.fp8_amax_history_len, + amax_compute_algo=args.fp8_amax_compute_algo, + override_linear_precision=(False, False, not args.fp8_wgrad), + ) + + self.num_microbatches_in_previous_step = -1 + self.microbatch_count = 0 + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + # Number of layers. + self.num_layers = _get_num_layers(args, model_type, + layer_type==LayerType.decoder) + + self.drop_path_rates = [ + rate.item() for rate in + torch.linspace(0, self.drop_path_rate, config.num_layers)] + + self.retro_layer_numbers = None + if model_type == ModelType.retro_decoder: + retro_layer_start = 6 if config.num_layers <= 15 else 9 + self.retro_layer_numbers = \ + np.arange(retro_layer_start, args.num_layers + 1, 3).tolist() + if model_type == ModelType.retro_encoder: + self.retro_layer_numbers = [1] + + # Transformer layers. + if args.retro_add_retriever: + assert self.recompute_granularity != 'full', \ + "Full recompute not supported for Retro." + assert args.transformer_impl == 'local', \ + "Transformer engine does not support Retro layers." + def build_layer(layer_number, n_e): + if args.transformer_impl == 'local': + current_layer_type = _get_layer_type( + model_type, layer_type, self.retro_layer_numbers, + layer_number) + return ParallelTransformerLayer( + config, + layer_number, + layer_type=current_layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=self.drop_path_rates[layer_number - 1], + num_experts=n_e) + else: + # This argument is only available from TE v0.10 onwards. + extra_transformer_engine_kwargs = {} + if self.transformer_engine_v_0_8: + extra_transformer_engine_kwargs["bias"] = args.add_bias_linear + if self.transformer_engine_v_0_10: + extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu" + if self.transformer_engine_v_0_11: + extra_transformer_engine_kwargs["normalization"] = args.normalization + return transformer_engine.pytorch.TransformerLayer( + config.hidden_size, + config.ffn_hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.layernorm_epsilon, + hidden_dropout=config.hidden_dropout, + attention_dropout=config.attention_dropout, + init_method=config.init_method, + output_layer_init_method=config.output_layer_init_method, + layer_number=layer_number, + kv_channels=config.kv_channels, + self_attn_mask_type=self_attn_mask_type.name, + tp_group=mpu.get_tensor_model_parallel_group(), + get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, + fuse_wgrad_accumulation=config.gradient_accumulation_fusion, + apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, + attention_softmax_in_fp32=config.attention_softmax_in_fp32, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + sequence_parallel=config.sequence_parallel, + params_dtype=config.params_dtype, + apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=self.drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=True, + **extra_transformer_engine_kwargs) + + if config.virtual_pipeline_model_parallel_size is not None: + assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ + 'num_layers_per_stage must be divisible by ' \ + 'virtual_pipeline_model_parallel_size' + assert args.model_type != ModelType.encoder_and_decoder + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( + config.num_layers // config.virtual_pipeline_model_parallel_size) + \ + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) + else: + # Each stage gets a contiguous set of layers. + if args.model_type == ModelType.encoder_and_decoder and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if layer_type == LayerType.encoder: + offset = pipeline_rank * self.num_layers + else: + num_ranks_in_enc = args.pipeline_model_parallel_split_rank + offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers + else: + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + + self.layers = [] + if self.num_layers == 0: + # When a standalone embedding stage is used (e.g., + # args.standalone_embedding_stage == True), virtual pipeline ranks + # on pipeline rank 0 will have zero transformer layers assigned to + # them. This results in the model's input and output tensors to be + # the same, which will cause failure for certain output tensor + # optimizations (e.g., pipeline output deallocation). To remedy + # this, we assign a 'no-op' layer on these ranks, which will + # disconnect the input tensor from the output tensor. + self.num_layers = 1 + self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + else: + num_experts = [args.num_experts] * (args.num_layers // args.expert_interval) + for i in range(self.num_layers): + layer_num = i + 1 + offset + if layer_num % args.expert_interval == 0: + n_e = num_experts[(layer_num-1) // args.expert_interval] + else: + n_e = 1 + self.layers.append(build_layer(layer_num, n_e)) + self.layers = torch.nn.ModuleList(self.layers) + + """ + self.layers = torch.nn.ModuleList( + [build_layer(i + 1 + offset) for i in range(self.num_layers)]) + """ + + # Update dropout rate for Retro encoder. + if model_type == ModelType.retro_encoder: + for layer in self.layers: + if layer.self_attention.use_flash_attn: + layer.self_attention.core_attention_flash.dropout_p = \ + torch.nn.Dropout(args.retro_encoder_attention_dropout) + else: + layer.self_attention.core_attention.attention_dropout.p =\ + args.retro_encoder_attention_dropout + layer.hidden_dropout = args.retro_encoder_hidden_dropout + + if self.post_process and self.post_norm: + # Final layer norm before output. + self.final_norm = get_norm(config) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + rotary_pos_emb, position_ids, is_first_microbatch): + """Forward method with activation checkpointing.""" + def custom(start, end): + def custom_forward(*args, **kwargs): + x_, *args = args + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, *args, **kwargs) + return x_ + return custom_forward + + te_forward_kwargs = {} + if self.transformer_impl == 'transformer_engine': + te_forward_kwargs['is_first_microbatch'] = is_first_microbatch + if self.transformer_engine_v_0_10: + te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + + if self.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and + # checkpoint the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb, position_ids) + + l += self.recompute_num_layers + + elif self.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers): + if l < self.recompute_num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb, position_ids) + else: + if self.transformer_impl == 'transformer_engine': + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb, position_ids) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + position_ids=None): + # hidden_states: [s, b, h] + + # Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers. + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + position_ids, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['position_ids'] = position_ids + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + assert len(hidden_states) == 2 + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + + # Final layer norm. + if self.post_process and self.post_norm: + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + def load_state_dict(self, state_dict, strict=True): + """Customize load.""" + + # Handle renaming layernorm -> norm in component names + args = get_args() + state_dict_ = {} + for key in state_dict.keys(): + if args.transformer_impl != "transformer_engine": + newkey = key.replace("layernorm", "norm") + state_dict_[newkey] = state_dict[key] + else: + state_dict_[key] = state_dict[key] + + super().load_state_dict(state_dict_, False) diff --git a/megatron_patch/model/qwen_vl/__init__.py b/megatron_patch/model/qwen_vl/__init__.py new file mode 100644 index 00000000..1f6175dc --- /dev/null +++ b/megatron_patch/model/qwen_vl/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/megatron_patch/model/qwen_vl/gpt_model.py b/megatron_patch/model/qwen_vl/gpt_model.py new file mode 100644 index 00000000..7981be7d --- /dev/null +++ b/megatron_patch/model/qwen_vl/gpt_model.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron import get_args +from megatron.core import tensor_parallel +from megatron.model.module import MegatronModule + +from megatron.model.enums import AttnMaskType +from .language_model import parallel_lm_logits +from .language_model import get_language_model + + +def post_language_model_processing(lm_output, labels, logit_weights, + parallel_output, + fp16_lm_cross_entropy): + + # Output. Format [s b h] + output = parallel_lm_logits( + lm_output, + logit_weights, + parallel_output) + + if labels is None: + # [s b h] => [b s h] + return output.transpose(0,1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + return loss + + +class GPTModel(MegatronModule): + """GPT-2 Language model.""" + + def __init__(self, + config, + num_tokentypes=0, + parallel_output=True, + pre_process=True, + post_process=True): + args = get_args() + super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights) + + self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy + self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights + + self.language_model, self._language_model_key = get_language_model( + config=config, + num_tokentypes=num_tokentypes, + add_pooler=False, + encoder_attn_mask_type=AttnMaskType.causal, + pre_process=self.pre_process, + post_process=self.post_process) + + if not args.untie_embeddings_and_output_weights: + self.initialize_word_embeddings() + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + self.language_model.set_input_tensor(input_tensor) + + def forward(self, input_ids, position_ids, attention_mask, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + labels=None, images=None, inference_params=None): + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + images=images, + retriever_input_ids=retriever_input_ids, + retriever_position_ids=retriever_position_ids, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) + + if self.post_process: + return post_language_model_processing( + lm_output, labels, + self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(), + self.parallel_output, + self.fp16_lm_cross_entropy) + else: + return lm_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + + state_dict_ = {} + state_dict_[self._language_model_key] \ + = self.language_model.state_dict_for_save_checkpoint( + prefix=prefix, keep_vars=keep_vars) + # Save word_embeddings. + if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + state_dict_[self._word_embeddings_for_head_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Load word_embeddings. + if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + self.word_embeddings.load_state_dict( + state_dict[self._word_embeddings_for_head_key], strict=strict) + if self._language_model_key in state_dict: + state_dict = state_dict[self._language_model_key] + self.language_model.load_state_dict(state_dict, strict=strict) diff --git a/megatron_patch/model/qwen_vl/language_model.py b/megatron_patch/model/qwen_vl/language_model.py new file mode 100644 index 00000000..ee50ba39 --- /dev/null +++ b/megatron_patch/model/qwen_vl/language_model.py @@ -0,0 +1,680 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron import get_args +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.model.enums import AttnMaskType +from megatron.model.module import MegatronModule +from megatron.model.utils import get_linear_layer +from megatron.model.utils import init_method_normal +from megatron.model.utils import scaled_init_method_normal +from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding + +from megatron_patch.model.mistral.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from .transformer import ParallelTransformer +from .visual import VisionTransformer + +def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, + bias=None): + """LM logits using word embedding weights.""" + args = get_args() + # Parallel logits. + if args.async_tensor_model_parallel_allreduce or\ + args.sequence_parallel: + input_parallel = input_ + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 + async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ + model_parallel and not args.sequence_parallel + else: + input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) + async_grad_allreduce = False + + # Matrix multiply. + logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( + input=input_parallel, + weight=word_embeddings_weight, + bias=bias, + gradient_accumulation_fusion=args.gradient_accumulation_fusion, + async_grad_allreduce=async_grad_allreduce, + sequence_parallel=args.sequence_parallel) + # Gather if needed. + + if parallel_output: + return logits_parallel + + return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) + + +def get_language_model(config, num_tokentypes, add_pooler, + encoder_attn_mask_type, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + pre_process=True, post_process=True): + """Build language model and return along with the key to save.""" + args = get_args() + if config.init_method is None: + config.init_method = init_method_normal(config.init_method_std) + + if config.output_layer_init_method is None: + config.output_layer_init_method = scaled_init_method_normal(config.init_method_std, + config.num_layers) + + # Language model. + language_model = TransformerLanguageModel( + config, + encoder_attn_mask_type, + num_tokentypes=num_tokentypes, + add_encoder=add_encoder, + add_decoder=add_decoder, + decoder_attn_mask_type=decoder_attn_mask_type, + add_pooler=add_pooler, + pre_process=pre_process, + post_process=post_process + ) + # key used for checkpoints. + language_model_key = 'language_model' + + return language_model, language_model_key + + +class Pooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, hidden_size, init_method): + super(Pooler, self).__init__() + args = get_args() + self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.sequence_parallel = args.sequence_parallel + + + def forward(self, hidden_states, sequence_index=0): + # hidden_states: [s, b, h] + # sequence_index: index of the token to pool. + + # gather data along sequence dimensions + # same pooler is run on all tensor parallel nodes + if self.sequence_parallel: + hidden_states = tensor_parallel.gather_from_sequence_parallel_region( + hidden_states, + tensor_parallel_output_grad=False) + + pooled = hidden_states[sequence_index, :, :] + pooled = self.dense(pooled) + pooled = torch.tanh(pooled) + return pooled + + +class Embedding(MegatronModule): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + hidden_size, + vocab_size, + max_sequence_length, + embedding_dropout_prob, + config, + num_tokentypes=0): + super(Embedding, self).__init__() + + self.hidden_size = hidden_size + self.init_method = config.init_method + self.num_tokentypes = num_tokentypes + + args = get_args() + + # Word embeddings (parallel). + self.params_dtype = args.params_dtype + self.word_embeddings = tensor_parallel.VocabParallelEmbedding( + vocab_size, self.hidden_size, config=config, init_method=config.init_method) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.add_position_embedding = args.position_embedding_type == 'learned_absolute' + if self.add_position_embedding: + self.position_embeddings = torch.nn.Embedding( + max_sequence_length, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + if args.perform_initialization: + self.init_method(self.position_embeddings.weight) + + # Token type embedding. + # Add this as an optional field that can be added through + # method call so we can load a pretrain model without + # token types and add them as needed. + self._tokentype_embeddings_key = 'tokentype_embeddings' + if self.num_tokentypes > 0: + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + if args.perform_initialization: + self.init_method(self.tokentype_embeddings.weight) + else: + self.tokentype_embeddings = None + + self.fp32_residual_connection = args.fp32_residual_connection + self.sequence_parallel = args.sequence_parallel + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + if self.add_position_embedding: + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + if self.num_tokentypes > 0: + self.tokentype_embeddings.weight.data.fill_(0) + self.tokentype_embeddings.weight.shared = True + + def add_tokentype_embeddings(self, num_tokentypes): + """Add token-type embedding. This function is provided so we can add + token-type embeddings in case the pretrained model does not have it. + This allows us to load the model normally and then add this embedding. + """ + if self.tokentype_embeddings is not None: + raise Exception('tokentype embeddings is already initialized') + if torch.distributed.get_rank() == 0: + print('adding embedding for {} tokentypes'.format(num_tokentypes), + flush=True) + self.num_tokentypes = num_tokentypes + self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, + self.hidden_size) + # Initialize the token-type embeddings. + args = get_args() + self.init_method(self.tokentype_embeddings.weight) + + def forward(self, input_ids, position_ids, tokentype_ids=None): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + if self.add_position_embedding: + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + else: + embeddings = words_embeddings + + if tokentype_ids is not None: + assert self.tokentype_embeddings is not None + embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + else: + assert self.tokentype_embeddings is None + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + with tensor_parallel.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.add_position_embedding: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self.add_position_embedding: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + # Tokentype embedding. + if self.num_tokentypes > 0: + state_dict_ = {} + if self._tokentype_embeddings_key in state_dict: + state_dict_ = state_dict[self._tokentype_embeddings_key] + else: + # for backward compatibility. + for key in state_dict.keys(): + if 'tokentype_embeddings' in key: + state_dict_[key.split('tokentype_embeddings.')[1]] \ + = state_dict[key] + if len(state_dict_.keys()) > 0: + self.tokentype_embeddings.load_state_dict(state_dict_, + strict=strict) + else: + print('***WARNING*** expected tokentype embeddings in the ' + 'checkpoint but could not find it', flush=True) + + +class TransformerLanguageModel(MegatronModule): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + config, + encoder_attn_mask_type, + num_tokentypes=0, + add_encoder=True, + add_decoder=False, + decoder_attn_mask_type=AttnMaskType.causal, + add_pooler=False, + pre_process=True, + post_process=True): + self.args = get_args() + # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. + if self.args.untie_embeddings_and_output_weights: assert not add_decoder + super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not self.args.untie_embeddings_and_output_weights) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = config.hidden_size + self.num_tokentypes = num_tokentypes + self.init_method = config.init_method + self.add_encoder = add_encoder + self.encoder_attn_mask_type = encoder_attn_mask_type + self.add_decoder = add_decoder + self.decoder_attn_mask_type = decoder_attn_mask_type + self.add_pooler = add_pooler + self.encoder_hidden_state = None + self.add_retriever = self.args.retro_add_retriever + self.untie_embeddings_and_output_weights = self.args.untie_embeddings_and_output_weights + + self.visual_config = {"heads": 16, "image_size": 448, "image_start_id": 151857, + "layers": 1, "mlp_ratio": 4.9231, "output_dim": 2048, "patch_size": 14, "width": 1664} + + self.visual = VisionTransformer(**self.visual_config) + + # Embeddings. + if self.pre_process: + self.embedding = Embedding(self.hidden_size, + self.args.padded_vocab_size, + self.args.max_position_embeddings, + self.args.hidden_dropout, + config, + self.num_tokentypes) + self._embedding_key = 'embedding' + + if self.args.freeze_llm: + for param in self.embedding.parameters(): + param.requires_grad = False + + # Rotary positional embeddings + if self.args.use_rotary_position_embeddings: + self.seq_length = self.args.seq_length + rotary_dim = self.args.hidden_size // self.args.num_attention_heads \ + if self.args.kv_channels is None else self.args.kv_channels + + if self.args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * self.args.rotary_percent) + + # partial rotary embeddings, which is better than full rotary + # Wang and Komatsuzaki et al + # https://github.com/kingoflolz/mesh-transformer-jax/ + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim, + seq_len_interpolation_factor=self.args.rotary_seq_len_interpolation_factor + ) + self.use_rotary_position_embeddings = True + elif self.args.use_llama2_rotary_position_embeddings: + self.use_rotary_position_embeddings = False + + + if self.add_encoder: + self.encoder = ParallelTransformer( + config, + model_type=self.args.model_type if not self.args.retro_add_retriever \ + else ModelType.retro_decoder, + self_attn_mask_type=self.encoder_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + ) + self._encoder_key = 'encoder' + + if self.args.freeze_llm: + for param in self.encoder.parameters(): + param.requires_grad = False + + if self.post_process: + if self.untie_embeddings_and_output_weights: + self.output_layer = tensor_parallel.ColumnParallelLinear( + self.args.hidden_size, + self.args.padded_vocab_size, + config=config, + init_method=self.init_method, + bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. + self._output_layer_key = 'output_layer' + + if self.args.freeze_llm: + for param in self.output_layer.parameters(): + param.requires_grad = False + + def encode_images(self, images): + image_features = self.vision_tower(images) + image_features = self.mm_projector(image_features) + return image_features + + def set_input_tensor(self, input_tensor): + """ See megatron.model.transformer.set_input_tensor()""" + + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + if self.add_encoder and self.add_decoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with both encoder and decoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + assert len(input_tensor) == 1, \ + 'input_tensor should only be length 1 for stage with only encoder' + self.encoder.set_input_tensor(input_tensor[0]) + elif self.add_decoder: + if len(input_tensor) == 2: + self.decoder.set_input_tensor(input_tensor[0]) + self.encoder_hidden_state = input_tensor[1] + elif len(input_tensor) == 1: + self.decoder.set_input_tensor(None) + self.encoder_hidden_state = input_tensor[0] + else: + raise Exception('input_tensor must have either length 1 or 2') + else: + raise Exception('Stage must have at least either encoder or decoder') + + def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, + dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, + retriever_input_ids=None, + retriever_position_ids=None, + retriever_attn_mask=None, + enc_dec_attn_mask=None, tokentype_ids=None, + inference_params=None, + pooling_sequence_index=0, + enc_hidden_states=None, output_enc_hidden=False, images=None): + + if torch.any(enc_input_ids == self.visual_config['image_start_id']): + bos_pos = torch.where(enc_input_ids == self.visual_config['image_start_id']) + eos_pos = torch.where(enc_input_ids == self.visual_config['image_start_id'] + 1) + assert (bos_pos[0] == eos_pos[0]).all() + img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) + images = [] + for i, a, b in img_pos: + image = enc_input_ids[i][a + 1: b - 1].tolist() + image = image[: image.index(self.visual_config['image_start_id'] + 2)] + images.append(bytes(image).decode('utf-8')) + + images = self.visual.encode(images) + fake_images = None + else: + fake_images=torch.zeros(1, 3, 224, 224).to( + dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device) + images = self.visual(fake_images) + + + # Encoder embedding. + if self.pre_process: + encoder_input = self.embedding(enc_input_ids, enc_position_ids, + tokentype_ids=tokentype_ids) + else: + encoder_input = None + + encoder_input = encoder_input.permute(1, 0, 2) + if fake_images is not None: + encoder_input = encoder_input + images.mean()*0 + elif images is not None: + for idx, (i, a, b) in enumerate(img_pos): + encoder_input[i][a + 1 : b] = images[idx] + + encoder_input = encoder_input.permute(1, 0, 2) + batch_size = enc_input_ids.shape[0] + seq_length = enc_input_ids.shape[1] + enc_attn_mask = _prepare_4d_causal_attention_mask( + enc_attn_mask, + (batch_size, seq_length), + encoder_input, + 0, + sliding_window=None, + ) + + # Rotary positional embeddings + rotary_pos_emb = None + if self.use_rotary_position_embeddings: + if inference_params is not None: + rotary_pos_emb = \ + self.rotary_pos_emb(inference_params.max_sequence_length) + else: + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + + if enc_position_ids is None: + past_key_values_length = 0 + seq_length = self.seq_length + device = enc_input_ids.device\ + if enc_input_ids is not None else encoder_input.device + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + enc_position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + retriever_input=None, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + position_ids=enc_position_ids + ) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + if self.post_process: + if self.add_pooler: + pooled_output = self.pooler(encoder_output, + pooling_sequence_index) + + # output_enc_hidden refers to when we just need the encoder's + # output. For example, it is helpful to compute + # similarity between two sequences by average pooling + if not self.add_decoder or output_enc_hidden: + if self.add_pooler and self.post_process: + return encoder_output, pooled_output + else: + return encoder_output + + # Decoder embedding. + if self.pre_process: + decoder_input = self.embedding(dec_input_ids, + dec_position_ids) + else: + decoder_input = None + + # Run decoder. + decoder_output = self.decoder( + decoder_input, + dec_attn_mask, + encoder_output=encoder_output, + enc_dec_attn_mask=enc_dec_attn_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + + if self.add_pooler and self.post_process: + return decoder_output, encoder_output, pooled_output + else: + return decoder_output, encoder_output + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.add_encoder: + state_dict_[self._encoder_key] \ + = self.encoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] \ + = self.pooler.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.untie_embeddings_and_output_weights: + state_dict_[self._output_layer_key] \ + = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_decoder: + state_dict_[self._decoder_key] \ + = self.decoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + args = get_args() + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if self.add_encoder: + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", + ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + if args.transformer_impl == "transformer_engine": + self.encoder.load_state_dict(state_dict_, strict=False) + else: + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Pooler. + if self.post_process: + if self.add_pooler: + assert 'pooler' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], + strict=strict) + if self.untie_embeddings_and_output_weights: + assert 'output_layer' in state_dict, \ + 'could not find data for output_layer in the checkpoint' + self.output_layer.load_state_dict(state_dict[self._output_layer_key], + strict=strict) + # Decoder. + if self.add_decoder: + assert 'decoder' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.decoder.load_state_dict(state_dict[self._decoder_key], + strict=strict) diff --git a/megatron_patch/model/qwen_vl/transformer.py b/megatron_patch/model/qwen_vl/transformer.py new file mode 100644 index 00000000..99ac9d34 --- /dev/null +++ b/megatron_patch/model/qwen_vl/transformer.py @@ -0,0 +1,1870 @@ +# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext +import math +import numpy as np +import torch +import torch.nn.functional as F +from typing import Optional + +from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches +from megatron.model.module import MegatronModule +from megatron.core import mpu, tensor_parallel +from megatron.core.enums import ModelType +from megatron.model.enums import AttnMaskType, LayerType, AttnType +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.fused_bias_gelu import bias_gelu_impl +from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_norm +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region_to_moe, reduce_scatter_to_sequence_parallel_region_from_moe +from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_and_data_parallel_group + + +from megatron_patch.model.llava.rotary_pos_embedding import RotaryEmbedding +from megatron_patch.model.llava.rotary_pos_embedding import apply_rotary_pos_emb as apply_llama2_rotary_pos_emb + + +try: + from einops import rearrange +except ImportError: + rearrange = None + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func + except ImportError: + flash_attn_unpadded_func = None + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + +class DropPath(MegatronModule): + """Drop paths (Stochastic Depth) per sample + (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=0.): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_state): + if self.drop_prob == 0. or not self.training: + return hidden_state + keep_prob = 1 - self.drop_prob + # work with diff dim tensors, not just 2D ConvNets + # hidden_state: [s, b, h] + shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2) + random_tensor = keep_prob + \ + torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) + random_tensor.floor_() # binarize + output = hidden_state.div(keep_prob) * random_tensor + return output + +class ParallelMLP(MegatronModule): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config, is_expert=False): + super(ParallelMLP, self).__init__() + args = get_args() + + self.add_bias = config.add_bias_linear + + ffn_hidden_size = config.ffn_hidden_size + if config.gated_linear_unit: + ffn_hidden_size *= 2 + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + ffn_hidden_size, + config=config, + init_method=config.init_method, + bias=self.add_bias, + gather_output=False, + skip_bias_add=True, + is_expert=is_expert, + ) + + self.bias_gelu_fusion = False + self.activation_func = None + self.swiglu = args.swiglu + + if args.openai_gelu: + self.activation_func = openai_gelu + elif args.onnx_safe: + self.activation_func = erf_gelu + elif args.swiglu: + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + self.activation_func = swiglu + elif args.squared_relu: + def squared_relu(x): + return torch.pow(F.relu(x), 2) + self.activation_func = squared_relu + else: + self.bias_gelu_fusion = args.bias_gelu_fusion + self.activation_func = F.gelu + + # Project back to h. + self.dense_4h_to_h = tensor_parallel.RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=self.add_bias, + input_is_parallel=True, + skip_bias_add=True, + is_expert=is_expert, + ) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + + if self.bias_gelu_fusion: + assert self.add_bias is True + assert self.activation_func == F.gelu + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + +def sinkhorn(cost, tol=0.0001): + cost = torch.exp(cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1/d0.size(0))*1/(torch.sum(d1*cost,1) + eps) + d1 = (1/d1.size(0))*1/(torch.sum(d0.unsqueeze(1)*cost,0)+eps) + error = torch.mean(torch.abs(d1_old-d1)) + d1_old = d1 + return d1*cost*d0.unsqueeze(1) + +class SwitchMLP(MegatronModule): + """ + Routes input to one of N MLP "experts" + """ + def __init__(self, config): + super(SwitchMLP, self).__init__() + args = get_args() + self.router = torch.nn.Linear(args.hidden_size, args.num_experts) + self.expert_parallel = config.expert_parallel + self.sequence_parallel = config.sequence_parallel + self.add_bias = config.add_bias_linear + + if self.expert_parallel: + assert args.num_experts % mpu.get_data_parallel_world_size() == 0 + self.num_local_experts = args.num_experts // mpu.get_data_parallel_world_size() + local_expert_indices_offset = mpu.get_data_parallel_rank() * self.num_local_experts + self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)] + else: + self.num_local_experts = args.num_experts + self.local_expert_indices = [i for i in range(self.num_local_experts)] + + self.local_experts = torch.nn.ModuleList() + for i in range(self.num_local_experts): + self.local_experts.append(ParallelMLP(config, is_expert=True)) + + def gather_indices(self, local_indices): + """ Gather tensors and concatinate along the first dimension.""" + if self.expert_parallel: + group = get_tensor_and_data_parallel_group() + else: + group = get_tensor_model_parallel_group() + world_size = torch.distributed.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return local_indices + + dim_size = list(local_indices.size()) + dim_size[0] = dim_size[0] * world_size + + # TODO pre allocate memory + output = torch.empty(dim_size, dtype=local_indices.dtype, + device=torch.cuda.current_device()) + torch.distributed._all_gather_base( + output, local_indices.contiguous(), group=group + ) + return output + + def forward(self, hidden_states): + # hidden_states: [b, s, h] + args = get_args() + s = hidden_states.size(0) + b = hidden_states.size(1) + h = hidden_states.size(2) + route = self.router(hidden_states).view(-1, args.num_experts) + + # TODO (rprenger) Right now we're just using the sinkhorn algorithm + # for load balancing. There should be an option to do no load balancing + # and the algorithm and parametets should be further tested + if self.training: + with torch.no_grad(): + sinkroute = sinkhorn(route.detach().to(dtype=torch.float32)) + _, max_ind = torch.max(sinkroute, dim=1) + route = torch.sigmoid(route) + max_prob = route[torch.arange(route.size(0)), max_ind] + else: + route = torch.sigmoid(route) + max_prob, max_ind = torch.max(route, dim=1) + + max_prob = torch.unsqueeze(max_prob, 1) + hidden_states = hidden_states.view(-1, hidden_states.size(2)) + + # TODO (rprenger) TODO this could be made easier to read + # Converting [s, b, h] to [s*b, h]. + # Each vector could be routed differently + if self.sequence_parallel or self.expert_parallel: + global_hidden_states = \ + gather_from_sequence_parallel_region_to_moe( + hidden_states, + expert_parallel=self.expert_parallel + ) + global_indices = self.gather_indices(max_ind) + else: + global_hidden_states = hidden_states + global_indices = max_ind + + output_total = torch.zeros_like(global_hidden_states) + if self.add_bias: + output_bias_total = torch.zeros_like(global_hidden_states) + + for expert_num, expert in enumerate(self.local_experts): + local_expert_index = self.local_expert_indices[expert_num] + local_indices = (global_indices == local_expert_index).nonzero() + hidden = global_hidden_states[local_indices, :] + output, output_bias = expert(hidden) + output_total[local_indices, :] = output + if self.add_bias: + output_bias = output_bias.expand_as(output) + output_bias_total[local_indices, :] = output_bias + + if self.sequence_parallel or self.expert_parallel: + output_total = \ + reduce_scatter_to_sequence_parallel_region_from_moe( + output_total, + expert_parallel=self.expert_parallel + ) + if self.add_bias: + output_bias_total = \ + reduce_scatter_to_sequence_parallel_region_from_moe( + output_bias_total, + expert_parallel=self.expert_parallel) + + # bias is duplicated across tensor parallelism ranks; + # reduce scatter reduces bias across tensor parallel_ranks + output_bias_total = \ + output_bias_total/mpu.get_tensor_model_parallel_world_size() + + output_total = output_total*max_prob + output_total = output_total.view(s, b, h) + if self.add_bias: + output_bias_total = output_bias_total*max_prob + output_bias_total = output_bias_total.view(s, b, h) + else: + output_bias_total = None + + return output_total, output_bias_total + + +class CoreAttention(MegatronModule): + + def __init__(self, layer_number, config, + attn_mask_type=AttnMaskType.padding): + super(CoreAttention, self).__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_partition = core.utils.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = core.utils.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, + self.attn_mask_type, + config.masked_softmax_fusion, + attention_mask_func, + self.attention_softmax_in_fp32, + coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, + value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.reshape(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( + (output_size[0]*output_size[1], output_size[2], output_size[3]), + query_layer.dtype, "mpu") + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_mask = attention_mask.to(torch.bool) + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), + output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, + device=None, dtype=None): + super().__init__() + assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + 'e.g., with pip install flash-attn') + assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, q, k, v): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the query, key, and value. (B, S, H, D) + """ + + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) + assert all((i.is_cuda for i in (q,k,v))) + + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + + q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + device=q.device) + + if self.training: + # during training q,k,v always have same seqlen + assert seqlen_k == seqlen_q + + is_causal = self.causal + cu_seqlens_k = cu_seqlens_q + dropout_p = self.dropout_p + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = seqlen_q == seqlen_k + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + device=q.device) + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, causal=is_causal + ) + + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + return output + + +class ParallelAttention(MegatronModule): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=AttnMaskType.padding): + super(ParallelAttention, self).__init__() + args = get_args() + self.layer_number = max(1, layer_number) + self.attention_type = attention_type + self.attn_mask_type = attn_mask_type + self.params_dtype = config.params_dtype + self.sequence_parallel = config.sequence_parallel + + self.group_query_attention = args.group_query_attention + self.num_query_groups = args.num_query_groups + + query_projection_size = config.kv_channels * config.num_attention_heads + if self.group_query_attention: + kv_projection_size = args.kv_channels * args.num_query_groups + else: + kv_projection_size = args.kv_channels * args.num_attention_heads + + self.use_flash_attn = args.use_flash_attn \ + and attention_type == AttnType.self_attn \ + and self.attn_mask_type == AttnMaskType.causal + if self.use_flash_attn: + if flash_attn_unpadded_func is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') + + # Per attention head and per partition values. + world_size = mpu.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = core.utils.divide( + query_projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = core.utils.divide( + config.num_attention_heads, world_size) + + if self.group_query_attention: + if args.num_query_groups % world_size != 0: + raise NotImplementedError('Currently the num_query_groups should be ' + 'a multiple of the tensor parallel size') + self.num_query_groups_per_partition = core.utils.divide( + args.num_query_groups, world_size) + else: + self.num_query_groups_per_partition = self.num_attention_heads_per_partition + + # Strided linear layer. + if attention_type == AttnType.self_attn: + self.query_key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=args.add_bias_linear, + gather_output=False) + else: + assert attention_type == AttnType.cross_attn + + if self.group_query_attention: + raise NotImplementedError("Grouped query attention not implemented for cross-attention.") + assert query_projection_size == kv_projection_size + + self.query = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + query_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.key_value = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + 2 * kv_projection_size, + config=config, + init_method=config.init_method, + bias=config.add_bias_linear, + gather_output=False) + + self.core_attention = CoreAttention(self.layer_number, config, + self.attn_mask_type) + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + if self.use_flash_attn: + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attention_dropout + ) + + # Output. + self.dense = tensor_parallel.RowParallelLinear( + query_projection_size, + config.hidden_size, + config=config, + init_method=config.output_layer_init_method, + bias=args.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True) + + if args.use_llama2_rotary_position_embeddings: + self.use_llama2_rotary_position_embeddings = True + self.seq_length = args.seq_length + rotary_dim = args.hidden_size // args.num_attention_heads \ + if args.kv_channels is None else args.kv_channels + + if args.rotary_percent < 1.0: + rotary_dim = int(rotary_dim * args.rotary_percent) + + # partial rotary embeddings, which is better than full rotary + # Wang and Komatsuzaki et al + # https://github.com/kingoflolz/mesh-transformer-jax/ + self.rotary_emb = RotaryEmbedding( + rotary_dim, + args.max_position_embeddings + ) + else: + self.use_llama2_rotary_position_embeddings = False + + def _checkpointed_attention_forward(self, query_layer, key_layer, + value_layer, attention_mask, + rotary_pos_emb=None): + """Forward method with activation checkpointing.""" + def custom_forward(*inputs): + query_layer = inputs[0] + key_layer = inputs[1] + value_layer = inputs[2] + attention_mask = inputs[3] + output_ = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + return output_ + + q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ + else rotary_pos_emb + + hidden_states = tensor_parallel.checkpoint( + custom_forward, + False, query_layer, key_layer, value_layer, attention_mask, + q_pos_emb, k_pos_emb) + + return hidden_states + + def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads): + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, + encoder_output=None, inference_params=None, + rotary_pos_emb=None, position_ids=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + is_first_step = False + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_length + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, + self.num_query_groups_per_partition) + + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + is_first_step = True + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + if self.attention_type == AttnType.self_attn: + + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query_layer, + key_layer, + value_layer) = torch.split( + mixed_x_layer, + [ + ( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) + else: + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv_layer, _ = self.key_value(encoder_output) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query_layer, _ = self.query(hidden_states) + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # ================================== + # Adjust key and value for inference + # ================================== + + # duplicate the pos_emb for self attention + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + if inference_params: + if self.use_llama2_rotary_position_embeddings: + kv_seq_len = key_layer.shape[0] + kv_seq_len += inference_params.sequence_len_offset + value_layer = value_layer.transpose(0, 1).transpose(1, 2) + query_layer = query_layer.transpose(0, 1).transpose(1, 2) + key_layer = key_layer.transpose(0, 1).transpose(1, 2) + cos, sin = self.rotary_emb(value_layer, kv_seq_len) + query_layer, key_layer = apply_llama2_rotary_pos_emb( + query_layer, key_layer, cos, sin, position_ids) + + value_layer = value_layer.transpose(1, 2).transpose(0, 1) + query_layer = query_layer.transpose(1, 2).transpose(0, 1) + key_layer = key_layer.transpose(1, 2).transpose(0, 1) + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[ + :sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[ + :sequence_end, batch_start:batch_end, ...] + + + # adjust the key rotary positional embedding + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + # need to cross check this condition during inference + # if not set_inference_key_value_memory: + if not is_first_step: + # In inference, we compute one token at a time. + # Select the correct positional embedding + # (only the last token in the sequence) + q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] + else: + # In the first forward pass of inference, + # we use the entire provided prefix. + # q_pos_emb here has the rope embeddings of the entire + # prefix + to-be-generated output so + # we slice to just the prefix. + q_pos_emb = q_pos_emb[:sequence_end, :, :, :] + k_pos_emb = k_pos_emb[:sequence_end, :, :, :] + rotary_pos_emb = (q_pos_emb, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] + key_layer = key_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + value_layer = value_layer.repeat_interleave( + self.num_attention_heads_per_partition // self.num_query_groups_per_partition, + dim = 2 + ) + + # apply relative positional encoding (rotary embedding) + if self.use_llama2_rotary_position_embeddings: + kv_seq_len = key_layer.shape[0] + value_layer = value_layer.transpose(0, 1).transpose(1, 2) + query_layer = query_layer.transpose(0, 1).transpose(1, 2) + key_layer = key_layer.transpose(0, 1).transpose(1, 2) + cos, sin = self.rotary_emb(value_layer, kv_seq_len) + query_layer, key_layer = apply_llama2_rotary_pos_emb( + query_layer, key_layer, cos, sin, position_ids) + + value_layer = value_layer.transpose(1, 2).transpose(0, 1) + query_layer = query_layer.transpose(1, 2).transpose(0, 1) + key_layer = key_layer.transpose(1, 2).transpose(0, 1) + else: + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + if not self.use_flash_attn: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask) + else: + q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() + for x in (query_layer, key_layer, value_layer)] + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + context_layer = self.core_attention_flash(q, k, v) + else: + context_layer = self.core_attention_flash(q, k, v) + context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor + if bias is not None: + x = x + bias + out = torch.nn.functional.dropout(x, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class ParallelTransformerLayer(MegatronModule): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.): + # retriever=None): + args = get_args() + + super(ParallelTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type + + self.apply_residual_connection_post_norm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Normalize the input data. + self.input_norm = get_norm(config) + + # Self attention. + self.self_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.self_attn, + attn_mask_type=self_attn_mask_type) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None + + # Normalize the attention output + self.post_attention_norm = get_norm(config) + + # Cross attention. + if self.layer_type in (LayerType.decoder, + LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever, + LayerType.retro_encoder): + self.inter_attention = ParallelAttention( + config, + layer_number, + attention_type=AttnType.cross_attn) + # Normalize the attention output. + self.post_inter_attention_norm = get_norm(config) + + # MLP + if args.num_experts is not None: + self.mlp = SwitchMLP(config) + else: + self.mlp = ParallelMLP(config) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + if args.retro_add_retriever: + retro_args = get_retro_args() + self.retro_num_neighbors = args.retro_num_neighbors + self.retro_chunk_length = retro_args.retro_gpt_chunk_length + self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length + + # Retriever (bi-directional transformer with cross attention) + if layer_type == LayerType.retro_decoder_with_retriever: + self.retriever = ParallelTransformer( + config=config, + model_type=ModelType.retro_encoder, + self_attn_mask_type=AttnMaskType.padding, + pre_process=True, + post_process=False, + ) + self._retriever_key = 'retriever' + else: + self.retriever = None + + def default_decoder_cross_attention(self, + encoder_output, + enc_dec_attn_mask, + norm_input, + norm_output, + bias_dropout_add_func): + '''Cross attention for a standard encoder-decoder model.''' + + # Attention. + attention_output, attention_bias = \ + self.inter_attention(norm_output, + enc_dec_attn_mask, + encoder_output=encoder_output) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + + # Bias-dropout-add. + with self.bias_dropout_add_exec_handler(): + norm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + + # Normalize. + norm_output = self.post_inter_attention_norm(norm_input) + + return norm_input, norm_output + + def retro_encoder_cross_attention(self, + retriever_output, + norm_input, + norm_output, + bias_dropout_add_func): + """Cross attention for Retro encoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = norm_output.shape # [r, bs * l * k, d] + + # Divide sequence dimension into chunks. + chunked_outputs = norm_output.reshape(self.retro_retrieved_length, + -1, + self.retro_num_neighbors, + d) + chunked_outputs_before_norm = \ + norm_input.reshape(self.retro_retrieved_length, -1, + self.retro_num_neighbors, d) # [r, bs*l, k, d] + + # Per-chunk attention. + norm_inputs = [] + norm_outputs = [] + for k in range(self.retro_num_neighbors): + + # Attention. + chunked_output = chunked_outputs[:,:,k].contiguous() + attention_output, attention_bias = \ + self.inter_attention( + chunked_output, # Q (neighbor embedding) + None, + encoder_output=retriever_output) # K, V (hidden act) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = chunked_output + else: + residual = chunked_outputs_before_norm[:,:,k] + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + norm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(residual), + residual, + self.hidden_dropout) + norm_inputs.append(norm_input) + + # Layer norm. + norm_output = self.post_inter_attention_norm(norm_input) + norm_outputs.append(norm_output) + + # Concatenate layer norms. + # norm_input : [r, k * bs * l, d] + # norm_output : [r, k * bs * l, d] + norm_input = torch.stack(norm_inputs, dim=1).reshape(ns, bs, d) + norm_output = torch.stack(norm_outputs, dim=1).reshape(ns, bs, d) + + return norm_input, norm_output + + def retro_decoder_cross_attention(self, + retriever_input, + retriever_output, + retriever_attn_mask, + norm_input, + norm_output, + inference_params, + bias_dropout_add_func): + """Cross attention for Retro decoder. + + Notation: + ns : Sequence length. + bs : Batch size. + d : Hidden size. + l : Number of chunks per sample (i.e., seq_length/chunk_length). + m : Number of tokens per chunk. + k : Number of neighbors. + r : Number of retrieved tokens (neighbors + continuation). + """ + + ns, bs, d = norm_output.shape + l = int(np.ceil(ns / self.retro_chunk_length)) + + # Retrieve neighbors. + if self.layer_type == LayerType.retro_decoder_with_retriever: + first_ns = ns % self.retro_chunk_length + if first_ns > 0: + raise Exception("test this case.") + first_chunk, rest_chunk = \ + norm_output[:first_ns], norm_output[first_ns:] + first_chunk = torch.nn.functional.pad( + first_chunk, + (0, 0, 0, 0, 0, self.retro_chunk_length - first_ns), + 'constant', + 0) + chunked_output = \ + torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] + else: + chunked_output = norm_output # [l * m, bs, d] + chunked_output = chunked_output \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) \ + .reshape(self.retro_chunk_length, bs * l, d) \ + .contiguous() + + # Get Encoder Output + retriever_output = self.retriever( + hidden_states=retriever_input, + attention_mask=retriever_attn_mask, + retriever_output=chunked_output, + retriever_attn_mask=retriever_attn_mask, + inference_params=inference_params) # [r, k * bs * l , d] + retriever_output = retriever_output.reshape( + self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d] + + # Chunks. + pad = (ns - 1) % self.retro_chunk_length + attending_chunks = norm_output[pad:] + padded_chunks = torch.nn.functional.pad( + attending_chunks, + (0, 0, 0, 0, 0, self.retro_chunk_length - 1), + 'constant', 0) + padded_chunked_output = padded_chunks \ + .reshape(l, self.retro_chunk_length, bs, d) \ + .permute(1, 2, 0, 3) + padded_chunked_output = padded_chunked_output.reshape( + self.retro_chunk_length, bs * l, d).contiguous() + + # Encoder output. + attention_output, attention_bias = \ + self.inter_attention(padded_chunked_output, + None, + encoder_output=retriever_output) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + # Re-enable torch grad to enable fused optimization. + with torch.enable_grad(): + norm_input = bias_dropout_add_func( + attention_output, + None if attention_bias is None else attention_bias.expand_as(attention_output), + torch.zeros_like(attention_output), + self.hidden_dropout) + norm_input = norm_input \ + .reshape(self.retro_chunk_length, bs, l, d) \ + .permute(2, 0, 1, 3) # [l, m, bs, d] + norm_input = norm_input.reshape(self.retro_chunk_length * l, bs, d) + norm_input = torch.nn.functional.pad( + norm_input, + (0, 0, 0, 0, pad, 0), + 'constant', 0)[:ns] # [ns, b, d] + norm_input = norm_input + residual + + # Layer norm post the decoder attention + norm_output = self.post_inter_attention_norm(norm_input) + + return retriever_output, norm_input, norm_output + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + position_ids=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + norm_output = self.input_norm(hidden_states) + + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + norm_output, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + position_ids=position_ids + ) + + # Residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = hidden_states + + if self.drop_path is None: + # jit scripting for a nn.module (with dropout) is not + # trigerring the fusion kernel. For now, we use two + # different nn.functional routines to account for varying + # dropout semantics during training and inference phases. + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + if attention_bias is not None: + attention_bias = attention_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + norm_input = bias_dropout_add_func( + attention_output, + attention_bias, + residual, + self.hidden_dropout) + else: + out = torch.nn.functional.dropout(attention_output + attention_bias, + p=self.hidden_dropout, + training=self.training) + norm_input = residual + self.drop_path(out) + + # Layer norm post the self attention. + norm_output = self.post_attention_norm(norm_input) + + # Cross attention. + if self.layer_type == LayerType.encoder: + pass + elif self.layer_type == LayerType.decoder: + norm_input, norm_output = \ + self.default_decoder_cross_attention( + encoder_output, + enc_dec_attn_mask, + norm_input, + norm_output, + bias_dropout_add_func) + elif self.layer_type == LayerType.retro_encoder: + norm_input, norm_output = \ + self.retro_encoder_cross_attention( + retriever_output, + norm_input, + norm_output, + bias_dropout_add_func) + elif self.layer_type in (LayerType.retro_decoder, + LayerType.retro_decoder_with_retriever): + retriever_output, norm_input, norm_output = \ + self.retro_decoder_cross_attention( + retriever_input, + retriever_output, + retriever_attn_mask, + norm_input, + norm_output, + inference_params, + bias_dropout_add_func) + else: + raise Exception("Unsupported layer type, '%s'." % + self.layer_type.name) + + # MLP. + mlp_output, mlp_bias = self.mlp(norm_output) + + # Second residual connection. + if self.apply_residual_connection_post_norm: + residual = norm_output + else: + residual = norm_input + + if self.drop_path is None: + if mlp_bias is not None: + mlp_bias = mlp_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias, + residual, + self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = core.utils.make_viewless_tensor(inp = output, + requires_grad = output.requires_grad, + keep_graph = True) + + else: + if mlp_bias is not None: + mlp_output = mlp_output + mlp_bias + out = torch.nn.functional.dropout(mlp_output, + p=self.hidden_dropout, + training=self.training) + output = residual + self.drop_path(out) + + if self.layer_type == LayerType.retro_decoder_with_retriever: + return output, retriever_output + else: + return output + + +class NoopTransformerLayer(MegatronModule): + """A single 'no-op' transformer layer. + + The sole purpose of this layer is for when a standalone embedding layer + is used (i.e., args.standalone_embedding_stage == True). In this case, + zero transformer layers are assigned when pipeline rank == 0. Additionally, + when virtual pipeline rank >= 1, zero total model parameters are created + (virtual rank 0 contains the input embedding). This results in the model's + input and output tensors being the same, which causes an error when + performing certain memory optimiations on the output tensor (e.g., + deallocating it). Thus, this layer disconnects the input from the output + via a clone. Since ranks containing a no-op layer are generally under- + utilized (both compute and memory), there's no worry of any performance + degredation. + """ + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def _get_num_layers(args, model_type, is_decoder=False): + """Compute the number of transformer layers resident on the current rank.""" + is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder) + if model_type == ModelType.retro_encoder: + num_layers = args.retro_encoder_layers + elif mpu.get_pipeline_model_parallel_world_size() > 1: + if is_encoder_and_decoder_model: + assert args.pipeline_model_parallel_split_rank is not None + + # When a standalone embedding stage is used, a rank is taken from + # the encoder's ranks, to be used for the encoder's embedding + # layer. This way, the rank referenced by the 'split rank' remains + # the same whether or not a standalone embedding stage is used. + num_ranks_in_encoder = ( + args.pipeline_model_parallel_split_rank - 1 + if args.standalone_embedding_stage else + args.pipeline_model_parallel_split_rank + ) + num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder + assert args.encoder_num_layers % num_ranks_in_encoder == 0, \ + 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder) + assert args.decoder_num_layers % num_ranks_in_decoder == 0, \ + 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder) + if mpu.is_pipeline_stage_before_split(): + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.encoder_num_layers // num_ranks_in_encoder + ) + else: + num_layers = args.decoder_num_layers // num_ranks_in_decoder + else: + assert args.num_layers == args.encoder_num_layers + assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ + 'num_layers must be divisible by transformer_pipeline_model_parallel_size' + + # When a standalone embedding stage is used, all transformer layers + # are divided among pipeline rank >= 1, while on pipeline rank 0, + # ranks either contain the input embedding layer (virtual pp rank 0), + # or no layers at all (virtual pp rank >= 1). + num_layers = ( + 0 + if args.standalone_embedding_stage + and mpu.get_pipeline_model_parallel_rank() == 0 else + args.num_layers // args.transformer_pipeline_model_parallel_size + ) + else: + if not is_decoder: + num_layers = args.encoder_num_layers + else: + num_layers = args.decoder_num_layers + return num_layers + + +def _get_layer_type(model_type, default_layer_type, retro_layer_numbers, + layer_number): + args = get_args() + if args.retro_add_retriever and layer_number in retro_layer_numbers: + if model_type == ModelType.retro_decoder: + return LayerType.retro_decoder_with_retriever \ + if layer_number == retro_layer_numbers[0] \ + else LayerType.retro_decoder + elif model_type == ModelType.retro_encoder: + return LayerType.retro_encoder + else: + raise Exception("Unsupported model type, '%s'." % model_type) + else: + return default_layer_type + + +class ParallelTransformer(MegatronModule): + """Transformer class.""" + + def __init__(self, config, + model_type, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + post_norm=True, + pre_process=True, + post_process=True, + drop_path_rate=0.0): + super(ParallelTransformer, self).__init__() + args = get_args() + + self.layer_type = layer_type + self.model_type = model_type + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_norm = post_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + self.drop_path_rate = drop_path_rate + self.transformer_impl = args.transformer_impl + self.retro_add_retriever = args.retro_add_retriever + + # Store activation checkpoiting flag. + self.recompute_granularity = config.recompute_granularity + self.recompute_method = config.recompute_method + self.recompute_num_layers = config.recompute_num_layers + self.distribute_saved_activations = \ + config.distribute_saved_activations and not config.sequence_parallel + + self.sequence_parallel = config.sequence_parallel + + # Transformer Engine Init. + self.transformer_engine_v_0_10 = False + self.transformer_engine_v_0_11 = False + self.transformer_engine_v_0_8 = False + if self.transformer_impl == 'transformer_engine': + global transformer_engine + import transformer_engine + from importlib.metadata import version + from pkg_resources import packaging + + te_version = packaging.version.Version(version("transformer-engine")) + if te_version >= packaging.version.Version("0.8.0"): + self.transformer_engine_v_0_8 = True + if te_version >= packaging.version.Version("0.10.0"): + self.transformer_engine_v_0_10 = True + if te_version >= packaging.version.Version("0.11.0"): + self.transformer_engine_v_0_11 = True + + del version, packaging + + assert not args.squared_relu, "TransformerEngine does not support squared relu activation." + + self.use_fp8 = args.fp8 is not None + self.fp8_recipe = None + self.fp8_group = None + if self.use_fp8: + assert args.transformer_impl == 'transformer_engine', \ + 'transformer-engine required for fp8 training and inference' + self.fp8_group = mpu.get_amax_reduction_group() + if args.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif args.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.") + self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + margin=args.fp8_margin, + interval=args.fp8_interval, + fp8_format=fp8_format, + amax_history_len=args.fp8_amax_history_len, + amax_compute_algo=args.fp8_amax_compute_algo, + override_linear_precision=(False, False, not args.fp8_wgrad), + ) + + self.num_microbatches_in_previous_step = -1 + self.microbatch_count = 0 + self.checkpoint_core_attention = config.recompute_granularity == 'selective' + + # Number of layers. + self.num_layers = _get_num_layers(args, model_type, + layer_type==LayerType.decoder) + + self.drop_path_rates = [ + rate.item() for rate in + torch.linspace(0, self.drop_path_rate, config.num_layers)] + + self.retro_layer_numbers = None + if model_type == ModelType.retro_decoder: + retro_layer_start = 6 if config.num_layers <= 15 else 9 + self.retro_layer_numbers = \ + np.arange(retro_layer_start, args.num_layers + 1, 3).tolist() + if model_type == ModelType.retro_encoder: + self.retro_layer_numbers = [1] + + # Transformer layers. + if args.retro_add_retriever: + assert self.recompute_granularity != 'full', \ + "Full recompute not supported for Retro." + assert args.transformer_impl == 'local', \ + "Transformer engine does not support Retro layers." + def build_layer(layer_number): + if args.transformer_impl == 'local': + current_layer_type = _get_layer_type( + model_type, layer_type, self.retro_layer_numbers, + layer_number) + return ParallelTransformerLayer( + config, + layer_number, + layer_type=current_layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=self.drop_path_rates[layer_number - 1]) + else: + # This argument is only available from TE v0.10 onwards. + extra_transformer_engine_kwargs = {} + if self.transformer_engine_v_0_8: + extra_transformer_engine_kwargs["bias"] = args.add_bias_linear + if self.transformer_engine_v_0_10: + extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu" + if self.transformer_engine_v_0_11: + extra_transformer_engine_kwargs["normalization"] = args.normalization + return transformer_engine.pytorch.TransformerLayer( + config.hidden_size, + config.ffn_hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.layernorm_epsilon, + hidden_dropout=config.hidden_dropout, + attention_dropout=config.attention_dropout, + init_method=config.init_method, + output_layer_init_method=config.output_layer_init_method, + layer_number=layer_number, + kv_channels=config.kv_channels, + self_attn_mask_type=self_attn_mask_type.name, + tp_group=mpu.get_tensor_model_parallel_group(), + get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, + fuse_wgrad_accumulation=config.gradient_accumulation_fusion, + apply_query_key_layer_scaling=config.apply_query_key_layer_scaling, + attention_softmax_in_fp32=config.attention_softmax_in_fp32, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + sequence_parallel=config.sequence_parallel, + params_dtype=config.params_dtype, + apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm, + output_layernorm=False, + layer_type="encoder", + drop_path_rate=self.drop_path_rates[layer_number - 1], + set_parallel_mode=True, + fuse_qkv_params=True, + **extra_transformer_engine_kwargs) + + if config.virtual_pipeline_model_parallel_size is not None: + assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \ + 'num_layers_per_stage must be divisible by ' \ + 'virtual_pipeline_model_parallel_size' + assert args.model_type != ModelType.encoder_and_decoder + # Number of layers in each model chunk is the number of layers in the stage, + # divided by the number of model chunks in a stage. + self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size + # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0] [2] [4] [6] + # Stage 1: [1] [3] [5] [7] + # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of + # layers to stages like (each list is a model chunk): + # Stage 0: [0, 1] [4, 5] + # Stage 1: [2, 3] [6, 7] + offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( + config.num_layers // config.virtual_pipeline_model_parallel_size) + \ + (mpu.get_pipeline_model_parallel_rank() * self.num_layers) + else: + # Each stage gets a contiguous set of layers. + if args.model_type == ModelType.encoder_and_decoder and \ + mpu.get_pipeline_model_parallel_world_size() > 1: + pipeline_rank = mpu.get_pipeline_model_parallel_rank() + if layer_type == LayerType.encoder: + offset = pipeline_rank * self.num_layers + else: + num_ranks_in_enc = args.pipeline_model_parallel_split_rank + offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers + else: + offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers + + if self.num_layers == 0: + # When a standalone embedding stage is used (e.g., + # args.standalone_embedding_stage == True), virtual pipeline ranks + # on pipeline rank 0 will have zero transformer layers assigned to + # them. This results in the model's input and output tensors to be + # the same, which will cause failure for certain output tensor + # optimizations (e.g., pipeline output deallocation). To remedy + # this, we assign a 'no-op' layer on these ranks, which will + # disconnect the input tensor from the output tensor. + self.num_layers = 1 + self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) + else: + self.layers = torch.nn.ModuleList( + [build_layer(i + 1 + offset) for i in range(self.num_layers)]) + + # Update dropout rate for Retro encoder. + if model_type == ModelType.retro_encoder: + for layer in self.layers: + if layer.self_attention.use_flash_attn: + layer.self_attention.core_attention_flash.dropout_p = \ + torch.nn.Dropout(args.retro_encoder_attention_dropout) + else: + layer.self_attention.core_attention.attention_dropout.p =\ + args.retro_encoder_attention_dropout + layer.hidden_dropout = args.retro_encoder_hidden_dropout + + if self.post_process and self.post_norm: + # Final layer norm before output. + self.final_norm = get_norm(config) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def _checkpointed_forward(self, hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + rotary_pos_emb, position_ids, is_first_microbatch): + """Forward method with activation checkpointing.""" + def custom(start, end): + def custom_forward(*args, **kwargs): + x_, *args = args + for index in range(start, end): + layer = self._get_layer(index) + x_ = layer(x_, *args, **kwargs) + return x_ + return custom_forward + + te_forward_kwargs = {} + if self.transformer_impl == 'transformer_engine': + te_forward_kwargs['is_first_microbatch'] = is_first_microbatch + if self.transformer_engine_v_0_10: + te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + + if self.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and + # checkpoint the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + l = 0 + while l < self.num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + self.recompute_num_layers), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb, position_ids) + + l += self.recompute_num_layers + + elif self.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + for l in range(self.num_layers): + if l < self.recompute_num_layers: + if self.transformer_impl == 'transformer_engine': + hidden_states = transformer_engine.pytorch.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + tensor_parallel.get_cuda_rng_tracker, + mpu.get_tensor_model_parallel_group(), + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = tensor_parallel.checkpoint( + custom(l, l + 1), + self.distribute_saved_activations, + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb, position_ids) + else: + if self.transformer_impl == 'transformer_engine': + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, encoder_output, + enc_dec_attn_mask, **te_forward_kwargs) + else: + hidden_states = custom(l, l + 1)( + hidden_states, attention_mask, + encoder_output, enc_dec_attn_mask, + None, None, None, None, rotary_pos_emb, position_ids) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + + When doing pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.input_tensor = input_tensor + + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + position_ids=None): + # hidden_states: [s, b, h] + + # Checks. + if inference_params: + assert self.recompute_granularity is None, \ + 'inference does not work with activation checkpointing' + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = core.utils.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + # RNG context. + if self.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # Forward layers. + with rng_context: + # The fp8_autocast context manager is a no-op when enabled=True + # The if...else serves to short circuit name resolution for fp8_autocast + with transformer_engine.pytorch.fp8_autocast( + enabled=self.use_fp8, + fp8_recipe=self.fp8_recipe, + fp8_group=self.fp8_group + ) if self.use_fp8 else nullcontext(): + # Determine if the current iteration is first microbatch + if self.num_microbatches_in_previous_step != get_num_microbatches(): + self.microbatch_count = 0 # Reset count on new batch size rampup interval + self.num_microbatches_in_previous_step = get_num_microbatches() + is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0 + + # Forward pass. + if self.recompute_granularity == 'full': + hidden_states = self._checkpointed_forward(hidden_states, + attention_mask, + encoder_output, + enc_dec_attn_mask, + rotary_pos_emb, + position_ids, + is_first_microbatch) + else: + forward_kwargs = { + 'encoder_output': encoder_output, + 'enc_dec_attn_mask': enc_dec_attn_mask, + 'inference_params': inference_params, + } + + if self.transformer_impl == 'transformer_engine': + forward_kwargs['is_first_microbatch'] = is_first_microbatch + forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention + if self.transformer_engine_v_0_10: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + else: + forward_kwargs['rotary_pos_emb'] = rotary_pos_emb + forward_kwargs['position_ids'] = position_ids + forward_kwargs['retriever_input'] = retriever_input + forward_kwargs['retriever_output'] = retriever_output + forward_kwargs['retriever_attn_mask'] = retriever_attn_mask + + for index in range(self.num_layers): + layer = self._get_layer(index) + + hidden_states = layer( + hidden_states, + attention_mask, + **forward_kwargs) + + # First Retro decoder layer returns both hidden_states + # and retriever_output. Make retriever_output available + # to subsequence Retro layers. + if isinstance(hidden_states, tuple): + assert len(hidden_states) == 2 + hidden_states, retriever_output = hidden_states + forward_kwargs["retriever_output"] = retriever_output + + # Skip counter update for eval and activation checkpointing + if torch.is_grad_enabled() and self.training: + self.microbatch_count += 1 + + # Final layer norm. + if self.post_process and self.post_norm: + hidden_states = self.final_norm(hidden_states) + + return hidden_states + + def load_state_dict(self, state_dict, strict=True): + """Customize load.""" + + # Handle renaming layernorm -> norm in component names + args = get_args() + state_dict_ = {} + for key in state_dict.keys(): + if args.transformer_impl != "transformer_engine": + newkey = key.replace("layernorm", "norm") + state_dict_[newkey] = state_dict[key] + else: + state_dict_[key] = state_dict[key] + + if args.use_llama2_rotary_position_embeddings: + super().load_state_dict(state_dict_, strict) + else: + super().load_state_dict(state_dict_, False) diff --git a/megatron_patch/model/qwen_vl/visual.py b/megatron_patch/model/qwen_vl/visual.py new file mode 100644 index 00000000..83db40e5 --- /dev/null +++ b/megatron_patch/model/qwen_vl/visual.py @@ -0,0 +1,425 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +import math +import requests +from functools import partial +from PIL import Image +from typing import Callable, Optional, List +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import trunc_normal_ +from torchvision import transforms +from torchvision.transforms import InterpolationMode + + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + return F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + else: + return abs_pos + +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Resampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + def __init__( + self, + grid_size, + embed_dim, + num_heads, + kv_dim=None, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.num_queries = grid_size ** 2 + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.pos_embed = nn.Parameter( + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() + ).requires_grad_(False) + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=.02) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + else: + self.kv_proj = nn.Identity() + + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + + # self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, attn_mask=None): + + pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + + x = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask)[0] + return out.permute(1, 0, 2) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class VisualAttention(nn.Module): + """self-attention layer class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, embed_dim, num_heads, + bias=True, kdim=None, vdim=None): + super(VisualAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward(self, query, key, value, attn_mask = None): + # query/key/value: [sq, b, h] + sq, b, _ = query.size() + + assert torch.allclose(query, key), 'Only Support Self-Attention Currently' + sk = sq + mixed_x_layer = self.in_proj(query) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(sk, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view(sk, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(b, + self.num_attention_heads_per_partition, + sq, self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + + return output + + +class VisualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock( + width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + **kwargs + ): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock( + width, + layers, + heads, + mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.attn_pool = Resampler( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + ) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) + + def forward(self, x: torch.Tensor): + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, x.size(1)) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + def encode(self, image_paths: List[str]): + images = [] + for image_path in image_paths: + if image_path.startswith("http://") or image_path.startswith("https://"): + image = Image.open(requests.get(image_path, stream=True).raw) + else: + image = Image.open(image_path) + image = image.convert("RGB") + images.append(self.image_transform(image)) + images = torch.stack(images, dim=0) + return self(images) diff --git a/megatron_patch/optimizer/__init__.py b/megatron_patch/optimizer/__init__.py new file mode 100644 index 00000000..34782079 --- /dev/null +++ b/megatron_patch/optimizer/__init__.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +from apex.optimizers import FusedAdam as Adam +from apex.optimizers import FusedSGD as SGD + +from megatron import get_args +from megatron.optimizer.grad_scaler import ConstantGradScaler, DynamicGradScaler +from megatron.optimizer.optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer +from megatron.optimizer import get_param_groups + +from .distrib_optimizer import DistributedOptimizer + + +def get_megatron_optimizer(model, + no_weight_decay_cond=None, + scale_lr_cond=None, + lr_mult=1.0): + args = get_args() + + # Base optimizer. + param_groups = get_param_groups(model, + no_weight_decay_cond, + scale_lr_cond, + lr_mult) + + if args.optimizer == 'adam': + optimizer = Adam(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_eps) + elif args.optimizer == 'sgd': + optimizer = SGD(param_groups, + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.sgd_momentum) + else: + raise Exception('{} optimizer is not supported.'.format( + args.optimizer)) + + # Determine whether the params have main-grad field. + params_have_main_grad = True + + # Mixed precision optimizer. + # - Note: both the Float16Optimizer and the DistributedOptimizer inherit + # from the MixedPrecisionOptimizer, which manages any optimizer where + # the model params and main params are distinct. + if args.fp16 or args.bf16 or args.use_distributed_optimizer: + + # Grad scaler: + # if loss-scale is provided, instantiate the constant scaler. + # if we are using fp16 and loss-scale is not present, use a + # dynamic scaler. + # otherwise we are running in bf16 with no loss-scale so + # leave it as None. + grad_scaler = None + + # Constant loss scale. + if args.loss_scale: + grad_scaler = ConstantGradScaler(args.loss_scale) + + # Dynamic loss scale. + else: + if args.fp16: + grad_scaler = DynamicGradScaler( + initial_scale=args.initial_loss_scale, + min_scale=args.min_loss_scale, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=args.loss_scale_window, + hysteresis=args.hysteresis) + + # Megatron optimizer. + opt_ty = DistributedOptimizer \ + if args.use_distributed_optimizer else \ + Float16OptimizerWithFloat16Params + return opt_ty(optimizer, + args.clip_grad, + args.log_num_zeros_in_grad, + args.check_for_nan_in_loss_and_grad, + params_have_main_grad, + args.fp16, + args.bf16, + args.params_dtype, + grad_scaler, + model) + + # FP32. + return FP32Optimizer(optimizer, args.clip_grad, + args.log_num_zeros_in_grad, + args.check_for_nan_in_loss_and_grad, + params_have_main_grad, + model) diff --git a/megatron_patch/optimizer/distrib_optimizer.py b/megatron_patch/optimizer/distrib_optimizer.py new file mode 100644 index 00000000..25a6a4c1 --- /dev/null +++ b/megatron_patch/optimizer/distrib_optimizer.py @@ -0,0 +1,707 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Megatron distributed optimizer.""" + +import torch + +from megatron import print_rank_0 +from megatron.core import mpu, tensor_parallel +from megatron.optimizer.distrib_optimizer import Range + +from megatron.optimizer.optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper +from megatron_patch.distributed import is_moe_param + + +def shard_buffer(buffer): + """ + Shard buffer into dp_size chunks of equal size. + """ + data_parallel_world_size = mpu.get_data_parallel_world_size() + assert buffer.numel() % data_parallel_world_size == 0, "{}, {}".format(buffer.numel(), data_parallel_world_size) + shard_size = buffer.numel() // data_parallel_world_size + sharded_buffer = [buffer[(r*shard_size):((r+1)*shard_size)] + for r in range(data_parallel_world_size)] + return sharded_buffer + +class DistributedOptimizer(MixedPrecisionOptimizer): + """Distributed optimizer, for all data types (fp16, bf16, and fp32). + + Arguments: + optimizer: base optimizer such as Adam or SGD + clip_grad: clip gradeints with this global L2 norm. Note + that clipping is ignored if clip_grad == 0 + log_num_zeros_in_grad: return number of zeros in the gradients. + params_have_main_grad: flag indicating if parameters have + a `main_grad` field. If this is set, we are assuming + that the model parameters are store in the `main_grad` + field instead of the typical `grad` field. This happens + for the DDP cases where there is a continuous buffer + holding the gradients. For example for bfloat16, we want + to do gradient accumulation and all-reduces in float32 + and as a result we store those gradients in the main_grad. + Note that main grad is not necessarily in float32. + use_contiguous_buffers_in_local_ddp: if true, the local DDP model + is using a contiguous buffer to hold the model grads. + fp16: if true, the model is running in fp16. + bf16: if true, the model is running in bfloat16. + grad_scaler: used for scaling gradients. Note that this can be + None. This case happens when `bf16 = True` and we don't + use any loss scale. Note that for `bf16 = True`, we can have + a constnat gradient scaler. Also for `bf16 = False`, we + always require a grad scaler. + models: list of models (i.e., the virtual pipelining models). This + is used by the distributed optimizer for mapping parameters. + """ + + @classmethod + def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range, bucket_offset): + """ + Build mapping from param reference to grad buffer shard ranges. + + This method builds a mapping from parameter references to grad + buffer shard ranges, specific to each data-parallel (DP) rank's + set of 'owned' parameters. Each grad buffer (padded to be an even + multiple of DP-world-size) is conceptually divided into DP-world-size + contiguous regions, where each DP rank 'owns' a contiguous regions. + Ownership in this sense means DP rank is responsible for reducing + the relevant subset of grads, and updating the relevant subset of + params. + + This conceptual partitioning of the grad buffer does NOT respect + parameter boundaries, and as such it is assumed that each created + range references a shard (or subset) of the full parameter. It is + easiest to think of each DP rank as operating (i.e., reducing, + gathering) purely on views into the grad buffer, for all model-to- + main & main-to-model operations. + + This method creates four ranges: + - The param's range within the entire grad buffer (i.e., world index). + - The param's range within the relevant grad bucket's buffer. + - The param's range within the DP rank's local view of the grad buffer. + - The param's range within itself (i.e., its shard). + """ + + # Param range map. + param_world_index_map = model.grad_buffer_param_index_map[dtype] + param_range_map = {} + for param, param_world_indexes in param_world_index_map.items(): + + # Param range. + param_world_start, param_world_end, _ = param_world_indexes + param_local_start = max( + 0, + param_world_start - gbuf_world_range.start) + param_local_end = min( + gbuf_world_range.size, + param_world_end - gbuf_world_range.start) + + # Add param, if within local gbuf range. + if param_local_end > param_local_start: + param_local_range = Range(param_local_start, param_local_end) + param_world_range = param_local_range.normalize( + param_local_start + gbuf_world_range.start) + param_world_range_in_bucket = Range(param_world_range.start-bucket_offset, + param_world_range.end-bucket_offset) + sub_param_start = max(0, gbuf_world_range.start-param_world_start) + sub_param_range = param_local_range.normalize(sub_param_start) + param_range_map[param] = { + "gbuf_world" : param_world_range, + "gbuf_world_in_bucket": param_world_range_in_bucket, + "gbuf_local" : param_local_range, + "param" : sub_param_range, + } + + return param_range_map + + + @classmethod + def build_model_gbuf_range(cls, model, dtype, bucket_index): + """ + Build mapping between params and their grad buffers. + + This method does the initial setup for the method above. This setup + includes determining the shard ranges into the DDP's grad buffer for + each data-parallel (DP) rank. Each DP rank keeps range info for + all other DP ranks, for the purpose of creating args for + reduce-scatter and all-gather. + """ + + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_world_size = mpu.get_data_parallel_world_size() + + bucket = model.grad_buffers[dtype].buckets[bucket_index] + bucket_buffer = bucket.data + gbuf_size = bucket_buffer.numel() + assert gbuf_size % data_parallel_world_size == 0, \ + f"Each bucket's buffer size should be divisible by {data_parallel_world_size}" + max_gbuf_range_size = gbuf_size // data_parallel_world_size + + # All world ranges (i.e., across all data parallel ranks). + gbuf_world_all_ranges = [] + for r in range(data_parallel_world_size): + # Compute start of chunk in this bucket. + gbuf_world_start = r * max_gbuf_range_size + gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size) + # Add bucket's offset in grad buffer. + gbuf_world_range = Range(gbuf_world_start + bucket.offset, + gbuf_world_end + bucket.offset) + gbuf_world_all_ranges.append(gbuf_world_range) + + # Local DP's ranges. + gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank] + + # Get each param's ranges. + param_range_map = cls.build_model_gbuf_param_range_map(model, + dtype, + gbuf_world_range, + bucket.offset) + + # Group into dict. + data = { + "param_map" : param_range_map, + } + + return data + @classmethod + def build_model_gbuf_range_map(cls, model): + """ + Create param-to-grad-buffer mappings, for grad buffer data types + within a specific virtual model. + """ + # Iterate through all buckets to construct param ranges that this rank "owns" + # (the dp_rank'th shard of each bucket, where each shard is 1/dp_world_size + # of the bucket). + return { + dtype : [cls.build_model_gbuf_range(model, dtype, bucket_index) + for bucket_index in range(len(model.grad_buffers[dtype].buckets))] + for dtype in model.grad_buffers + } + + @classmethod + def build_model_param_gbuf_map(cls, model_gbuf_ranges): + """ + Create a reverse of the model_gbuf_ranges, for referencing in + opposite direction. + """ + param_gbuf_map = {} + for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges): + for dtype, gbuf_range_map_for_all_buckets in model_gbuf_range_map.items(): + for bucket_index, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): + for param, _ in gbuf_range_map["param_map"].items(): + assert param not in param_gbuf_map, \ + "Param should not be in param_gbuf_map; each param only belongs to a single bucket" + param_gbuf_map[param] = (model_index, dtype, bucket_index) + return param_gbuf_map + + @classmethod + def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges, model_expert_params_list): + """ + Create optimizer groups. + + Given the set of parameter shard ranges that are owned by the current + data-parallel (DP) rank, gather the set of parameters that will be + used (in the method below) to create the current DP's optimizer + groups. + """ + + num_groups = len(param_groups) + + # Param group map. + # World param group map. + # - Store a mapping of for all parameters + # across all DP ranks. This is necessary because it is our first + # cross reference between the DDP mappings and the optimizer group + # parameters. This mapping only for use in the next step of building + # the local mapping over this DP rank's parameters. + world_param_group_map = {} + for group_index, group in enumerate(param_groups): + for param in group["params"]: + assert param.requires_grad + world_param_group_map[param] = group_index + + # Optimizer group ranges & param-group mapping. + # - Build a mapping from groups to their contained parameters, and also + # from parameters to their containing group index and order within + # the group. The group index and order are particularly important for + # saving and loading checkpoints. + local_param_group_map = {} + group_ranges = [ {"params": []} for _ in param_groups ] + for model_gbuf_range_map in model_gbuf_ranges: + for dtype, gbuf_range_map_for_all_buckets in model_gbuf_range_map.items(): + for gbuf_range_map in gbuf_range_map_for_all_buckets: + for param in gbuf_range_map["param_map"]: + group_index = world_param_group_map[param] + group_range = group_ranges[group_index] + group_range["params"].append(param) + local_param_group_map[param] = \ + (group_index, len(group_range["params"]) - 1) + + # Add expert params into optimizer group + for expert_params in model_expert_params_list: + for expert_param in expert_params: + group_range["params"].append(expert_param) + + # Squeeze zero-size group ranges. + for group_index, group_range in enumerate(group_ranges): + group_range["orig_group"] = param_groups[group_index] + group_range["orig_group_idx"] = param_groups[group_index] + + return local_param_group_map, group_ranges + + @classmethod + def build_model_and_main_param_groups(cls, + model_gbuf_ranges, + param_gbuf_map, + opt_group_ranges): + """ + Create main parameter groups needed for the optimizer step. + + These groups encompass both: 1) groups used by this class, for + reducing/gather, and 2) groups used by the inner optimizer for the + parameter update. Given that the conceptual grad buffer partitioning + (created in earlier method) doesn't respect parameter boundaries, + the optimizer operates on shards of the model parameters, rather than + the full parameters. + """ + + # Parameter groups: + # model_float16_groups: original float16 parameters + # model_fp32_groups: original fp32 parameters + # shard_float16_groups: shards of original float16 parameters + # shard_fp32_groups: shards of original fp32 parameters + # shard_fp32_from_float16_groups: fp32 copy of float16 parameters + model_float16_groups = [] + model_fp32_groups = [] + shard_float16_groups = [] + shard_fp32_groups = [] + shard_fp32_from_float16_groups = [] + + # Allocate (or slice) each group's param shard. + for group_index, group_range in enumerate(opt_group_ranges): + + # Params of this group. + model_float16_params_this_group = [] + model_fp32_params_this_group = [] + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_fp32_from_float16_params_this_group = [] + model_float16_groups.append(model_float16_params_this_group) + model_fp32_groups.append(model_fp32_params_this_group) + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + shard_fp32_from_float16_groups.append( + shard_fp32_from_float16_params_this_group) + + for model_param in group_range["params"]: + + assert model_param.requires_grad + + if is_moe_param(model_param): + # Each DP rank holds different experts with whole shape. + param_range = Range(0, model_param.numel()) + else: + model_index, dtype, bucket_index = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', + 'torch.cuda.BFloat16Tensor']: + + # Clone model -> main. + shard_model_param = model_param.detach().view(-1) \ + [param_range.start:param_range.end] + shard_main_param = shard_model_param.clone().float() + tensor_parallel.copy_tensor_model_parallel_attributes( + shard_model_param, model_param) + tensor_parallel.copy_tensor_model_parallel_attributes( + shard_main_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + shard_main_param.shared = model_param.shared + + # Add to group. + model_float16_params_this_group.append(model_param) + shard_float16_params_this_group.append(shard_model_param) + shard_fp32_from_float16_params_this_group.append(shard_main_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1) \ + [param_range.start:param_range.end] + model_fp32_params_this_group.append(model_param) + shard_fp32_params_this_group.append(shard_model_param) + tensor_parallel.copy_tensor_model_parallel_attributes( + shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + else: + raise TypeError('Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(model_param.type())) + + # Update optimizer's params. + group_range["orig_group"]["params"] = [ + *shard_fp32_params_this_group, + *shard_fp32_from_float16_params_this_group, + ] + + return ( + model_float16_groups, + model_fp32_groups, + shard_float16_groups, + shard_fp32_groups, + shard_fp32_from_float16_groups, + ) + + def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models): + """ + See top of class definition for argument descriptions. + + The steps in this method create the core mapping between DDP grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard + indexes. This method also updates the optimizer parameter groups + with the newly created shards. + + For MoE params which are sharded across each DP rank by default, each + DP ranks holds full shape of each expert tensor. + + """ + + super().__init__( + optimizer, clip_grad, log_num_zeros_in_grad, + params_have_main_grad, use_contiguous_buffers_in_local_ddp, + fp16, bf16, params_dtype, grad_scaler, models) + + # Verify that contiguous buffers are being used. + # - Note: this should already be checked in arguments.py. + assert use_contiguous_buffers_in_local_ddp + + # Model grad buffer ranges. + self.model_gbuf_ranges = [] + self.bucket_sizes = [] + for model_index, model in enumerate(self.models): + self.bucket_sizes.append(model.bucket_size) + self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model)) + self.model_param_gbuf_map = \ + self.build_model_param_gbuf_map(self.model_gbuf_ranges) + + self.model_expert_params_list = [] + for _, model in enumerate(self.models): + self.model_expert_params_list.append(model.expert_params) + + # Optimizer ranges. + self.model_param_group_index_map, self.opt_group_ranges = \ + self.build_optimizer_group_ranges(self.optimizer.param_groups, + self.model_gbuf_ranges, + self.model_expert_params_list) + + # Allocate main param shards. + ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, + self.shard_fp32_groups, + self.shard_fp32_from_float16_groups, + ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges, + self.model_param_gbuf_map, + self.opt_group_ranges) + + # Initialize param buffers. + # - These are views on the DDP model's grad buffers, that share + # storage & have their own dtype. This is safe because the param + # dtype size is always <= grad dtype size. + self.param_buffers = [] + for model_index, model in enumerate(self.models): + current_param_buffers = {} + for dtype, grad_buffer in model.grad_buffers.items(): + current_param_buffers[dtype] = [] + for bucket in grad_buffer.buckets: + + # Handle older/newer method for getting untyped storage. + try: + storage = bucket.data.storage()._untyped() + except: + storage = bucket.data.storage().untyped() + + # Typed param buffer. + param_buffer = torch.tensor( + storage, + dtype = params_dtype, + device = bucket.data.device) + # .storage() ignores views / slices, so param_buffer now points to the start + # of the grad_buffer instead of to the start of each bucket. As a result, + # add bucket.offset to make sure param_buffers don't point to the same region + # of memory. + param_buffer = param_buffer[bucket.offset:bucket.offset+bucket.data.numel()] + current_param_buffers[dtype].append(param_buffer) + self.param_buffers.append(current_param_buffers) + + # Update optimizer groups. + # - Also, leverage state_dict() and load_state_dict() to + # recast preexisting per-param state tensors. + self.optimizer.param_groups = \ + [ g["orig_group"] for g in self.opt_group_ranges ] + self.optimizer.load_state_dict(self.optimizer.state_dict()) + + + def get_model_param_range_map(self, param): + """ + Given a model param, get the index sub-range of the param that this + data-parallel rank owns. + """ + if is_moe_param(param): + # Each DP rank holds the *WHOLE* size of each expert's param. + param_range_map = {"param" : Range(0, param.numel())} + else: + model_index, dtype, bucket_index = self.model_param_gbuf_map[param] + gbuf_range_map = self.model_gbuf_ranges[model_index][dtype][bucket_index] + param_range_map = gbuf_range_map["param_map"][param] + return param_range_map + + + def get_model_parallel_group(self): + """ + With the distributed optimizer, the model parallel group is the + entire world. + """ + return None + + + def state_dict(self): + """ + The state dict must contain the fp32-from-float16 shards. + """ + state_dict = {} + state_dict['optimizer'] = self.optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + state_dict['shard_fp32_from_float16_groups'] = \ + self.shard_fp32_from_float16_groups + return state_dict + + + def load_state_dict(self, state_dict): + """ + Load the state dict. + """ + + # Optimizer. + optimizer_key = 'optimizer' + if optimizer_key not in state_dict: + optimizer_key = 'optimizer_state_dict' + print_rank_0('***WARNING*** loading optimizer from ' + 'an old checkpoint ...') + self.optimizer.load_state_dict(state_dict[optimizer_key]) + + # Grad scaler. + if 'grad_scaler' not in state_dict: + if self.fp16: + print_rank_0('***WARNING*** found an old checkpoint, will not ' + 'load grad scaler ...') + else: + if self.grad_scaler: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + else: + print_rank_0('***WARNING*** fould the grad scaler in the ' + 'checkpoint but it is None in the class. ' + 'Skipping loading grad scaler ...') + + # Copy data for the main params. + for current_group, saved_group in zip( + self.shard_fp32_from_float16_groups, + state_dict["shard_fp32_from_float16_groups"]): + for current_param, saved_param in zip(current_group, saved_group): + current_param.data.copy_(saved_param.data) + + + def zero_grad(self, set_to_none=True): + """ + Zero grads. + + We only need to zero the model related parameters, i.e., + model_float16_groups & model_fp32_groups. We additionally zero + the remaining groups as a memory optimization to reduce + fragmentation; in the case of set_to_none==True, the space + used by this field can be safely deallocated at this point. + """ + for groups in ( + self.model_float16_groups, + self.model_fp32_groups, + self.shard_float16_groups, # grad empty/unused here? + self.shard_fp32_groups, # throws grad-access warning + self.shard_fp32_from_float16_groups): + for group in groups: + _zero_grad_group_helper(group, set_to_none) + + @staticmethod + def get_model_buffer_dp_views(model_buffers): + # Buffer views. + view_items = [] + for model_index, buffers in enumerate(model_buffers): + for dtype, buf_for_all_buckets in buffers.items(): + for bucket_index, buf in enumerate(buf_for_all_buckets): + buf_views = shard_buffer(buf) + view_items.append((model_index, dtype, bucket_index, buf, buf_views)) + return view_items + + def get_model_grad_buffer_dp_views(self): + return self.get_model_buffer_dp_views([ + {dtype : mem_buffer.data} + for model in self.models + for dtype, mem_buffer in model.grad_buffers.items()]) + + + def get_model_param_buffer_dp_views(self): + return self.get_model_buffer_dp_views(self.param_buffers) + + def gather_model_params(self, args, timers): + """ + All-gather updated model params. + + The DDP's param buffer is used for the all-gather, and thus no + tensors are dynamically allocated. After the all-gather, the params + can be copied from the param buffer to the param. + """ + + timers('params-all-gather', log_level=1).start( + barrier=args.barrier_with_L1_time) + + data_parallel_rank = mpu.get_data_parallel_rank() + data_parallel_group = mpu.get_data_parallel_group() + + # All-gather updated main params. + # - All param buffer views are guaranteed to have the same num elements + # across all data parallel ranks, due to grad buffer padding that is + # done in distributed.py, and extended to the param buffers. Thus, + # all sub-views will have consistent start/end indexes across data + # parallel ranks. + pbuf_view_items = self.get_model_param_buffer_dp_views() + for (_, _, _, pbuf, pbuf_views) in pbuf_view_items: + torch.distributed._all_gather_base( + pbuf, + pbuf_views[data_parallel_rank], + group = data_parallel_group, + ) + + # Copy from param buffer to each param. + for model_id, model in enumerate(self.models): + # For non-MoE parameters. + for dtype, param_map in model.grad_buffer_param_index_map.items(): + for param, (buf_start, buf_end, bucket_index) in param_map.items(): + bucket_offset = model.grad_buffers[dtype].buckets[bucket_index].offset + param_buf = self.param_buffers[model_id][dtype][bucket_index] + # buf_start and buf_end store position of this parameter in the full grad_buffer, + # so need to adjust these indices (by subtracting out bucket_offset) since we + # have independent param_bufs for each bucket. + param_buf_shard = param_buf[buf_start-bucket_offset:buf_end-bucket_offset] + assert param.data.nelement() == param_buf_shard.nelement() + param.view(-1).detach().copy_(param_buf_shard) + # For MoE parameters. + for expert_param in model.expert_params: + expert_param.detach().copy_(expert_param.main_grad) + timers('params-all-gather').stop() + + + def _collect_main_grad_data_for_unscaling(self): + """ + Note: this should be equivalent to the float-16 optimizer's method, + but writtent differently, so the two should be combined. + """ + return [ + param.grad.data + for group in self.optimizer.param_groups + for param in group["params"] + ] + + + def _get_model_and_main_params_data_float16(self): + """ + Get aligned list of model and main params. + """ + model_data = [] + main_data = [] + for model_group, main_group in zip(self.shard_float16_groups, + self.shard_fp32_from_float16_groups): + for model_param, main_param in zip(model_group, main_group): + model_data.append(model_param.data) + main_data.append(main_param.data) + return model_data, main_data + + + def _copy_model_grads_to_main_grads(self): + """ + Copy model grads to main grads. + + Since this step follows a reduce-scatter through the DDP's grad + buffer, this method is responsible for copying the updated grads + from the grad buffer to the main shard's grad field. + """ + + # Utility method for copying group grads. + def copy_group_grads(model_groups, shard_main_groups): + for model_group, shard_main_group in zip(model_groups, + shard_main_groups): + for model_param, shard_main_param in zip(model_group, + shard_main_group): + + param_range_map = self.get_model_param_range_map(model_param) + param_range = param_range_map["param"] + assert param_range.size == shard_main_param.nelement() + + model_grad = model_param.main_grad + shard_model_grad = model_grad.view(-1) \ + [param_range.start:param_range.end] + shard_main_param.grad = shard_model_grad.float() + + # Copy model groups to shard groups. + copy_group_grads(self.model_float16_groups, + self.shard_fp32_from_float16_groups) + copy_group_grads(self.model_fp32_groups, + self.shard_fp32_groups) + + + def _copy_main_params_to_model_params(self): + """ + Copy main params to model params. + + Since this step is followed by an all-gather through the DDP's grad + buffer, this method is responsible for copying the updated params + from the main shards into the correct position in the grad buffer. + """ + + # Utility method for copying group params. + def copy_group_params(shard_main_groups, model_groups): + for shard_main_group, model_group in zip(shard_main_groups, + model_groups): + for shard_main_param, model_param in zip(shard_main_group, + model_group): + + if is_moe_param(model_param): + assert model_param.numel() == shard_main_param.numel() + model_param.main_grad.data.view(-1).copy_(shard_main_param) + else: + param_range_map = self.get_model_param_range_map(model_param) + world_range = param_range_map["gbuf_world"] + + assert world_range.size == shard_main_param.nelement() + + model_id, dtype, bucket_id = self.model_param_gbuf_map[model_param] + model_param_buffer = self.param_buffers[model_id][dtype][bucket_id] + + shard_model_param = model_param_buffer.view(-1) \ + [world_range.start:world_range.end] + + shard_model_param.data.copy_(shard_main_param) + + # Copy shard groups to model groups. + copy_group_params(self.shard_fp32_from_float16_groups, + self.model_float16_groups) + copy_group_params(self.shard_fp32_groups, + self.model_fp32_groups) diff --git a/megatron_patch/tokenizer/__init__.py b/megatron_patch/tokenizer/__init__.py index b9b5981b..7dbb2b9b 100644 --- a/megatron_patch/tokenizer/__init__.py +++ b/megatron_patch/tokenizer/__init__.py @@ -187,11 +187,22 @@ def build_tokenizer(args): padding_side='right', use_fast=False, ) - #tokenizer.pad_token_id = tokenizer.eod_id tokenizer.pad_token_id = tokenizer.pad_id tokenizer.eos_token_id = tokenizer.eod_id args.padded_vocab_size = tokenizer.vocab_size + args.extra_vocab_size + elif args.patch_tokenizer_type == 'QwenVLTokenizer': + from .tokenization_qwen_vl import QWenTokenizer + tokenizer = QWenTokenizer.from_pretrained( + args.load, + model_max_length=args.seq_length, + padding_side="right", + use_fast=False, + trust_remote_code=False, + ) + tokenizer.pad_token_id = tokenizer.eod_id + args.padded_vocab_size = tokenizer.vocab_size + args.extra_vocab_size + elif args.patch_tokenizer_type == 'YiTokenizer': from .tokenization_yi import YiTokenizer if args.load is None: @@ -213,12 +224,13 @@ def build_tokenizer(args): args.padded_vocab_size = tokenizer.vocab_size + args.extra_vocab_size elif args.patch_tokenizer_type == 'MistralTokenizer': - from .tokenization_mistral import MistralTokenizer - tokenizer = MistralTokenizer(os.path.join(args.load, "tokenizer.model")) - tokenizer.pad_token_id = tokenizer.pad_id - tokenizer.eos_token_id = tokenizer.eos_id - tokenizer.eos_token = tokenizer.decode(tokenizer.eos_id) - args.padded_vocab_size = tokenizer.n_words + args.extra_vocab_size + print_rank_0('Using Mistral tokenizer.') + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.load, + padding_side='right', + use_fast=False,) + tokenizer.pad_token_id = 0 + args.padded_vocab_size = tokenizer.vocab_size + args.extra_vocab_size elif args.patch_tokenizer_type == 'BloomTokenizerFromCustom': print_rank_0('Using Customized Bloom tokenizer.') diff --git a/megatron_patch/tokenizer/tokenization_mistral.py b/megatron_patch/tokenizer/tokenization_mistral.py deleted file mode 100644 index 835d48a7..00000000 --- a/megatron_patch/tokenizer/tokenization_mistral.py +++ /dev/null @@ -1,36 +0,0 @@ -from pathlib import Path -from sentencepiece import SentencePieceProcessor -from typing import List - - -class MistralTokenizer: - def __init__(self, model_path: str): - assert Path(model_path).exists(), model_path - self._model = SentencePieceProcessor(model_file=model_path) - assert self._model.vocab_size() == self._model.get_piece_size() - - @property - def n_words(self) -> int: - return self._model.vocab_size() - - @property - def bos_id(self) -> int: - return self._model.bos_id() - - @property - def eos_id(self) -> int: - return self._model.eos_id() - - @property - def pad_id(self) -> int: - return self._model.pad_id() - - def encode(self, s: str, bos: bool = True) -> List[int]: - assert isinstance(s, str) - t = self._model.encode(s) - if bos: - t = [self.bos_id, *t] - return t - - def decode(self, t: List[int]) -> str: - return self._model.decode(t) diff --git a/megatron_patch/tokenizer/tokenization_qwen_vl.py b/megatron_patch/tokenizer/tokenization_qwen_vl.py new file mode 100644 index 00000000..ece277b6 --- /dev/null +++ b/megatron_patch/tokenizer/tokenization_qwen_vl.py @@ -0,0 +1,587 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Tokenization classes for QWen.""" + +import base64 +import logging +import os +import requests +import unicodedata +from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional + +import tiktoken +from PIL import Image +from transformers import PreTrainedTokenizer, AddedToken +from transformers.utils import try_to_load_from_cache + +import matplotlib.colors as mcolors +from matplotlib.font_manager import FontProperties + +logger = logging.getLogger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"} +FONT_PATH = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf") +if FONT_PATH is None: + if not os.path.exists("SimSun.ttf"): + ttf = requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/SimSun.ttf") + open("SimSun.ttf", "wb").write(ttf.content) + FONT_PATH = "SimSun.ttf" + +PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +ENDOFTEXT = "<|endoftext|>" +IMSTART = "<|im_start|>" +IMEND = "<|im_end|>" +# as the default behavior is changed to allow special tokens in +# regular texts, the surface forms of special tokens need to be +# as different as possible to minimize the impact +EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) +SPECIAL_TOKENS = ( + ENDOFTEXT, + IMSTART, + IMEND, +) + EXTRAS +IMG_TOKEN_SPAN = 256 + + +def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + with open(tiktoken_bpe_file, "rb") as f: + contents = f.read() + return { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } + +def _list_find( + input_list: List[Any], + candidates: Tuple[Any], + start: int = 0, +): + for i in range(start, len(input_list)): + if input_list[i] in candidates: + return i + return -1 + +def _replace_closed_tag( + input_tokens: List[Any], + start_tags: Union[Any, Tuple[Any]], + end_tags: Union[Any, Tuple[Any]], + inclusive_replace_func: Callable, + exclusive_replace_func: Callable = lambda x: x, +): + if isinstance(start_tags, (str, int)): + start_tags = (start_tags,) + if isinstance(end_tags, (str, int)): + end_tags = (end_tags,) + assert len(start_tags) == len(end_tags) + + output_tokens = [] + end = 0 + while True: + start = _list_find(input_tokens, start_tags, end) + if start == -1: + break + output_tokens.extend(exclusive_replace_func(input_tokens[end : start])) + tag_idx = start_tags.index(input_tokens[start]) + end = _list_find(input_tokens, (end_tags[tag_idx],), start) + if end == -1: + raise ValueError("Unclosed image token") + output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1])) + end += 1 + output_tokens.extend(exclusive_replace_func(input_tokens[end : ])) + return output_tokens + +class QWenTokenizer(PreTrainedTokenizer): + """QWen tokenizer.""" + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + errors="replace", + image_start_tag='', + image_end_tag='', + image_pad_tag='', + ref_start_tag='', + ref_end_tag='', + box_start_tag='', + box_end_tag='', + quad_start_tag='', + quad_end_tag='', + **kwargs, + ): + super().__init__(**kwargs) + self.image_start_tag = image_start_tag + self.image_end_tag = image_end_tag + self.image_pad_tag = image_pad_tag + self.ref_start_tag = ref_start_tag + self.ref_end_tag = ref_end_tag + self.box_start_tag = box_start_tag + self.box_end_tag = box_end_tag + self.quad_start_tag = quad_start_tag + self.quad_end_tag = quad_end_tag + self.IMAGE_ST = ( + ref_start_tag, ref_end_tag, + box_start_tag, box_end_tag, + quad_start_tag, quad_end_tag, + image_start_tag, image_end_tag, + image_pad_tag + ) + + self.errors = errors # how to handle errors in decoding + + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int] + self.special_tokens = { + token: index + for index, token in enumerate( + SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks) + ) + } + self.img_start_id = self.special_tokens[self.image_start_tag] + self.img_end_id = self.special_tokens[self.image_end_tag] + self.img_pad_id = self.special_tokens[self.image_pad_tag] + self.ref_start_id = self.special_tokens[self.ref_start_tag] + self.ref_end_id = self.special_tokens[self.ref_end_tag] + self.box_start_id = self.special_tokens[self.box_start_tag] + self.box_end_id = self.special_tokens[self.box_end_tag] + self.quad_start_id = self.special_tokens[self.quad_start_tag] + self.quad_end_id = self.special_tokens[self.quad_end_tag] + + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + assert ( + len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab + ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" + + self.decoder = { + v: k for k, v in self.mergeable_ranks.items() + } # type: dict[int, bytes|str] + self.decoder.update({v: k for k, v in self.special_tokens.items()}) + + self.tokenizer = enc # type: tiktoken.Encoding + + self.eod_id = self.tokenizer.eot_token + self.im_start_id = self.special_tokens[IMSTART] + self.im_end_id = self.special_tokens[IMEND] + + def __getstate__(self): + # for pickle lovers + state = self.__dict__.copy() + del state['tokenizer'] + return state + + def __setstate__(self, state): + # tokenizer is not python native; don't pass it; rebuild it + self.__dict__.update(state) + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + self.tokenizer = enc + + + def __len__(self) -> int: + return self.tokenizer.n_vocab + + def get_vocab(self) -> Dict[bytes, int]: + return self.mergeable_ranks + + def convert_tokens_to_ids( + self, tokens: Union[bytes, str, List[Union[bytes, str]]] + ) -> List[int]: + ids = [] + if isinstance(tokens, (str, bytes)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.mergeable_ranks.get(tokens) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.mergeable_ranks.get(token)) + return ids + + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + if not special_tokens and new_tokens: + raise ValueError('Adding regular tokens is not supported') + for token in new_tokens: + surface_form = token.content if isinstance(token, AddedToken) else token + if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST: + raise ValueError('Adding unknown special tokens is not supported') + return 0 + + def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: + """ + Save only the vocabulary of the tokenizer (vocabulary). + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + file_path = os.path.join(save_directory, "qwen.tiktoken") + with open(file_path, "w", encoding="utf8") as w: + for k, v in self.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" + w.write(line) + return (file_path,) + + def tokenize( + self, + text: str, + allowed_special: Union[Set, str] = "all", + disallowed_special: Union[Collection, str] = (), + **kwargs, + ) -> List[Union[bytes, str]]: + """ + Converts a string in a sequence of tokens. + + Args: + text (`str`): + The sequence to be encoded. + allowed_special (`Literal["all"]` or `set`): + The surface forms of the tokens to be encoded as special tokens in regular texts. + Default to "all". + disallowed_special (`Literal["all"]` or `Collection`): + The surface forms of the tokens that should not be in regular texts and trigger errors. + Default to an empty tuple. + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. + + Returns: + `List[bytes|str]`: The list of tokens. + """ + tokens = [] + text = unicodedata.normalize("NFC", text) + + # this implementation takes a detour: text -> token id -> token surface forms + for t in self.tokenizer.encode( + text, allowed_special=allowed_special, disallowed_special=disallowed_special + ): + tokens.append(self.decoder[t]) + + def _encode_imgurl(img_tokens): + assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag + img_tokens = img_tokens[1:-1] + img_url = b''.join(img_tokens) + out_img_tokens = list(map(self.decoder.get, img_url)) + if len(out_img_tokens) > IMG_TOKEN_SPAN: + raise ValueError("The content in {}..{} is too long".format( + self.image_start_tag, self.image_end_tag)) + out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens))) + out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag] + return out_img_tokens + + return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl) + + def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: + """ + Converts a sequence of tokens in a single string. + """ + text = "" + temp = b"" + for t in tokens: + if isinstance(t, str): + if temp: + text += temp.decode("utf-8", errors=self.errors) + temp = b"" + text += t + elif isinstance(t, bytes): + temp += t + else: + raise TypeError("token should only be of type types or str") + if temp: + text += temp.decode("utf-8", errors=self.errors) + return text + + @property + def vocab_size(self): + return self.tokenizer.n_vocab + + def _convert_id_to_token(self, index: int) -> Union[bytes, str]: + """Converts an id to a token, special tokens included""" + if index in self.decoder: + return self.decoder[index] + raise ValueError("unknown ids") + + def _convert_token_to_id(self, token: Union[bytes, str]) -> int: + """Converts a token to an id using the vocab, special tokens included""" + if token in self.special_tokens: + return self.special_tokens[token] + if token in self.mergeable_ranks: + return self.mergeable_ranks[token] + raise ValueError("unknown token") + + def _tokenize(self, text: str, **kwargs): + """ + Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + errors: str = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + + def _decode_imgurl(img_token_ids): + assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id + img_token_ids = img_token_ids[1:-1] + img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)] + img_url = bytes(img_token_ids).decode('utf-8') + return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id] + + token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl) + + if skip_special_tokens: + token_ids = [i for i in token_ids if i < self.eod_id] + return self.tokenizer.decode(token_ids, errors=errors or self.errors) + + def to_list_format(self, text: str): + text = unicodedata.normalize("NFC", text) + token_ids = self.tokenizer.encode( + text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,))) + + def _encode_vl_info(tokens): + if len(tokens) == 0: + return [] + if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id: + key = 'image' + elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id: + key = 'ref' + elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id: + key = 'box' + elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id: + key = 'quad' + else: + _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x + return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}] + _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x + val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8') + return [{key: val}] + + return _replace_closed_tag( + token_ids, + (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id), + (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id), + _encode_vl_info, + _encode_vl_info, + ) + + def from_list_format(self, list_format: List[Dict]): + text = '' + num_images = 0 + for ele in list_format: + if 'image' in ele: + num_images += 1 + text += f'Picture {num_images}: ' + text += self.image_start_tag + ele['image'] + self.image_end_tag + text += '\n' + elif 'text' in ele: + text += ele['text'] + elif 'box' in ele: + if 'ref' in ele: + text += self.ref_start_tag + ele['ref'] + self.ref_end_tag + for box in ele['box']: + text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag + else: + raise ValueError("Unsupport element: " + str(ele)) + return text + + def _fetch_latest_picture(self, response, history): + if history is None: + history = [] + _history = history + [(response, None)] + for q, r in _history[::-1]: + for ele in self.to_list_format(q)[::-1]: + if 'image' in ele: + return ele['image'] + return None + + def _fetch_all_box_with_ref(self, text): + list_format = self.to_list_format(text) + output = [] + for i, ele in enumerate(list_format): + if 'box' in ele: + bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(','))) + assert len(bbox) == 4 + output.append({'box': bbox}) + if i > 0 and 'ref' in list_format[i-1]: + output[-1]['ref'] = list_format[i-1]['ref'].strip() + return output + + def draw_bbox_on_latest_picture( + self, + response, + history=None, + ) -> Optional[Image.Image]: + image = self._fetch_latest_picture(response, history) + if image is None: + return None + if image.startswith("http://") or image.startswith("https://"): + image = Image.open(requests.get(image, stream=True).raw).convert("RGB") + h, w = image.height, image.width + else: + image = np.asarray(Image.open(image).convert("RGB")) + h, w = image.shape[0], image.shape[1] + visualizer = Visualizer(image) + + boxes = self._fetch_all_box_with_ref(response) + if not boxes: + return None + color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color + for box in boxes: + if 'ref' in box: # random new color for new refexps + color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) + x1, y1, x2, y2 = box['box'] + x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h)) + visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color) + if 'ref' in box: + visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left") + return visualizer.output + + +import colorsys +import logging +import math +import numpy as np +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg +from PIL import Image +import random + +logger = logging.getLogger(__name__) + + +class VisImage: + def __init__(self, img, scale=1.0): + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + img = img.astype("uint8") + self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") + + def save(self, filepath): + self.fig.savefig(filepath) + + def get_image(self): + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + def __init__(self, img_rgb, metadata=None, scale=1.0): + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + self.font_path = FONT_PATH + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 14 + self._default_font_size = max( + np.sqrt(self.output.height * self.output.width) // 30, 15 // scale + ) + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + fontproperties=FontProperties(fname=self.font_path), + bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 4, 1) + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def get_output(self): + + return self.output diff --git a/megatron_patch/training.py b/megatron_patch/training.py index 64afb3d8..f90c6502 100644 --- a/megatron_patch/training.py +++ b/megatron_patch/training.py @@ -25,9 +25,8 @@ from megatron.core import mpu, tensor_parallel from megatron.initialize import (initialize_megatron, set_jit_fusion_options, write_args_to_tensorboard) -from megatron.model import DistributedDataParallel as DDP + from megatron.model import Float16Module -from megatron.optimizer import get_megatron_optimizer from megatron.training import (build_train_valid_test_data_iterators, get_optimizer_param_scheduler, print_datetime, save_checkpoint_and_time) @@ -51,7 +50,8 @@ def pretrain(train_valid_test_dataset_provider, forward_step_func, process_non_loss_data_func=None, extra_args_provider=None, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}): + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + moe=False): """Main training program. Refer to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training.py @@ -82,10 +82,13 @@ def pretrain(train_valid_test_dataset_provider, args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. """ - - # Initalize and get arguments, timers, and Tensorboard writer. - initialize_megatron(extra_args_provider=extra_args_provider, - args_defaults=args_defaults) + if not moe: + # Initalize and get arguments, timers, and Tensorboard writer. + initialize_megatron(extra_args_provider=extra_args_provider, + args_defaults=args_defaults) + else: + from megatron_patch.initialize import initialize_megatron + initialize_megatron(extra_args_provider=extra_args_provider) # Set pytorch JIT layer fusion options and warmup JIT functions. set_jit_fusion_options() @@ -293,15 +296,20 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap sum([sum([p.nelement() for p in model_module.parameters() if p.requires_grad == True]) for model_module in model])), flush=True) - # GPU allocation. - for model_module in model: - model_module.cuda(torch.cuda.current_device()) + if args.transformer_type == "megatron": + # GPU allocation. + for model_module in model: + model_module.cuda(torch.cuda.current_device()) # Fp16 conversion. if args.fp16 or args.bf16: model = [Float16Module(model_module, args) for model_module in model] if wrap_with_ddp: + if not args.moe: + from megatron.model import DistributedDataParallel as DDP + else: + from megatron_patch.distributed import DistributedDataParallel as DDP model = [DDP(model_module, data_parallel_group=mpu.get_data_parallel_group(), accumulate_allreduce_grads_in_fp32=args.accumulate_allreduce_grads_in_fp32, @@ -330,8 +338,15 @@ def setup_model_and_optimizer(model_provider_func, if args.load is not None and args.no_load_optim: load_checkpoint(model, None, None) - optimizer = get_megatron_optimizer(model, no_wd_decay_cond, scale_lr_cond, - lr_mult) + if not args.moe: + from megatron.optimizer import get_megatron_optimizer + optimizer = get_megatron_optimizer(model, no_wd_decay_cond, scale_lr_cond, + lr_mult) + else: + from megatron.optimizer import get_megatron_optimizer + #from megatron_patch.optimizer import get_megatron_optimizer + optimizer = get_megatron_optimizer(model, no_wd_decay_cond, scale_lr_cond, + lr_mult) opt_param_scheduler = get_optimizer_param_scheduler(optimizer) if args.load is not None: @@ -634,6 +649,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, config.grad_scale_func = optimizer.scale_loss config.timers = timers # TODO: Remove this once we move DDP to Core. + + if not args.moe: + from megatron.model import DistributedDataParallel as DDP + else: + from megatron_patch.distributed import DistributedDataParallel as DDP + if len(model) == 1 and isinstance(model[0], DDP) and \ args.overlap_grad_reduce: assert config.no_sync_func is None, \ diff --git a/toolkits/model_checkpoints_convertor/mixtral/checkpoint_reshaping_and_interoperability.py b/toolkits/model_checkpoints_convertor/mixtral/checkpoint_reshaping_and_interoperability.py new file mode 100644 index 00000000..ba13498b --- /dev/null +++ b/toolkits/model_checkpoints_convertor/mixtral/checkpoint_reshaping_and_interoperability.py @@ -0,0 +1,739 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import json +import os +import re +import sys +import types +import numpy as np +import torch +seed = 1234 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) + +from transformers import AutoTokenizer, GPT2Config, LlamaConfig +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint + + +def add_checkpointing_args(parser): + parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument( + "--convert_checkpoint_from_megatron_to_transformers", + action="store_true", + help=( + "If True, convert a Megatron checkpoint to a Transformers checkpoint. " + "If False, convert a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--load_path", + type=str, + required=True, + help="Path to the checkpoint to convert.", + ) + parser.add_argument( + "--save_path", + type=str, + required=True, + help="Path to the converted checkpoint.", + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="model name", + ) + parser.add_argument("--print-checkpoint-structure", action="store_true") + return parser + + +def add_megatron_checkpoint_args(parser): + parser.add_argument( + "--target_tensor_model_parallel_size", + type=int, + default=1, + help=( + "The tensor model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_pipeline_model_parallel_size", + type=int, + default=1, + help=( + "The pipeline model parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_data_parallel_size", + type=int, + default=1, + help=( + "The data parallel size of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--target_params_dtype", + type=str, + default="fp32", + help=( + "The dtype of the converted checkpoint. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + parser.add_argument( + "--make_vocab_size_divisible_by", + type=int, + default=128, + help=( + "Pad the vocab size to be divisible by this value. " + "This is added for computational efficieny reasons. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + + parser.add_argument( + '--extra_num_vocabs', + type=int, + default=0, + ) + + parser.add_argument( + "--use_distributed_optimizer", + action="store_true", + help=( + "If True, use the distributed optimizer. " + "Only used when converting a Transformers checkpoint to a Megatron checkpoint." + ), + ) + return parser + + +def add_transformers_checkpoint_args(parser): + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help=( + "The name of the pre-trained tokenizer to save. " + "If not None, the tokenizer will be saved. " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + parser.add_argument( + "--max_shard_size", + type=str, + default="10GB", + help=( + "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " + "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " + "Only used when converting a Megatron checkpoint to a Transformers checkpoint." + ), + ) + + return parser + +transformers_to_megatron = { + "self_attn.dense": "self_attention.dense", + "mlp.megatron_moe.gate.wg":"mlp.megatron_moe.gate.wg", + "mlp.megatron_moe.experts.megatron_experts.0.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.0.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.1.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.1.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.2.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.2.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.3.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.3.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.4.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.4.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.5.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.5.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.6.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.6.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.7.dense_h_to_4h_1":"mlp.megatron_moe.experts.megatron_experts.7.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.0.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.0.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.1.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.1.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.2.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.2.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.3.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.3.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.4.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.4.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.5.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.5.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.6.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.6.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.7.dense_h_to_4h_2":"mlp.megatron_moe.experts.megatron_experts.7.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.0.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.0.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.1.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.1.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.2.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.2.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.3.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.3.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.4.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.4.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.5.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.5.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.6.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.6.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.7.dense_4h_to_h":"mlp.megatron_moe.experts.megatron_experts.7.dense_4h_to_h" +} + +tensor_parallel_params = [ + # megatron-lm layers to merge across tp ranks + "self_attn.query.weight", + "self_attn.key_value.weight", + "self_attn.dense.weight", + "mlp.megatron_moe.experts.megatron_experts.0.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.1.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.2.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.3.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.4.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.5.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.6.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.7.dense_h_to_4h_1", + "mlp.megatron_moe.experts.megatron_experts.0.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.1.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.2.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.3.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.4.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.5.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.6.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.7.dense_h_to_4h_2", + "mlp.megatron_moe.experts.megatron_experts.0.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.1.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.2.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.3.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.4.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.5.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.6.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.7.dense_4h_to_h" +] + +def recursive_print(name, val, spaces=0): + """ + Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` + Args: + name (str): the name of the current tensor parameter + val (Tuple(int)): the shape of the current tensor parameter + spaces (int): the number of spaces to print before the output for a nested structure + """ + # Format the message. + if name is None: + msg = None + else: + fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + msg = fmt.format(name) + + # Print and recurse (if needed). + if isinstance(val, dict): + if msg is not None: + print(msg) + for k in val.keys(): + recursive_print(k, val[k], spaces + 2) + elif isinstance(val, torch.Tensor): + print(msg, ":", val.size()) + else: + print(msg, ":", val) + + +def megatron_to_transformers_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] for compatibility with later versions + of NVIDIA Megatron-LM. The inverse operation is performed inside Megatron-LM to read checkpoints: + https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 If param is the weight tensor of the + self-attention block, the returned tensor will have to be transposed one more time to be read by HuggingFace GPT2. + This function is taken from `convert_megatron_gpt2_checkpoint.py` + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def transformers_to_megatron_fix_query_key_value_ordering( + param, checkpoint_version, num_splits, num_heads, hidden_size +): + """ + Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input + is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version + 1.0 and [num_heads * num_splits * hidden_size, :] for version 2.0 and later. If param is the weight tensor of the + self-attention block, the param needs to be already transposed before calling this function. + Args: + param (torch.Tensor): the tensor to permute + checkpoint_version (int): the version of the checkpoint. + num_splits (int): the number of projections, usually 3 for (Query, Key, Value) + num_heads (int): the number of attention heads + hidden_size (int): the hidden size per head + """ + + # Input is [num_splits * num_heads * hidden_size, :] + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + current_shape = (num_splits, num_heads, hidden_size) + input_shape[1:] + param = param.view(*current_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + +def merge_transformers_sharded_states_7b(path, num_checkpoints): + """ + Merge sharded checkpoints from transformers into a single checkpoint. + Args: + path (str): the path to the sharded checkpoints + num_checkpoints (int): the number of checkpoints to merge + """ + state_dict = {} + for i in range(0, num_checkpoints): + checkpoint_path = os.path.join(path, f"consolidated.{i:02d}.pt") + print(checkpoint_path) + current_chunk = torch.load(checkpoint_path, map_location="cpu") + state_dict.update(current_chunk) + return state_dict + +def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): + """ + Get sharded checkpoints from NVIDIA Megatron-LM checkpoint based on the provided tensor parallel size, pipeline + parallel size and pipeline parallel rank. + Args: + args (argparse.Namespace): the arguments to the script + tp_size (int): the tensor parallel size + pp_size (int): the pipeline parallel size + pp_rank (int): the pipeline parallel rank + """ + tp_state_dicts = [] + for i in range(tp_size): + sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}" + checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0] + checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) + state_dict = torch.load(checkpoint_path, map_location="cpu") + tp_state_dicts.append(state_dict) + return tp_state_dicts + + +def get_element_from_dict_by_path(d, path): + """ + Get element from dictionary by path. If element is not present, recursively add empty dictionaries. + Args: + d (dict): the dictionary to get the element from + path (list): the path to the element which is delimited by "." + """ + path = path.split(".") + for k in path: + if k not in d: + d[k] = {} + d = d[k] + return d + +def _init_embedding_weights(module): + std = 0.02 + module.weight.data.normal_(mean=0.0, std=std) + + +def convert_checkpoint_from_transformers_to_megatron(args): + """ + Convert a checkpoint from HuggingFace Transformers to Megatron-LM. This allows converted checkpoints with variable + tensor parallelism and pipeline parallelism sizes. It takes as input a checkpoint from HuggingFace Transformers + which can have multiple shards. + Args: + args (argparse.Namespace): the arguments to the script + """ + os.makedirs(args.save_path, exist_ok=True) + # Search in directory above this + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + try: + from megatron.tokenizer.tokenizer import _vocab_size_with_padding + except ModuleNotFoundError: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + exit(1) + + # load the transformers model state dict and config + if args.model_name == "mixtral-8x7b": + from transformers import AutoModelForCausalLM + state_dict = AutoModelForCausalLM.from_pretrained(args.load_path).state_dict() + else: + raise ValueError("model name is not supported") + + config = GPT2Config.from_pretrained(args.load_path) + internal_state_dict = {} + for layer_id in range(config.num_hidden_layers): + q_weight = state_dict['model.layers.'+str(layer_id)+'.self_attn.q_proj.weight'] + k_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.k_proj.weight'] + v_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.v_proj.weight'] + + internal_state_dict['transformer.layers.'+str(layer_id)+'.self_attn.query.weight'] = q_weight + internal_state_dict['transformer.layers.'+str(layer_id)+'.self_attn.key_value.weight'] = torch.cat((k_weight, v_weight)) + + internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.dense.weight'] =\ + state_dict['model.layers.' + str(layer_id) + '.self_attn.o_proj.weight'] + + internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.gate.wg.weight'] = state_dict[ + 'model.layers.' + str(layer_id) + '.block_sparse_moe.gate.weight'] + + for expert_id in range(config.num_local_experts): + + internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.experts.megatron_experts.' + str(expert_id)+'.dense_h_to_4h_1.weight'] = \ + state_dict['model.layers.' + str(layer_id) + '.block_sparse_moe.experts.' + str(expert_id) + '.w1.weight'] + + internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.experts.megatron_experts.' + str(expert_id)+'.dense_h_to_4h_2.weight'] = \ + state_dict['model.layers.' + str(layer_id) + '.block_sparse_moe.experts.' + str(expert_id) + '.w3.weight'] + + internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.experts.megatron_experts.' + str(expert_id)+'.dense_4h_to_h.weight'] = state_dict[ + 'model.layers.' + str(layer_id) + '.block_sparse_moe.experts.'+str(expert_id) +'.w2.weight'] + + internal_state_dict['transformer.layers.' + str(layer_id) + '.input_layernorm.weight'] = state_dict[ + 'model.layers.' + str(layer_id) + '.input_layernorm.weight'] + + internal_state_dict['transformer.layers.' + str(layer_id) + '.post_attention_layernorm.weight'] = state_dict[ + 'model.layers.' + str(layer_id) + '.post_attention_layernorm.weight'] + + internal_state_dict["transformer.word_embeddings.weight"] = state_dict['model.embed_tokens.weight'] + internal_state_dict["transformer.final_layernorm.weight"] = state_dict['model.norm.weight'] + internal_state_dict["transformer.lm_head.weight"] = state_dict['lm_head.weight'] + state_dict = internal_state_dict + + # Saving config and tokenzier files + os.system("cp -rf "+args.load_path+"/*.json "+args.save_path) + os.system("cp -rf " + args.load_path + "/tokenizer* " + args.save_path) + + # Saving the tracker file + tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt") + with open(tracker_filepath, "w") as f: + f.write("release") + + # create `release` dir in args.load_path + release_dir = os.path.join(args.save_path, "release") + os.makedirs(release_dir, exist_ok=True) + + # megatron args + megatron_args = { + "orig_vocab_size": config.vocab_size, + "hidden_size": config.hidden_size, + "num_layers": config.num_hidden_layers, + "num_attention_heads": config.num_attention_heads, + "tensor_model_parallel_size": args.target_tensor_model_parallel_size, + "pipeline_model_parallel_size": args.target_pipeline_model_parallel_size, + "data_parallel_size": args.target_data_parallel_size, + "make_vocab_size_divisible_by": args.make_vocab_size_divisible_by, + "rank": 0, + "tokenizer_type": "GPT2BPETokenizer", + } + + margs = types.SimpleNamespace() + for k, v in megatron_args.items(): + setattr(margs, k, v) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + setattr(margs, "params_dtype", dtype) + + # Convert. + print("Converting") + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + num_query_group = 8 + output_group_state_dict = [] + for i in range(num_query_group): + output_group_state_dict.append({}) + + # Embedding layer + print("converting embedding layer") + word_embedding = state_dict["transformer.word_embeddings.weight"].to(dtype) + lm_head = state_dict["transformer.lm_head.weight"].to(dtype) + orig_vocab_size = config.vocab_size + #padded_vocab_size = _vocab_size_with_padding(orig_vocab_size, margs) + padded_vocab_size = orig_vocab_size + setattr(margs, "padded_vocab_size", padded_vocab_size) + # Cut out extra padding we don't need + if args.extra_num_vocabs == 0: + full_word_embed = word_embedding + full_lm_head = lm_head + else: + new_embeddings = torch.nn.Embedding(args.extra_num_vocabs, word_embedding.shape[1]) + # initialize all new embeddings (in particular added tokens) + _init_embedding_weights(new_embeddings) + full_word_embed = torch.cat([word_embedding, new_embeddings.weight]) + full_lm_head = torch.cat([lm_head, new_embeddings.weight]) + + # Split into new tensor model parallel sizes + out_word_embed = torch.chunk(full_word_embed, args.target_tensor_model_parallel_size, dim=0) + for i in range(args.target_tensor_model_parallel_size): + word_emb_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.language_model.embedding.word_embeddings" + ) + word_emb_dict["weight"] = out_word_embed[i] + + out_lm_head = torch.chunk(full_lm_head, args.target_tensor_model_parallel_size, dim=0) + for i in range(args.target_tensor_model_parallel_size): + lm_head_dict = get_element_from_dict_by_path( + output_state_dict[i], "model.lm_head" + ) + lm_head_dict["weight"] = out_lm_head[i] + + # Transformer layers + print("converting transformer layers") + if config.num_hidden_layers % args.target_pipeline_model_parallel_size != 0: + raise ValueError( + f"Number of layers ({config.num_hidden_layers}) must be divisible by number of pipeline parallelism" + f" ({args.target_pipeline_model_parallel_size})" + ) + num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size + + layer_re = re.compile("transformer.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + # The number of heads. + heads = config.num_attention_heads + # The hidden_size per head. + hidden_size = 4096 + num_groups = 8 + head_dim = 128 + num_heads = 32 + hidden_size_per_head = config.hidden_size // config.num_attention_heads + for pp_rank in range(args.target_pipeline_model_parallel_size): + layer_offset = pp_rank * num_layers + if pp_rank > 0: + output_state_dict = [] + for i in range(args.target_tensor_model_parallel_size): + output_state_dict.append({}) + + output_group_state_dict = [] + for i in range(num_query_group): + output_group_state_dict.append({}) + + for layer in range(num_layers): + pp_layer_id = layer + layer_offset + layers_to_copy = [ + layer_name + for layer_name in state_dict.keys() + if layer_name.startswith(f"transformer.layers.{pp_layer_id}.") + ] + + for layer_name in layers_to_copy: + m = layer_re.match(layer_name) + # Stop if that's not a layer + if m is None: + break + + # The index of the layer. + _ = int(m.group(1)) + # The name of the operation. + op_name = m.group(2) + + # Is it a weight or a bias? + weight = m.group(3) + + params = state_dict[layer_name].to(dtype) + # handle layernorm + if op_name.startswith("input_layernorm") or op_name.startswith("post_attention_layernorm"): + out_name = "input_layernorm" if op_name.endswith("input_layernorm") else "post_attention_layernorm" + layer_name = f"layers.{layer}.{out_name}.{weight}" + + elif op_name.startswith("self_attn.rotary_emb"): + layer_name = f"layers.{layer}.self_attention.rotary_emb.inv_freq" + + # handle attention K, V, Q weights + elif op_name.startswith("self_attn.query") and weight == "weight": + # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D. + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 1, + heads, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.query.{weight}" + + elif op_name.startswith("self_attn.key_value") and weight == "weight": + # transformers stores D X (3*D) but Megatron-LM expects (3*D) X D. + params = transformers_to_megatron_fix_query_key_value_ordering( + params, + 3.0, + 2, + 8, + hidden_size_per_head, + ) + layer_name = f"layers.{layer}.self_attention.key_value.{weight}" + + # handle attention and mlp weights + elif weight == "weight": + out_name = transformers_to_megatron.get(op_name, None) + if out_name is None: + continue + layer_name = f"layers.{layer}.{out_name}.{weight}" + + # skip + else: + continue + + if op_name + "." + weight in tensor_parallel_params: + dim = 1 if op_name in ["self_attn.dense", + "mlp.megatron_moe.experts.megatron_experts.0.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.1.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.2.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.3.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.4.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.5.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.6.dense_4h_to_h", + "mlp.megatron_moe.experts.megatron_experts.7.dense_4h_to_h" + ] else 0 + params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=dim) + + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = ( + params[i].clone() if (op_name + "." + weight in tensor_parallel_params) else params.clone() + ) + + + for i in range(args.target_tensor_model_parallel_size): + + params_dict = get_element_from_dict_by_path(output_state_dict[i], + "model.language_model.encoder") + for expert_id in range(config.num_local_experts): + + dense_h_to_4h_1_name = 'mlp.megatron_moe.experts.megatron_experts.' + str(expert_id)+'.dense_h_to_4h_1.weight' + dense_h_to_4h_1_layer_name = f"layers.{layer}.{dense_h_to_4h_1_name}" + dense_h_to_4h_1_weight = params_dict[dense_h_to_4h_1_layer_name] + + dense_h_to_4h_2_name = 'mlp.megatron_moe.experts.megatron_experts.' + str(expert_id)+'.dense_h_to_4h_2.weight' + dense_h_to_4h_2_layer_name = f"layers.{layer}.{dense_h_to_4h_2_name}" + dense_h_to_4h_2_weight = params_dict[dense_h_to_4h_2_layer_name] + + dense_h_to_4h_name = 'mlp.megatron_moe.experts.megatron_experts.' + str(expert_id)+'.dense_h_to_4h.weight' + dense_h_to_4h_layer_name = f"layers.{layer}.{dense_h_to_4h_name}" + + params_dict[dense_h_to_4h_layer_name] = torch.cat( + [dense_h_to_4h_1_weight, dense_h_to_4h_2_weight], dim=0) + + del params_dict[dense_h_to_4h_1_layer_name] + del params_dict[dense_h_to_4h_2_layer_name] + + query_name = 'self_attention.query.weight' + query_layer_name = f"layers.{layer}.{query_name}" + query_weight = params_dict[query_layer_name] + + kv_name = 'self_attention.key_value.weight' + kv_layer_name = f"layers.{layer}.{kv_name}" + kv_weight = params_dict[kv_layer_name] + + qkv_name = 'self_attention.query_key_value.weight' + qkv_layer_name = f"layers.{layer}.{qkv_name}" + + # torch.Size([32 128, 4096]) + group_query_weight = query_weight.view(num_groups // args.target_tensor_model_parallel_size, num_heads // num_groups * head_dim, hidden_size) + # torch.Size(8, 256, 4096]) + group_kv_weight = kv_weight.view(num_groups // args.target_tensor_model_parallel_size, 2 * head_dim, hidden_size) + + group_qkv_weight = torch.cat([group_query_weight, group_kv_weight], dim=1) + params_dict[qkv_layer_name] = group_qkv_weight.view(-1, hidden_size) + + del params_dict[query_layer_name] + del params_dict[kv_layer_name] + + if pp_rank == args.target_pipeline_model_parallel_size - 1: + # handle final layernorm + for weight_or_bias in ["weight"]: + params = state_dict[f"transformer.final_layernorm.{weight_or_bias}"].to(dtype) + layer_name = f"final_layernorm.{weight_or_bias}" + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") + params_dict[layer_name] = params.clone() + + # add the LM head + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head") + params_dict["weight"] = out_word_embed[i].clone() + + # add the LM head + for i in range(args.target_tensor_model_parallel_size): + params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.output_layer") + params_dict["weight"] = out_lm_head[i].clone() + + + # saving the state dict as per the tp_rank and pp_rank + for tp_rank in range(args.target_tensor_model_parallel_size): + output_state_dict[tp_rank]["checkpoint_version"] = 3.0 + output_state_dict[tp_rank]["args"] = margs + checkpoint_dir = ( + f"mp_rank_{tp_rank:02d}" + if args.target_pipeline_model_parallel_size == 1 + else f"mp_rank_{tp_rank:02d}_{pp_rank:03d}" + ) + + for layer_id in range(config.num_hidden_layers): + for expert_id in range(config.num_local_experts): + + moe_state_dict = {} + + moe_dense_h_to_4h_path = "layers." + str(layer_id)+".mlp.megatron_moe.experts.megatron_experts." + str(expert_id) +".dense_h_to_4h.weight" + + moe_dense_4h_to_h_path = "layers." + str(layer_id)+".mlp.megatron_moe.experts.megatron_experts." + str(expert_id) +".dense_4h_to_h.weight" + + moe_state_dict["module.module.language_model.encoder."+moe_dense_h_to_4h_path] = output_state_dict[tp_rank]['model']['language_model']['encoder'][moe_dense_h_to_4h_path] + moe_state_dict["module.module.language_model.encoder." + moe_dense_4h_to_h_path] = output_state_dict[tp_rank]['model']['language_model']['encoder'][moe_dense_4h_to_h_path] + output_state_dict[tp_rank]['model']['language_model']['encoder'].pop(moe_dense_h_to_4h_path) + output_state_dict[tp_rank]['model']['language_model']['encoder'].pop(moe_dense_4h_to_h_path) + + moe_checkpoint_path = "layer_" + str(layer_id)+"_expert_" + str(expert_id) + "_" +checkpoint_dir+"_model_states.pt" + torch.save(moe_state_dict, os.path.join(release_dir, moe_checkpoint_path)) + + checkpoint_name = "model_rng.pt" + checkpoint_dir = os.path.join(release_dir, checkpoint_dir) + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + if args.print_checkpoint_structure: + print( + f"Checkpoint structure of model state dict shard belonging to TP rank {tp_rank} and PP rank" + f" {pp_rank}:" + ) + recursive_print(None, output_state_dict[tp_rank]) + torch.save(output_state_dict[tp_rank], checkpoint_path) + +def main(): + parser = argparse.ArgumentParser() + parser = add_checkpointing_args(parser) + parser = add_megatron_checkpoint_args(parser) + parser = add_transformers_checkpoint_args(parser) + args = parser.parse_args() + convert_checkpoint_from_transformers_to_megatron(args) + + +if __name__ == "__main__": + main() diff --git a/toolkits/model_checkpoints_convertor/mixtral/model_convertor.sh b/toolkits/model_checkpoints_convertor/mixtral/model_convertor.sh new file mode 100644 index 00000000..2f0acf1c --- /dev/null +++ b/toolkits/model_checkpoints_convertor/mixtral/model_convertor.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# megatron to transformers: You need to copy the tokenizer files into the save_path +# bash model_convertor.sh ../../Megatron-LM/ ../../llama-hf2mg-test-2-2/release/ ../../llama_mg2hf 1 1 llama-7b 1 true +# transformers to megatron +# bash model_convertor.sh ../../Megatron-LM/ ../../llama-7b-hf ../../llama-hf2mg 1 1 llama-7b 1 false +set -e +START_TIME=$SECONDS + +MEGATRON_PATH=$1 +SOURCE_CKPT_PATH=$2 +TARGET_CKPT_PATH=$3 +TP=$4 +PP=$5 +MN=$6 #mixtral-8x7b +EXTRA_VOCAB_SIZE=$7 +mg2hf=$8 + +if [ $mg2hf = true ]; then + do_options=" + --convert_checkpoint_from_megatron_to_transformers + " +elif [ $mg2hf = false ]; then + do_options="" +fi + +export PYTHONPATH=${MEGATRON_PATH}:$PYTHONPATH + +python checkpoint_reshaping_and_interoperability.py \ +--load_path ${SOURCE_CKPT_PATH} \ +--save_path ${TARGET_CKPT_PATH} \ +--target_params_dtype fp16 \ +--megatron-path ${MEGATRON_PATH} \ +--target_tensor_model_parallel_size ${TP} \ +--target_pipeline_model_parallel_size ${PP} \ +--model_name ${MN} \ +--extra_num_vocabs ${EXTRA_VOCAB_SIZE} \ +${do_options} + +ELAPSED_TIME=$(($SECONDS - $START_TIME)) +echo "$(($ELAPSED_TIME/60)) min $(($ELAPSED_TIME%60)) sec"