Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Mar 14, 2024
1 parent 653aa06 commit a980e70
Showing 1 changed file with 46 additions and 4 deletions.
50 changes: 46 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def unwrap(self):
return module


def get_param_info(optim: Optimizer):
def get_param_info(optim: Optimizer, model: torch.nn.Module):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
Expand All @@ -199,7 +199,14 @@ def get_param_info(optim: Optimizer):

if optim is None:
return {}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
"old_input_embedding_param_id": None,
"old_output_embedding_param_id": None,
}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
Expand All @@ -215,6 +222,13 @@ def get_param_info(optim: Optimizer):
param_info["param_groups"].append(packed_group)
start_index += len(group["params"])

input_embedding = model.get_input_embeddings()
if input_embedding is not None:
param_info["old_input_embedding_param_id"] = id(input_embedding.weight)
output_embedding = model.get_output_embeddings()
if output_embedding is not None:
param_info["old_output_embedding_param_id"] = id(output_embedding.weight)

return param_info


Expand Down Expand Up @@ -1067,7 +1081,7 @@ def __init__(
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
# forced_dtype=PRECISION_TORCH_TYPE[precision],
)

self.max_norm = max_norm
Expand All @@ -1076,6 +1090,32 @@ def __del__(self):
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()

def set_resized_embedding_to_optimizer(self, model, optimizer, param_info):
old_input_embedding_param_id = param_info["old_input_embedding_param_id"]
if old_input_embedding_param_id is not None:
for param_group in optimizer.param_groups:
group_params = param_group["params"]
new_params = []
for param in group_params:
if id(param) == old_input_embedding_param_id:
new_input_embeddings = model.module.get_input_embeddings()
new_params.append(new_input_embeddings.weight)
else:
new_params.append(param)
param_group["params"] = new_params
old_output_embedding_param_id = param_info["old_output_embedding_param_id"]
if old_output_embedding_param_id is not None:
for param_group in optimizer.param_groups:
group_params = param_group["params"]
new_params = []
for param in group_params:
if id(param) == old_output_embedding_param_id:
new_output_embeddings = model.module.get_output_embeddings()
new_params.append(new_output_embeddings.weight)
else:
new_params.append(param)
param_group["params"] = new_params

@property
def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
Expand Down Expand Up @@ -1106,7 +1146,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
param_info = get_param_info(optimizer, model)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
Expand All @@ -1119,6 +1159,8 @@ def configure(
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)

self.set_resized_embedding_to_optimizer(model, optimizer, param_info)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ["fp16", "bf16"]:
Expand Down

0 comments on commit a980e70

Please sign in to comment.