Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
fix resize embedding

fix resize embedding
  • Loading branch information
flybird11111 committed Mar 17, 2024
1 parent a980e70 commit 64379ab
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 50 deletions.
43 changes: 3 additions & 40 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def unwrap(self):
return module


def get_param_info(optim: Optimizer, model: torch.nn.Module):
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
Expand All @@ -204,8 +204,6 @@ def get_param_info(optim: Optimizer, model: torch.nn.Module):
"param2id": {},
"id2param": {},
"param2shape": {},
"old_input_embedding_param_id": None,
"old_output_embedding_param_id": None,
}
start_index = 0
for group in optim.param_groups:
Expand All @@ -222,13 +220,6 @@ def get_param_info(optim: Optimizer, model: torch.nn.Module):
param_info["param_groups"].append(packed_group)
start_index += len(group["params"])

input_embedding = model.get_input_embeddings()
if input_embedding is not None:
param_info["old_input_embedding_param_id"] = id(input_embedding.weight)
output_embedding = model.get_output_embeddings()
if output_embedding is not None:
param_info["old_output_embedding_param_id"] = id(output_embedding.weight)

return param_info


Expand Down Expand Up @@ -1081,7 +1072,7 @@ def __init__(
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
# forced_dtype=PRECISION_TORCH_TYPE[precision],
forced_dtype=PRECISION_TORCH_TYPE[precision],
)

self.max_norm = max_norm
Expand All @@ -1090,32 +1081,6 @@ def __del__(self):
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()

def set_resized_embedding_to_optimizer(self, model, optimizer, param_info):
old_input_embedding_param_id = param_info["old_input_embedding_param_id"]
if old_input_embedding_param_id is not None:
for param_group in optimizer.param_groups:
group_params = param_group["params"]
new_params = []
for param in group_params:
if id(param) == old_input_embedding_param_id:
new_input_embeddings = model.module.get_input_embeddings()
new_params.append(new_input_embeddings.weight)
else:
new_params.append(param)
param_group["params"] = new_params
old_output_embedding_param_id = param_info["old_output_embedding_param_id"]
if old_output_embedding_param_id is not None:
for param_group in optimizer.param_groups:
group_params = param_group["params"]
new_params = []
for param in group_params:
if id(param) == old_output_embedding_param_id:
new_output_embeddings = model.module.get_output_embeddings()
new_params.append(new_output_embeddings.weight)
else:
new_params.append(param)
param_group["params"] = new_params

@property
def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
Expand Down Expand Up @@ -1146,7 +1111,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer, model)
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
Expand All @@ -1159,8 +1124,6 @@ def configure(
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)

self.set_resized_embedding_to_optimizer(model, optimizer, param_info)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ["fp16", "bf16"]:
Expand Down
45 changes: 45 additions & 0 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.lazy.lazy_init import LazyInitContext

from colossalai.pipeline.stage_manager import PipelineStageManager

Expand Down Expand Up @@ -243,3 +245,46 @@ def get_stage_index(
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices


def resize_token_embeddings(self, model, new_num_tokens):
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
self._resize_token_embeddings(model, input_embeddings, new_num_tokens)
output_embedddings = self.model.get_output_embeddings()
if output_embedddings is not None:
self._resize_lm_head(model, output_embedddings, new_num_tokens)

def _resize_token_embeddings(self, model, embedding, new_num_tokens):
LazyInitContext.materialize(embedding)
old_num_tokens = embedding.num_embeddings
input_embedding_dim = embedding.embedding_dim
old_weight_data = embedding.weight.data
embedding.num_embeddings = new_num_tokens
if embedding.padding_idx is not None and embedding.padding_idx > new_num_tokens:
embedding.padding_idx = embedding.padding_idx - (old_num_tokens-new_num_tokens)
factory_kwargs = {'device': embedding.weight.device, 'dtype': embedding.weight.dtype}
embedding.weight.data = torch.empty((new_num_tokens, input_embedding_dim), **factory_kwargs)
embedding.reset_parameters()
model._init_weights(embedding)
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
embedding.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :]

def _resize_lm_head(self, model, lm_head, new_num_tokens):
LazyInitContext.materialize(lm_head)
old_num_tokens, lm_head_dim = (lm_head.weight.size())
old_weight_data = lm_head.weight.data
old_bias_data = lm_head.bias.data if lm_head.bias is not None else None
lm_head.out_features = new_num_tokens
factory_kwargs = {'device': lm_head.weight.device, 'dtype': lm_head.weight.dtype}
lm_head.weight.data = torch.empty((new_num_tokens, lm_head_dim), **factory_kwargs)
if lm_head.bias is not None:
lm_head.bias.data = torch.empty(new_num_tokens, **factory_kwargs)
lm_head.reset_parameters()
model._init_weights(lm_head)
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
lm_head.weight.data[:num_tokens_to_copy, :] = old_weight_data[:num_tokens_to_copy, :]
if lm_head.bias is not None:
lm_head.bias.data[:num_tokens_to_copy] = old_bias_data[:num_tokens_to_copy]
5 changes: 3 additions & 2 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
import math
from typing import Callable, Dict, List

from torch import Tensor, nn
Expand Down Expand Up @@ -36,10 +37,10 @@ def preprocess(self):
multiple = self.shard_config.make_vocab_size_divisible_by
if self.shard_config.enable_tensor_parallelism:
world_size = self.shard_config.tensor_parallel_size
multiple = multiple * world_size
multiple = multiple * world_size // (math.gcd(multiple, world_size))
if vocab_size % multiple != 0:
new_vocab_size = vocab_size + multiple - vocab_size % multiple
self.model.resize_token_embeddings(new_vocab_size)
self.resize_token_embeddings(self.model, new_vocab_size)
return self.model

def module_policy(self):
Expand Down
14 changes: 7 additions & 7 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import math
from functools import partial
from typing import Callable, Dict, List, Union

Expand All @@ -23,15 +24,14 @@ def config_sanity_check(self):
pass

def preprocess(self):
vocab_size = self.model.config.vocab_size
multiple = self.shard_config.make_vocab_size_divisible_by
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

multiple = multiple * world_size // (math.gcd(multiple, world_size))
if vocab_size % multiple != 0:
new_vocab_size = vocab_size + multiple - vocab_size % multiple
self.resize_token_embeddings(self.model, new_vocab_size)
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check(
Expand Down

0 comments on commit 64379ab

Please sign in to comment.