Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Mar 17, 2024
1 parent a980e70 commit 9a5832a
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 51 deletions.
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def configure(
custom_policy=self.custom_policy,
)

self.set_resized_embedding_to_optimizer(model, optimizer, param_info)
# self.set_resized_embedding_to_optimizer(model, optimizer, param_info)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ["fp16", "bf16"]:
Expand Down
45 changes: 45 additions & 0 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.lazy.lazy_init import LazyInitContext

from colossalai.pipeline.stage_manager import PipelineStageManager

Expand Down Expand Up @@ -243,3 +245,46 @@ def get_stage_index(
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices


def resize_token_embeddings(self, model, new_num_tokens):
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
self._resize_token_embeddings(model, input_embeddings, new_num_tokens)
output_embedddings = self.model.get_output_embeddings()
if output_embedddings is not None:
self._resize_lm_head(model, output_embedddings, new_num_tokens)

def _resize_token_embeddings(self, model, embedding, new_num_tokens):
LazyInitContext.materialize(embedding)
old_num_tokens = embedding.num_embeddings
input_embedding_dim = embedding.embedding_dim
old_weight_data = embedding.weight.data
embedding.num_embeddings = new_num_tokens
if embedding.padding_idx is not None and embedding.padding_idx > new_num_tokens:
embedding.padding_idx = embedding.padding_idx - (old_num_tokens-new_num_tokens)
factory_kwargs = {'device': embedding.weight.device, 'dtype': embedding.weight.dtype}
embedding.weight.data = torch.empty((new_num_tokens, input_embedding_dim), **factory_kwargs)
embedding.reset_parameters()
model._init_weights(embedding)
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
embedding.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :]

def _resize_lm_head(self, model, lm_head, new_num_tokens):
LazyInitContext.materialize(lm_head)
old_num_tokens, lm_head_dim = (lm_head.weight.size())
old_weight_data = lm_head.weight.data
old_bias_data = lm_head.bias.data if lm_head.bias is not None else None
lm_head.out_features = new_num_tokens
factory_kwargs = {'device': lm_head.weight.device, 'dtype': lm_head.weight.dtype}
lm_head.weight.data = torch.empty((new_num_tokens, lm_head_dim), **factory_kwargs)
if lm_head.bias is not None:
lm_head.bias.data = torch.empty(new_num_tokens, **factory_kwargs)
lm_head.reset_parameters()
model._init_weights(lm_head)
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
lm_head.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :]
if lm_head.bias is not None:
lm_head.bias.data[:num_tokens_to_copy] = old_bias_data[:num_tokens_to_copy]
5 changes: 3 additions & 2 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
import math
from typing import Callable, Dict, List

from torch import Tensor, nn
Expand Down Expand Up @@ -36,10 +37,10 @@ def preprocess(self):
multiple = self.shard_config.make_vocab_size_divisible_by
if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size
multiple = multiple * world_size
multiple = multiple * world_size // (math.gcd(multiple, world_size))
if vocab_size % multiple != 0:
new_vocab_size = vocab_size + multiple - vocab_size % multiple
self.model.resize_token_embeddings(new_vocab_size)
self.resize_token_embeddings(self.model, new_vocab_size)
return self.model

def module_policy(self):
Expand Down
14 changes: 7 additions & 7 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import math
from functools import partial
from typing import Callable, Dict, List, Union

Expand All @@ -23,15 +24,14 @@ def config_sanity_check(self):
pass

def preprocess(self):
vocab_size = self.model.config.vocab_size
multiple = self.shard_config.make_vocab_size_divisible_by
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

multiple = multiple * world_size // (math.gcd(multiple, world_size))
if vocab_size % multiple != 0:
new_vocab_size = vocab_size + multiple - vocab_size % multiple
self.resize_token_embeddings(self.model, new_vocab_size)
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
Expand Down
80 changes: 40 additions & 40 deletions tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def date_gen_for_double_heads():
config_for_token_classification.num_labels = 2

# register the following models
model_zoo.register(
name="transformers_gpt",
model_fn=lambda: transformers.GPT2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_gpt2_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
# model_zoo.register(
# name="transformers_gpt",
# model_fn=lambda: transformers.GPT2Model(config),
# data_gen_fn=data_gen,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn_for_gpt2_model,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
model_zoo.register(
name="transformers_gpt_lm",
model_fn=lambda: transformers.GPT2LMHeadModel(config),
Expand All @@ -125,35 +125,35 @@ def date_gen_for_double_heads():
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_double_heads",
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=date_gen_for_double_heads,
output_transform_fn=output_transform_fn,
loss_fn=lambda x: x.loss + x.mc_loss,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_for_question_answering",
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_for_token_classification",
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_for_sequence_classification",
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
# model_zoo.register(
# name="transformers_gpt_double_heads",
# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
# data_gen_fn=date_gen_for_double_heads,
# output_transform_fn=output_transform_fn,
# loss_fn=lambda x: x.loss + x.mc_loss,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
# model_zoo.register(
# name="transformers_gpt_for_question_answering",
# model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
# data_gen_fn=data_gen_for_question_answering,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
# model_zoo.register(
# name="transformers_gpt_for_token_classification",
# model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
# data_gen_fn=data_gen_for_token_classification,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
# model_zoo.register(
# name="transformers_gpt_for_sequence_classification",
# model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
# data_gen_fn=data_gen_for_sequence_classification,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"use_lazy_init": True,
"precision": "fp32",
},
{
Expand Down

0 comments on commit 9a5832a

Please sign in to comment.