Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Apr 10, 2024
1 parent 5464925 commit ea8bcdc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
1 change: 0 additions & 1 deletion colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(

# deal with tensor parallelism
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

Expand Down
13 changes: 9 additions & 4 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,15 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
p_mapping = param_to_save_data
for name, param in self.name2param.items():
if param is not None:
origin_shape = self.params_info["name2shape"][name]
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
destination[prefix + name] = p_mapping[param][: origin_shape[0], ...]
if self.params_info is not None:
origin_shape = self.params_info["name2shape"][name]
destination[prefix + name] = p_mapping[param][: origin_shape[0], ...]
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data

Expand Down Expand Up @@ -891,8 +894,10 @@ def state_dict_shard(
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(param_to_save)

origin_shape = self.params_info["name2shape"][name]
gathered_param = gathered_param[: origin_shape[0], ...]
if self.params_info is not None:
origin_shape = self.params_info["name2shape"][name]
gathered_param = gathered_param[: origin_shape[0], ...]

block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
Expand Down

0 comments on commit ea8bcdc

Please sign in to comment.