Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #45 from hpcaitech/feature/variable_len
Browse files Browse the repository at this point in the history
correctness for tp only
  • Loading branch information
MaruyamaAya authored May 6, 2022
2 parents 238c811 + c2dc270 commit 8e65da0
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 30 deletions.
45 changes: 34 additions & 11 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,29 @@ def __init__(self,
# TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others.
self.model = model
self.dtype = dtype

# get the hidden_size
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
hidden_states = None
self.sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size

if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
self._init_tensor_meta()
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
self._init_tensor_meta(sample)

self.pipe_msg_queue = PipelineMsgDict()
self.lock = threading.Lock()
self.key = CircleInt()


def _init_tensor_meta(self):
def _init_tensor_meta(self, sample):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(hidden_states = None, input_ids=self.sample['input_ids'], attention_mask = self.sample['attention_mask']) # ([32, 512, 1600])
output = self.model(hidden_states = None, input_ids=sample['input_ids'], attention_mask = sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
Expand All @@ -62,10 +60,35 @@ def _init_tensor_meta(self):
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(hidden_states = None, input_ids=input_tensor, attention_mask = self.sample['attention_mask'])
output = self.model(hidden_states = None, input_ids=input_tensor, attention_mask = sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)



def run(self, key, inputs):
if gpc.is_initialized(ParallelMode.PIPELINE):
return self.run_with_pp(key, inputs)
else:
return self.run_without_pp(key, inputs)


def run_without_pp(self, key, inputs):
pipe_meta = None
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()

cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()
output = self.model(hidden_states = None,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])
self.lock.release()

return output, cur_key

'''
hidden_size : ([32, 512, 1600])
For different model type, fill_meta_tensor is different
Expand All @@ -77,7 +100,7 @@ def fill_meta_tensor(self, inputs, pipe_meta):
pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run(self, key, inputs):
def run_with_pp(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)
Expand Down
27 changes: 11 additions & 16 deletions energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def __init__(self,
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')
self._init_self()

if gpc.is_initialized(ParallelMode.PIPELINE):
self.return_dict = ReturnDict()
self.return_dict = ReturnDict()

def _init_self(self):
print("[INFO] init model in rank {}".format(self.rank))
Expand All @@ -62,8 +61,7 @@ def _init_self(self):
self.model.eval()


if gpc.is_initialized(ParallelMode.PIPELINE):
self.model = self.pipe_wrapper(model = self.model, max_batch_size = self.max_batch_size, dtype=self.dtype)
self.model = self.pipe_wrapper(model = self.model, max_batch_size = self.max_batch_size, dtype=self.dtype)

def run(self, key, inputs):
# print("key: {}".format(key), flush=True)
Expand All @@ -72,16 +70,13 @@ def run(self, key, inputs):
if v is not None:
inputs[k] = v.cuda() #non_blocking=True

if gpc.is_initialized(ParallelMode.PIPELINE):
if gpc.is_last_rank(ParallelMode.PIPELINE):
output, cur_key = self.model.run(key, inputs)
self.return_dict.enqueue(cur_key, output.cpu())
return self.return_dict.top(key)
else:
self.model.run(key, inputs)
return None
else:
output = self.model(**inputs)
return output


if (gpc.is_initialized(ParallelMode.PIPELINE)) and (not gpc.is_last_rank(ParallelMode.PIPELINE)):
self.model.run(key, inputs)
return None
else:
output, cur_key = self.model.run(key, inputs)
self.return_dict.enqueue(cur_key, output.cpu())
return self.return_dict.top(key)

return None
7 changes: 4 additions & 3 deletions examples/gpt/gpt_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gpt import gpt2_small, gpt2_medium, gpt2_large, gpt2_xl, gpt2_8B, gpt3
from gpt_server import launch_engine
from gpt_batch_server import launch_engine

model_class = gpt2_large
model_type = "gpt"
Expand All @@ -10,6 +10,7 @@
port = 29400
half = True
server_host = "127.0.0.1"
server_port = 8010
server_port = 8020
log_level = "info"
backend = "nccl"
backend = "nccl"
tokenizer_path = "/home/lcdjs/hf_gpt2"

0 comments on commit 8e65da0

Please sign in to comment.