Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #10 from hpcaitech/feature/activation_reuse
Browse files Browse the repository at this point in the history
Feature/activation reuse
  • Loading branch information
dujiangsu authored Mar 7, 2022
2 parents 9d70023 + 7bcd026 commit 333a842
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 42 deletions.
69 changes: 35 additions & 34 deletions energon/nn/wrapper/pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,25 @@ def _build_samples(self):
self.sample[param] = self.input_tensors

def _init_tensor_meta(self):

if gpc.is_first_rank(ParallelMode.PIPELINE):
# print(*self.sample)
# print(type(self.sample[0]))
output = self.model(**self.sample)
send_tensor_meta(output)
send_forward(output)
elif gpc.is_last_rank(ParallelMode.PIPELINE):
self.recv_tensor_shape = recv_tensor_meta(self.recv_tensor_shape)
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
else:
self.recv_tensor_shape = recv_tensor_meta(self.recv_tensor_shape)
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
send_tensor_meta(output)
send_forward(output)
with torch.inference_mode():
if gpc.is_first_rank(ParallelMode.PIPELINE):
# print(*self.sample)
# print(type(self.sample[0]))
output = self.model(**self.sample)
send_tensor_meta(output)
send_forward(output)
elif gpc.is_last_rank(ParallelMode.PIPELINE):
self.recv_tensor_shape = recv_tensor_meta(self.recv_tensor_shape)
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
else:
self.recv_tensor_shape = recv_tensor_meta(self.recv_tensor_shape)
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
send_tensor_meta(output)
send_forward(output)

def run(self):
if gpc.is_initialized(ParallelMode.PIPELINE):
Expand All @@ -82,21 +82,22 @@ def no_pipeline_run(self):
return output

def pipeline_run(self):
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(**self.sample)
send_forward(output)
return None
elif gpc.is_last_rank(ParallelMode.PIPELINE):
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
return output
else:
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
send_forward(output)
return None
with torch.inference_mode():
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(**self.sample)
send_forward(output)
return None
elif gpc.is_last_rank(ParallelMode.PIPELINE):
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
return output
else:
self.input_tensors = recv_forward(self.recv_tensor_shape, dtype=self.dtype) # only a tensor now
self._build_samples()
output = self.model(**self.sample)
send_forward(output)
return None

# def _init_group(self):
# world_size = gpc.get_world_size(ParallelMode.GLOBAL)
Expand Down
19 changes: 11 additions & 8 deletions model/gpt/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ def main():

dtype=torch.float
if args.fp16:
# print("FP16")
dtype=torch.half

config = {'num_chunks':1, 'checkpoint':False, 'dtype':dtype, 'embed_split_hidden':False}


input_ids = torch.randint(1, 10, (1, 2048), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (1, 1, 2048), dtype=torch.int64)
input_ids = torch.randint(1, 10, (1, 512), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (1, 1, 512), dtype=torch.int64)
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)

engine = InferenceEngine(GPT3_pipeline_1D, config, sample, pp_init_size = args.pipe_para_size, tp_init_size = args.tensor_para_size, dtype = torch.half)
# print(MODEL_CLASSES[args.model_name])

engine = InferenceEngine(MODEL_CLASSES[args.model_name], config, sample, pp_init_size = args.pipe_para_size, tp_init_size = args.tensor_para_size, dtype = torch.half)



Expand Down Expand Up @@ -79,14 +82,14 @@ def main():
latency_elapsed = timer('latency-time').elapsed()

logger.info(f'Throughput, '
f'Pipeline Rank/ Tensor Rank: {pp}/{gpc.get_world_size(ParallelMode.PARALLEL_1D)},'
f'Time: {itr/evaluate_elapsed}')
f'Pipeline Rank/ Tensor Rank: {args.pipe_para_size}/{gpc.get_world_size(ParallelMode.PARALLEL_1D)},'
f'Time: {args.iteration/evaluate_elapsed}')
logger.info(f'Latency, '
f'Pipeline Rank/ Tensor Rank: {pp}/{gpc.get_world_size(ParallelMode.PARALLEL_1D)},'
f'Time: {latency_elapsed/itr}')
f'Pipeline Rank/ Tensor Rank: {args.pipe_para_size}/{gpc.get_world_size(ParallelMode.PARALLEL_1D)},'
f'Time: {latency_elapsed/args.iteration}')

logger.info(f'max memory allocated, '
f'Pipeline Rank/ Tensor Rank: {pp}/{gpc.get_world_size(ParallelMode.PARALLEL_1D)},'
f'Pipeline Rank/ Tensor Rank: {args.pipe_para_size}/{gpc.get_world_size(ParallelMode.PARALLEL_1D)},'
f'memory: {torch.cuda.max_memory_allocated()/1e9} GB')


Expand Down
3 changes: 3 additions & 0 deletions model/gpt/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

torchrun --nproc_per_node=4 evaluate.py --fp16 --model_name=gpt2_exlarge --tensor_para_size=2 --pipe_para_size=2

0 comments on commit 333a842

Please sign in to comment.