diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index ed42694a166a..26927b20dd3a 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1160,7 +1160,7 @@ def configure( custom_policy=self.custom_policy, ) - self.set_resized_embedding_to_optimizer(model, optimizer, param_info) + # self.set_resized_embedding_to_optimizer(model, optimizer, param_info) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ["fp16", "bf16"]: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 1d2b7a570681..bbd8c91af4b2 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -5,9 +5,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import torch import torch.nn as nn from torch import Tensor from torch.nn import Module +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.pipeline.stage_manager import PipelineStageManager @@ -243,3 +245,46 @@ def get_stage_index( stage_indices.append([start_idx, end_idx]) return stage_indices[0] if num_model_chunks == 1 else stage_indices + + + def resize_token_embeddings(self, model, new_num_tokens): + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is not None: + self._resize_token_embeddings(model, input_embeddings, new_num_tokens) + output_embedddings = self.model.get_output_embeddings() + if output_embedddings is not None: + self._resize_lm_head(model, output_embedddings, new_num_tokens) + + def _resize_token_embeddings(self, model, embedding, new_num_tokens): + LazyInitContext.materialize(embedding) + old_num_tokens = embedding.num_embeddings + input_embedding_dim = embedding.embedding_dim + old_weight_data = embedding.weight.data + embedding.num_embeddings = new_num_tokens + if embedding.padding_idx is not None and embedding.padding_idx > new_num_tokens: + embedding.padding_idx = embedding.padding_idx - (old_num_tokens-new_num_tokens) + factory_kwargs = {'device': embedding.weight.device, 'dtype': embedding.weight.dtype} + embedding.weight.data = torch.empty((new_num_tokens, input_embedding_dim), **factory_kwargs) + embedding.reset_parameters() + model._init_weights(embedding) + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + embedding.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :] + + def _resize_lm_head(self, model, lm_head, new_num_tokens): + LazyInitContext.materialize(lm_head) + old_num_tokens, lm_head_dim = (lm_head.weight.size()) + old_weight_data = lm_head.weight.data + old_bias_data = lm_head.bias.data if lm_head.bias is not None else None + lm_head.out_features = new_num_tokens + factory_kwargs = {'device': lm_head.weight.device, 'dtype': lm_head.weight.dtype} + lm_head.weight.data = torch.empty((new_num_tokens, lm_head_dim), **factory_kwargs) + if lm_head.bias is not None: + lm_head.bias.data = torch.empty(new_num_tokens, **factory_kwargs) + lm_head.reset_parameters() + model._init_weights(lm_head) + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + lm_head.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :] + if lm_head.bias is not None: + lm_head.bias.data[:num_tokens_to_copy] = old_bias_data[:num_tokens_to_copy] diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 9d3ebed2540c..0a4bd5bd48b4 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,4 +1,5 @@ from functools import partial +import math from typing import Callable, Dict, List from torch import Tensor, nn @@ -36,10 +37,10 @@ def preprocess(self): multiple = self.shard_config.make_vocab_size_divisible_by if self.shard_config.enable_tensor_parallelism: world_size = self.shard_config.tensor_parallel_size - multiple = multiple * world_size + multiple = multiple * world_size // (math.gcd(multiple, world_size)) if vocab_size % multiple != 0: new_vocab_size = vocab_size + multiple - vocab_size % multiple - self.model.resize_token_embeddings(new_vocab_size) + self.resize_token_embeddings(self.model, new_vocab_size) return self.model def module_policy(self): diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 42bf0825b045..fbe03d72b030 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,4 +1,5 @@ import warnings +import math from functools import partial from typing import Callable, Dict, List, Union @@ -23,15 +24,14 @@ def config_sanity_check(self): pass def preprocess(self): + vocab_size = self.model.config.vocab_size + multiple = self.shard_config.make_vocab_size_divisible_by if self.shard_config.enable_tensor_parallelism: - # Resize embedding - vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - + multiple = multiple * world_size // (math.gcd(multiple, world_size)) + if vocab_size % multiple != 0: + new_vocab_size = vocab_size + multiple - vocab_size % multiple + self.resize_token_embeddings(self.model, new_vocab_size) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 24f9627c269c..e3d78d86be77 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -109,14 +109,14 @@ def date_gen_for_double_heads(): config_for_token_classification.num_labels = 2 # register the following models -model_zoo.register( - name="transformers_gpt", - model_fn=lambda: transformers.GPT2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_gpt2_model, - model_attribute=ModelAttribute(has_control_flow=True), -) +# model_zoo.register( +# name="transformers_gpt", +# model_fn=lambda: transformers.GPT2Model(config), +# data_gen_fn=data_gen, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn_for_gpt2_model, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) model_zoo.register( name="transformers_gpt_lm", model_fn=lambda: transformers.GPT2LMHeadModel(config), @@ -125,35 +125,35 @@ def date_gen_for_double_heads(): loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True), ) -model_zoo.register( - name="transformers_gpt_double_heads", - model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=date_gen_for_double_heads, - output_transform_fn=output_transform_fn, - loss_fn=lambda x: x.loss + x.mc_loss, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_question_answering", - model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_token_classification", - model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) -model_zoo.register( - name="transformers_gpt_for_sequence_classification", - model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True), -) +# model_zoo.register( +# name="transformers_gpt_double_heads", +# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), +# data_gen_fn=date_gen_for_double_heads, +# output_transform_fn=output_transform_fn, +# loss_fn=lambda x: x.loss + x.mc_loss, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_question_answering", +# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), +# data_gen_fn=data_gen_for_question_answering, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_token_classification", +# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), +# data_gen_fn=data_gen_for_token_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) +# model_zoo.register( +# name="transformers_gpt_for_sequence_classification", +# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), +# data_gen_fn=data_gen_for_sequence_classification, +# output_transform_fn=output_transform_fn, +# loss_fn=loss_fn, +# model_attribute=ModelAttribute(has_control_flow=True), +# ) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3155420f1cf2..2195726f66d3 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -124,7 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 4, "pp_size": 1, "enable_all_optimization": True, - "use_lazy_init": False, + "use_lazy_init": True, "precision": "fp32", }, {