Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #6 from hpcaitech/feature/pipeline
Browse files Browse the repository at this point in the history
deal with no pipeline
  • Loading branch information
dujiangsu authored Feb 18, 2022
2 parents 8d897a6 + abbe93f commit b020672
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
13 changes: 12 additions & 1 deletion energon/nn/wrapper/pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,18 @@ def _init_tensor_meta(self):
send_tensor_meta(output)
send_forward(output)

def run(self):
def run(self):
if gpc.is_initialized(ParallelMode.PIPELINE):
self.pipeline_run()
else:
self.no_pipeline_run()


def no_pipeline_run(self):
output = self.model(**self.sample)
return output

def pipeline_run(self):
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(**self.sample)
send_forward(output)
Expand Down
44 changes: 30 additions & 14 deletions model/gpt/model/pipeline_gpt1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect

from .gpt1d import GPTTransformerLayer1D
from .pipeline_gpt_wrapper import PipelineSharedModuleWrapper
# from .pipeline_gpt_wrapper import PipelineSharedModuleWrapper

from energon.nn import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
from energon.context import ParallelMode
Expand All @@ -30,8 +30,8 @@ def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"

logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks
parts = [[] for _ in range(pipeline_parallel_size)] # 4
partition_items = num_items // num_chunks # 96 // 2
for idx in range(num_chunks):
base_idx = idx * partition_items
chunk_size = partition_items // pipeline_parallel_size
Expand Down Expand Up @@ -214,23 +214,36 @@ def _filter_kwargs(func, kwargs):

def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
logger = get_dist_logger()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

pipeline_size = 0
pipeline_rank = 0
rank = gpc.get_global_rank()
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]

if gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pipeline_size = 1
pipeline_rank = 0

parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] # [[(0, 24)], [(24, 48)], [(48, 72)], [(72, 96)]]
models = []

for start, end in parts:
kwargs['num_layers'] = end - start
kwargs['first'] = start == 0
kwargs['last'] = end == num_layers

logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')

chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
if start == 0:
wrapper.register_module(chunk.embedding.word_embeddings)
elif end == num_layers:
wrapper.register_module(chunk.head)
models.append(chunk)

# if start == 0:
# wrapper.register_module(chunk.embedding.word_embeddings)
# elif end == num_layers:
# wrapper.register_module(chunk.head)

if len(models) == 1:
model = models[0]
else:
Expand All @@ -239,7 +252,8 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to
numel = 0
for _, param in model.named_parameters(recurse=True):
numel += param.numel()
logger.info(f'Rank{rank}/{gpc.get_local_rank(ParallelMode.PIPELINE)} model size = {numel * 2 / 1e9} GB')
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')

return model


Expand All @@ -248,9 +262,11 @@ def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'),
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)


# def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
# return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)

# def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
# if gpc.is_initialized(ParallelMode.PIPELINE):
# else:
# model = PipelineGPT1D(kwargs, first=True,last=True)

def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(hidden_size=768, num_attention_heads=12, checkpoint=checkpoint,
Expand Down

0 comments on commit b020672

Please sign in to comment.