Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #12 from hpcaitech/feature/triton
Browse files Browse the repository at this point in the history
make the distributed program a single entrance
  • Loading branch information
dujiangsu authored Mar 17, 2022
2 parents c0078ab + 0fdb3bf commit a8c687a
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 165 deletions.
2 changes: 1 addition & 1 deletion energon/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .engine import InferenceEngine
from .engine import InferenceEngine, launch_rpc
226 changes: 100 additions & 126 deletions energon/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,73 @@
import os
import time
import torch
from torch.nn import Module

import torch.multiprocessing as mp
from functools import partial
# pytorch rpc
import torch.distributed.rpc as rpc
from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .rpc_worker import RPCWorker

# depend on colossalai
from energon.core import global_context as gpc
from energon.context import ParallelMode
from energon.initialize import launch_from_torch
from energon.initialize import launch_from_torch, launch_from_multiprocess

from energon.utils import ensure_directory_exists
from energon.logging import get_dist_logger
from energon.nn import PipelineCommWrapper


def process_func(tp_size: int = 1,
pp_size:int = 1,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True,
rank: int = 0,
local_rank: int = 0,
world_size:int = 1,
host: str = 'localhost',
port: int = 29500):

os.environ['MASTER_ADDR'] = host
os.environ['MASTER_PORT'] = f'{port}'

launch_from_multiprocess(tp_size, pp_size, backend, seed, verbose, rank, local_rank, world_size, host, port)
WORKER_NAME = "wok{}"
rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size)
rpc.shutdown()

def launch_rpc(tp_size: int = 1,
pp_size:int = 1,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True,
host: str = 'localhost',
port: int = 29500):

world_size = pp_size * tp_size

processes = []
for rank in range(world_size-1):
p = mp.Process(target=process_func, args=(tp_size, pp_size, backend, seed, verbose, rank+1, rank+1, world_size, 'localhost', 29500))
p.start()
processes.append(p)

return processes




class InferenceEngine(Module):
def __init__(self,
model_class,
model_config,
samples,
pp_init_size: int = -1,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
host: str = 'localhost',
port: int = 29500,
dtype=None,
checkpoint=None
):
Expand All @@ -31,138 +79,64 @@ def __init__(self,
"""
super().__init__()

self._model_class = model_class
self._model_config = model_config
self._samples = samples
self._pp_size = pp_init_size
self._tp_size = tp_init_size
self._dtype = dtype
self._checkpoint = checkpoint
self._model = None
self.model_class = model_class
self.model_config = model_config
self.dtype = dtype
self.checkpoint = checkpoint
self.max_batch_size = max_batch_size

# for gpc
self.rank = 0
self.global_world_size = pp_init_size * tp_init_size
self.host = host
self.port = port
self.processes = None
self.tp_size = tp_init_size
self.pp_size = pp_init_size

# for TP
self.rrefs = []

self._init_dist()
self._set_sample_device()
self._init_model()


if self._checkpoint:
# self._save_parameter()
self._load_parameter()
# for rpc
self.WORKER_NAME = "wok{}"

self._init_dist_rpc()
self._init_model()

def _init_dist(self):
launch_from_torch(tp_size = self._tp_size, pp_size = self._pp_size)

def _set_sample_device(self):
for k, v in self._samples.items():
if v is not None:
self._samples[k] = v.cuda()
def _init_dist_rpc(self):
r'''
Based on global_context, init the rpc connection.
'''
self.processes = launch_rpc(tp_size = self.tp_size, pp_size = self.pp_size, backend = 'nccl', seed = 1024, verbose = True, host = self.host, port = self.port)
os.environ['MASTER_ADDR'] = self.host
os.environ['MASTER_PORT'] = f'{self.port}'
launch_from_multiprocess(tp_size = self.tp_size, pp_size = self.pp_size, rank = self.rank, local_rank = self.rank, world_size = self.global_world_size, host = self.host, port = self.port)
rpc.init_rpc(self.WORKER_NAME.format(0), rank=0, world_size=self.global_world_size)


def _init_model(self):
"""
TODO(dujiangsu) support other dtype
"""
if self._dtype == torch.half:
model = self._model_class(**self._model_config).cuda().half()
else:
model = self._model_class(**self._model_config).cuda()
model.eval()
self._model = PipelineCommWrapper(model = model, sample = self._samples, dtype=self._dtype)

def _reinit_dist(self):
gpc.destroy_vice_groups()
config = dict(parallel = dict(pipeline=dict(size=self._pp_size),tensor=dict(size=self._tp_size, mode='1d')))
gpc.load_config(config)
gpc.init_parallel_groups()

logger = get_dist_logger()
logger.info(f'Switch is triggered and Distributed environment is re-initialized, '
f'pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])

def _reload_model(self):
del self._model
self._init_model()

def _get_ranks_name(self):
# tensor parallel
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.TENSOR):
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)

# pipeline parallel
pp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
return ranks_name

def dtype_convert(self):
"""
TODO(dujiangsu) support other dtype
"""
if self._dtype == torch.half:
self._model.half()
elif self._dtype == torch.float:
self._model.float()
for i in range(self.global_world_size):
print(f'[INFO] rank{self.rank} calls rank{i} to init.')
ob_info = rpc.get_worker_info(self.WORKER_NAME.format(i))
self.rrefs.append(rpc.remote(ob_info, RPCWorker, args=(self.model_class, self.model_config, self.dtype, self.checkpoint, self.max_batch_size)))


def _save_parameter(self):
"""
save checkpoint.
"""
ensure_directory_exists(self._checkpoint)

ranks_name = self._get_ranks_name()
ckpt_filename = f'{ranks_name}.pt'

checkpoint_path = os.path.join(self._checkpoint, ckpt_filename)
# print(self._model)
torch.save(self._model.state_dict(), checkpoint_path)
def run(self, inputs):

res_rref = 0
output = None
for rref in self.rrefs:
output = remote_cls_method(RPCWorker.run, rref, inputs)

return output


def _repartition_weights(self):
"""
TODO: The method can repartition weights among all devices based on new tp/pp strategy through communication.
"""

def _load_parameter(self):
"""
TODO(dujiangsu) use json file to describe the distributed checkpoint. Like the strategy of Megatron.
TODO(dujiangsu) based on the current tp/pp configuration, the func can re-partition the existing ckpt automatically.
use self.repartition_weights() to avoid communication between host and device.
"""
ensure_directory_exists(self._checkpoint)
ranks_name = self._get_ranks_name()
ckpt_filename = f'{ranks_name}.pt'
checkpoint_path = os.path.join(self._checkpoint, ckpt_filename)

self._model.load_state_dict(torch.load(checkpoint_path))
def clear(self):
rpc.shutdown()

def apply_new_parallel_strategy(self):
"""
TODO: Switch between different tp/pp.
"""
# decide new parallel strategy, re-create communication group.
# repartition models and communicate weights.
for p in self.processes:
p.join()

self.repartition_weights()

def switch(self, pp_size, tp_size):
"""
TP/PP switch trigger, triggered from remote.
"""
self._pp_size = pp_size
self._tp_size = tp_size
self._reinit_dist()
self._reload_model()

def run(self):
output = None
with torch.inference_mode():
output = self._model.run()
# if gpc.is_last_rank(ParallelMode.PIPELINE):
return output



Loading

0 comments on commit a8c687a

Please sign in to comment.