Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

make the distributed program a single entrance #12

Merged
merged 1 commit into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion energon/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .engine import InferenceEngine
from .engine import InferenceEngine, launch_rpc
226 changes: 100 additions & 126 deletions energon/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,73 @@
import os
import time
import torch
from torch.nn import Module

import torch.multiprocessing as mp
from functools import partial
# pytorch rpc
import torch.distributed.rpc as rpc
from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .rpc_worker import RPCWorker

# depend on colossalai
from energon.core import global_context as gpc
from energon.context import ParallelMode
from energon.initialize import launch_from_torch
from energon.initialize import launch_from_torch, launch_from_multiprocess

from energon.utils import ensure_directory_exists
from energon.logging import get_dist_logger
from energon.nn import PipelineCommWrapper


def process_func(tp_size: int = 1,
pp_size:int = 1,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True,
rank: int = 0,
local_rank: int = 0,
world_size:int = 1,
host: str = 'localhost',
port: int = 29500):

os.environ['MASTER_ADDR'] = host
os.environ['MASTER_PORT'] = f'{port}'

launch_from_multiprocess(tp_size, pp_size, backend, seed, verbose, rank, local_rank, world_size, host, port)
WORKER_NAME = "wok{}"
rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size)
rpc.shutdown()

def launch_rpc(tp_size: int = 1,
pp_size:int = 1,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True,
host: str = 'localhost',
port: int = 29500):

world_size = pp_size * tp_size

processes = []
for rank in range(world_size-1):
p = mp.Process(target=process_func, args=(tp_size, pp_size, backend, seed, verbose, rank+1, rank+1, world_size, 'localhost', 29500))
p.start()
processes.append(p)

return processes




class InferenceEngine(Module):
def __init__(self,
model_class,
model_config,
samples,
pp_init_size: int = -1,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
host: str = 'localhost',
port: int = 29500,
dtype=None,
checkpoint=None
):
Expand All @@ -31,138 +79,64 @@ def __init__(self,
"""
super().__init__()

self._model_class = model_class
self._model_config = model_config
self._samples = samples
self._pp_size = pp_init_size
self._tp_size = tp_init_size
self._dtype = dtype
self._checkpoint = checkpoint
self._model = None
self.model_class = model_class
self.model_config = model_config
self.dtype = dtype
self.checkpoint = checkpoint
self.max_batch_size = max_batch_size

# for gpc
self.rank = 0
self.global_world_size = pp_init_size * tp_init_size
self.host = host
self.port = port
self.processes = None
self.tp_size = tp_init_size
self.pp_size = pp_init_size

# for TP
self.rrefs = []

self._init_dist()
self._set_sample_device()
self._init_model()


if self._checkpoint:
# self._save_parameter()
self._load_parameter()
# for rpc
self.WORKER_NAME = "wok{}"

self._init_dist_rpc()
self._init_model()

def _init_dist(self):
launch_from_torch(tp_size = self._tp_size, pp_size = self._pp_size)

def _set_sample_device(self):
for k, v in self._samples.items():
if v is not None:
self._samples[k] = v.cuda()
def _init_dist_rpc(self):
r'''
Based on global_context, init the rpc connection.
'''
self.processes = launch_rpc(tp_size = self.tp_size, pp_size = self.pp_size, backend = 'nccl', seed = 1024, verbose = True, host = self.host, port = self.port)
os.environ['MASTER_ADDR'] = self.host
os.environ['MASTER_PORT'] = f'{self.port}'
launch_from_multiprocess(tp_size = self.tp_size, pp_size = self.pp_size, rank = self.rank, local_rank = self.rank, world_size = self.global_world_size, host = self.host, port = self.port)
rpc.init_rpc(self.WORKER_NAME.format(0), rank=0, world_size=self.global_world_size)


def _init_model(self):
"""
TODO(dujiangsu) support other dtype
"""
if self._dtype == torch.half:
model = self._model_class(**self._model_config).cuda().half()
else:
model = self._model_class(**self._model_config).cuda()
model.eval()
self._model = PipelineCommWrapper(model = model, sample = self._samples, dtype=self._dtype)

def _reinit_dist(self):
gpc.destroy_vice_groups()
config = dict(parallel = dict(pipeline=dict(size=self._pp_size),tensor=dict(size=self._tp_size, mode='1d')))
gpc.load_config(config)
gpc.init_parallel_groups()

logger = get_dist_logger()
logger.info(f'Switch is triggered and Distributed environment is re-initialized, '
f'pipeline parallel size: {gpc.pipeline_parallel_size}, '
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])

def _reload_model(self):
del self._model
self._init_model()

def _get_ranks_name(self):
# tensor parallel
tp_local_rank = 0
if gpc.is_initialized(ParallelMode.TENSOR):
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)

# pipeline parallel
pp_local_rank = 0
if gpc.is_initialized(ParallelMode.PIPELINE):
pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)

ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}'
return ranks_name

def dtype_convert(self):
"""
TODO(dujiangsu) support other dtype
"""
if self._dtype == torch.half:
self._model.half()
elif self._dtype == torch.float:
self._model.float()
for i in range(self.global_world_size):
print(f'[INFO] rank{self.rank} calls rank{i} to init.')
ob_info = rpc.get_worker_info(self.WORKER_NAME.format(i))
self.rrefs.append(rpc.remote(ob_info, RPCWorker, args=(self.model_class, self.model_config, self.dtype, self.checkpoint, self.max_batch_size)))


def _save_parameter(self):
"""
save checkpoint.
"""
ensure_directory_exists(self._checkpoint)

ranks_name = self._get_ranks_name()
ckpt_filename = f'{ranks_name}.pt'

checkpoint_path = os.path.join(self._checkpoint, ckpt_filename)
# print(self._model)
torch.save(self._model.state_dict(), checkpoint_path)
def run(self, inputs):

res_rref = 0
output = None
for rref in self.rrefs:
output = remote_cls_method(RPCWorker.run, rref, inputs)

return output


def _repartition_weights(self):
"""
TODO: The method can repartition weights among all devices based on new tp/pp strategy through communication.
"""

def _load_parameter(self):
"""
TODO(dujiangsu) use json file to describe the distributed checkpoint. Like the strategy of Megatron.
TODO(dujiangsu) based on the current tp/pp configuration, the func can re-partition the existing ckpt automatically.
use self.repartition_weights() to avoid communication between host and device.
"""
ensure_directory_exists(self._checkpoint)
ranks_name = self._get_ranks_name()
ckpt_filename = f'{ranks_name}.pt'
checkpoint_path = os.path.join(self._checkpoint, ckpt_filename)

self._model.load_state_dict(torch.load(checkpoint_path))
def clear(self):
rpc.shutdown()

def apply_new_parallel_strategy(self):
"""
TODO: Switch between different tp/pp.
"""
# decide new parallel strategy, re-create communication group.
# repartition models and communicate weights.
for p in self.processes:
p.join()

self.repartition_weights()

def switch(self, pp_size, tp_size):
"""
TP/PP switch trigger, triggered from remote.
"""
self._pp_size = pp_size
self._tp_size = tp_size
self._reinit_dist()
self._reload_model()

def run(self):
output = None
with torch.inference_mode():
output = self._model.run()
# if gpc.is_last_rank(ParallelMode.PIPELINE):
return output



Loading