From 55859c00db7ba26f02695fe19e9808428ddd4bc0 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 19 Aug 2022 13:55:31 +0800 Subject: [PATCH] add opt example --- .gitignore | 2 + energonai/engine/engine.py | 10 +- energonai/model/model_factory.py | 162 +++++++++++++++------------ energonai/server/worker_server.py | 18 +-- energonai/utils/checkpointing_opt.py | 65 ++++++----- examples/opt/opt_config.py | 6 +- examples/opt/opt_server.py | 62 +++++----- 7 files changed, 175 insertions(+), 150 deletions(-) diff --git a/.gitignore b/.gitignore index b6e4761..f347387 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,5 @@ dmypy.json # Pyre type checker .pyre/ + +.vscode/ \ No newline at end of file diff --git a/energonai/engine/engine.py b/energonai/engine/engine.py index d0a4aa5..dc22cfe 100644 --- a/energonai/engine/engine.py +++ b/energonai/engine/engine.py @@ -19,7 +19,6 @@ from energonai.initialize import launch_from_multiprocess - logger = get_dist_logger('energonai') @@ -61,7 +60,7 @@ def __init__(self, # for TP, PP self.rrefs = [] self.auto_pp = auto_pp - + # for rpc self.WORKER_NAME = "wok{}" self._init_dist_rpc() @@ -79,9 +78,10 @@ def _init_dist_rpc(self): world_size=self.global_world_size, host=self.host, port=self.port) - rpc_backend_options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout=6000) - # _transports=["uv"] TODO: potentially a bug - + rpc_backend_options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, + rpc_timeout=6000 + # _transports=["uv"] TODO: potentially a bug + ) rpc.init_rpc(self.WORKER_NAME.format(0), rank=0, world_size=self.global_world_size, diff --git a/energonai/model/model_factory.py b/energonai/model/model_factory.py index a1fbd77..ea57a53 100644 --- a/energonai/model/model_factory.py +++ b/energonai/model/model_factory.py @@ -15,14 +15,18 @@ from energonai.utils import is_using_pp, get_current_device from energonai.logging import get_dist_logger + def gelu_impl(x): """OpenAI's gelu implementation.""" return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + def select_top_k(predictions, k=5): - predicted_index = random.choice(predictions[0, -1, :].sort(descending=True)[1][:10]) #.item() + predicted_index = random.choice(predictions[0, -1, :].sort(descending=True)[1][:10]) # .item() return predicted_index + class PipelineModel(nn.Module): def __init__(self, vocab_size: int = 50257, @@ -37,14 +41,14 @@ def __init__(self, padding_idx: int = 0, dtype: dtype = torch.float16, bias: bool = True, - apply_post_layernorm:bool = False, + apply_post_layernorm: bool = False, first: bool = False, last: bool = False, - fused_qkv:bool = True, - checkpoint:str = None, - model_name:str = None, - topk:int = 5, - is_decoder:bool = True) -> None: + fused_qkv: bool = True, + checkpoint: str = None, + model_name: str = None, + topk: int = 5, + is_decoder: bool = True) -> None: super().__init__() self.hidden_size = hidden_size @@ -53,37 +57,37 @@ def __init__(self, self.max_seq_len = max_seq_len self.model_name = model_name self.topk = topk - + if first: self.embed = Embedding1D(hidden_size=hidden_size, - vocab_size=vocab_size, - max_seq_len=max_seq_len, - num_tokentypes = num_tokentypes, - padding_idx=padding_idx, - dtype=dtype) - + vocab_size=vocab_size, + max_seq_len=max_seq_len, + num_tokentypes=num_tokentypes, + padding_idx=padding_idx, + dtype=dtype) + self.blocks = nn.ModuleList() self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if is_using_pp() else 0 for id_ in range(depth): self.blocks.add_module(f'{id_ + self.pp_rank * depth}', - Block1D(hidden_size=hidden_size, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - activation=activation, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype, - bias=bias, - apply_post_layernorm=apply_post_layernorm, - max_seq_len=max_seq_len, - fused_qkv=fused_qkv, - is_decoder=is_decoder)) + Block1D(hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype, + bias=bias, + apply_post_layernorm=apply_post_layernorm, + max_seq_len=max_seq_len, + fused_qkv=fused_qkv, + is_decoder=is_decoder)) if last: - self.norm = LayerNorm1D(normalized_shape=hidden_size, eps=layernorm_epsilon) + self.norm = LayerNorm1D(normalized_shape=hidden_size, eps=layernorm_epsilon) self.head = LMHead1D(hidden_size=hidden_size, vocab_size=vocab_size, bias=False, dtype=dtype) - + def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None): - batch_size = input_ids.shape[0] - + batch_size = input_ids.shape[0] + if self.first: hidden_states = self.embed(input_ids) @@ -94,20 +98,21 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l attention_mask = (1.0 - attention_mask) * -10000.0 for block in self.blocks: - hidden_states = block(hidden_states, attention_mask) # seq_lens - + hidden_states = block(hidden_states, attention_mask) # seq_lens + if self.last: hidden_states = self.head(self.norm(hidden_states)) hidden_states = select_top_k(hidden_states, k=self.topk) - + return hidden_states + def partition_uniform(num_items, pipeline_parallel_size): logger = get_dist_logger() assert num_items % pipeline_parallel_size == 0, \ "Layer length should be divided by the number of pipeline size, otherwise parameter method is recomended" - parts = [[] for _ in range(pipeline_parallel_size)] + parts = [[] for _ in range(pipeline_parallel_size)] base_idx = 0 chunk_size = num_items // pipeline_parallel_size @@ -122,7 +127,7 @@ def partition_uniform(num_items, pipeline_parallel_size): return parts -def create_pipeline_model(depth:int = 48, +def create_pipeline_model(depth: int = 48, layer_partitions=None, **model_kwargs): logger = get_dist_logger() @@ -150,9 +155,10 @@ def create_pipeline_model(depth:int = 48, numel = 0 for _, param in model.named_parameters(recurse=True): numel += param.numel() - logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') + logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB!!!') if "checkpoint" in model_kwargs.keys() and "model_name" in model_kwargs.keys(): + start = time.time() assert os.path.exists(model_kwargs["checkpoint"]), "Checkpoint file not found" if model_kwargs["model_name"] == "hf_gpt2": from energonai.utils.checkpointing_hf_gpt2 import load_checkpoint @@ -160,95 +166,105 @@ def create_pipeline_model(depth:int = 48, if model_kwargs["model_name"] == "opt": from energonai.utils.checkpointing_opt import load_checkpoint load_checkpoint(model_kwargs["checkpoint"], model, **model_kwargs) + logger.info(f'Load time: {time.time() - start:.3f} s') return model + def hf_gpt2(**kwargs): - model_kwargs = dict(hidden_size=768, - depth=12, - max_seq_len = 1024, - num_heads=12, - fused_qkv=False, - model_name = "hf_gpt2", - is_decoder = True, + model_kwargs = dict(hidden_size=768, + depth=12, + max_seq_len=1024, + num_heads=12, + fused_qkv=False, + model_name="hf_gpt2", + is_decoder=True, **kwargs) return create_pipeline_model(**model_kwargs) + def gpt2_small(**kwargs): - model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, is_decoder = True, **kwargs) + model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, is_decoder=True, **kwargs) return create_pipeline_model(**model_kwargs) + def gpt2_large(**kwargs): - model_kwargs = dict(hidden_size=1536, depth=36, num_heads=12, is_decoder = True, **kwargs) + model_kwargs = dict(hidden_size=1536, depth=36, num_heads=12, is_decoder=True, **kwargs) return create_pipeline_model(**model_kwargs) + def gpt2_8B(**kwargs): - model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, is_decoder = True, **kwargs) + model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, is_decoder=True, **kwargs) return create_pipeline_model(**model_kwargs) + def gpt3(**kwargs): - model_kwargs = dict(hidden_size=12288, depth=12, num_heads=96, is_decoder = True, **kwargs) + model_kwargs = dict(hidden_size=12288, depth=12, num_heads=96, is_decoder=True, **kwargs) return create_pipeline_model(**model_kwargs) + def bert_small(**kwargs): - model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, is_decoder = False, **kwargs) + model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, is_decoder=False, **kwargs) return create_pipeline_model(**model_kwargs) def bert_large(**kwargs): - model_kwargs = dict(hidden_size=1024, depth=24, num_heads=16, is_decoder = False, **kwargs) + model_kwargs = dict(hidden_size=1024, depth=24, num_heads=16, is_decoder=False, **kwargs) return create_pipeline_model(**model_kwargs) + def bert_8B(**kwargs): - model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, is_decoder = False, **kwargs) + model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, is_decoder=False, **kwargs) return create_pipeline_model(**model_kwargs) def bert_175B(**kwargs): - model_kwargs = dict(hidden_size=12288, depth=96, num_heads=96, is_decoder = False, **kwargs) + model_kwargs = dict(hidden_size=12288, depth=96, num_heads=96, is_decoder=False, **kwargs) return create_pipeline_model(**model_kwargs) + def opt_125M(**kwargs): model_kwargs = dict(vocab_size=50272, - hidden_size=768, - depth=12, + hidden_size=768, + depth=12, max_seq_len=2050, - num_heads=12, - activation=nn.functional.relu, - is_decoder = True, - fused_qkv=False, - model_name = "opt", + num_heads=12, + activation=nn.functional.relu, + is_decoder=True, + fused_qkv=False, + model_name="opt", **kwargs) return create_pipeline_model(**model_kwargs) + def opt_30B(**kwargs): model_kwargs = dict(vocab_size=50272, - hidden_size=7168, - depth=48, + hidden_size=7168, + depth=48, max_seq_len=2050, - num_heads=56, - activation=nn.functional.relu, - is_decoder = True, - fused_qkv=False, - model_name = "opt", + num_heads=56, + activation=nn.functional.relu, + is_decoder=True, + fused_qkv=False, + model_name="opt", **kwargs) return create_pipeline_model(**model_kwargs) + def opt_66B(**kwargs): model_kwargs = dict(vocab_size=50272, - hidden_size=9216, - depth=64, + hidden_size=9216, + depth=64, max_seq_len=2050, - num_heads=72, - activation=nn.functional.relu, - is_decoder = True, - fused_qkv=False, - model_name = "opt", + num_heads=72, + activation=nn.functional.relu, + is_decoder=True, + fused_qkv=False, + model_name="opt", **kwargs) return create_pipeline_model(**model_kwargs) - # def opt_175B(**kwargs): # model_kwargs = dict(hidden_size=12288, depth=96, num_heads=96, activation=nn.functional.relu, is_decoder = True, **kwargs) -# return create_pipeline_model(**model_kwargs) \ No newline at end of file +# return create_pipeline_model(**model_kwargs) diff --git a/energonai/server/worker_server.py b/energonai/server/worker_server.py index 930304f..a81877c 100644 --- a/energonai/server/worker_server.py +++ b/energonai/server/worker_server.py @@ -7,13 +7,14 @@ logger = get_dist_logger('energonai') -app = FastAPI() +app = FastAPI() + + @app.get("/") def root(): return {"200"} - @app.get("/shutdown") async def shutdown(): rpc.shutdown() @@ -22,21 +23,20 @@ async def shutdown(): await server.shutdown() -def launch_worker(config_file, +def launch_worker(config_file, rank=0, local_rank=0, server_host="127.0.0.1", server_port=8005): mcfg.load_config(config_file) - + world_size = mcfg['tp_init_size'] * mcfg['pp_init_size'] - launch_from_multiprocess(mcfg['tp_init_size'], mcfg['pp_init_size'], mcfg['backend'], - mcfg['seed'], mcfg['verbose'], rank, local_rank, world_size, - mcfg['host'], mcfg['port']) + launch_from_multiprocess(mcfg['tp_init_size'], mcfg['pp_init_size'], mcfg['backend'], + mcfg['seed'], mcfg['verbose'], rank, local_rank, world_size, + mcfg['host'], mcfg['port']) - - WORKER_NAME = "wok{}" + WORKER_NAME = "wok{}" # _transports=["uv"] TODO: potentially a bug rpc_backend_options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout=6000) rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options) diff --git a/energonai/utils/checkpointing_opt.py b/energonai/utils/checkpointing_opt.py index 97c9c4a..c0f9caa 100644 --- a/energonai/utils/checkpointing_opt.py +++ b/energonai/utils/checkpointing_opt.py @@ -15,6 +15,30 @@ except ImportError: _EXTRA_STATE_KEY_SUFFIX = '_extra_state' +import os +from multiprocessing import Pool +from time import time + + +def load_state_dict(path: str): + if os.path.isfile(path): + return torch.load(path) + assert os.path.isdir(path) + state_dict = {} + files = [] + for filename in os.listdir(path): + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + files.append(filepath) + threads = torch.get_num_threads() + print(f'load {len(files)} files using {threads} threads') + with Pool(threads) as pool: + state_dicts = pool.map(torch.load, files) + for sd in state_dicts: + state_dict.update(sd) + return state_dict + + __all__ = [ "partition_tensor_parallel_state_dict", "load_checkpoint", "gather_tensor_parallel_state_dict", "save_checkpoint" ] @@ -26,7 +50,7 @@ 'self_attn.q_proj': 'attn.query_', 'self_attn.k_proj': 'attn.key_', 'self_attn.v_proj': 'attn.value_', - 'self_attn.out_proj':'attn.dense', + 'self_attn.out_proj': 'attn.dense', 'self_attn_layer_norm': 'norm1', 'final_layer_norm': 'norm2', 'fc1': 'mlp.dense_1', @@ -189,25 +213,21 @@ def load_checkpoint(file, Raises: RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ - state_dict = (torch.load(file, map_location=torch.device("cpu")) - if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) - - # model states - if state_dict is not None: - model_state = state_dict.pop("model") + start = time() + if gpc.get_local_rank(ParallelMode.MODEL) == 0: + model_state = load_state_dict(file) model_state = processing_OPT(model_state) else: model_state = dict() - + dist.barrier() + print(f'Load file time: {time()-start:.3f} s') # pipeline if is_using_pp(): model_state = partition_pipeline_parallel_state_dict(model, model_state, **kwargs) if "prefix" in kwargs.keys(): if kwargs['prefix'] != '': model_state = remove_prefix(model_state, kwargs["prefix"]) - # print("Rank {}: {}".format(gpc.get_global_rank(), model_state.keys())) - # print("+"*30) - # print(model_state.keys()) + try: model.load_state_dict(model_state, strict=strict) except RuntimeError as e: @@ -224,21 +244,7 @@ def load_checkpoint(file, else: raise e - # broadcast the rest states - state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL) - - # # optimizer states - # if optimizer is not None and 'optimizer' in state_dict: - # optimizer.load_state_dict(state_dict['optimizer']) - - # # lr scheduler states - # if lr_scheduler is not None and 'lr_scheduler' in state_dict: - # lr_scheduler.load_state_dict(state_dict['lr_scheduler']) - - # last epoch - last_epoch = state_dict.pop("epoch", -1) - - return last_epoch + return -1 def save_checkpoint(file, @@ -294,6 +300,7 @@ def judge_t(key_): return True return False + def module_name_mapping(ori_name: str): # print(ori_name) if ori_name == 'decoder.embed_tokens.weight': @@ -310,6 +317,7 @@ def module_name_mapping(ori_name: str): res = res.replace(k_, name_map[k_]) return res + def processing_OPT(state_dict: OrderedDict): new_dict = OrderedDict() for k_ in state_dict.keys(): @@ -353,12 +361,9 @@ def processing_OPT(state_dict: OrderedDict): # print("="*100) # print(new_dict.keys()) # print("---------------------------") - return new_dict #{"model": new_dict, "epoch": 0} + return new_dict # {"model": new_dict, "epoch": 0} def id_map(matched): value = matched.group('value') return "blocks.{}.".format(value) - - - diff --git a/examples/opt/opt_config.py b/examples/opt/opt_config.py index 45fea1f..4d34434 100644 --- a/examples/opt/opt_config.py +++ b/examples/opt/opt_config.py @@ -8,10 +8,10 @@ port = 29402 half = True checkpoint = "/data/user/djs_model_checkpoint/opt_metaseq_125m/model/restored.pt" -#"/data/user/djs_model_checkpoint/opt-30B-singleton/opt_metaseq_30000m/model/restored.pt" +# "/data/user/djs_model_checkpoint/opt-30B-singleton/opt_metaseq_30000m/model/restored.pt" backend = "nccl" -# for parallel +# for parallel tp_init_size = 2 pp_init_size = 2 @@ -21,4 +21,4 @@ # server_host = "127.0.0.1" server_host = "0.0.0.0" server_port = 8020 -log_level = "info" \ No newline at end of file +log_level = "info" diff --git a/examples/opt/opt_server.py b/examples/opt/opt_server.py index 9694a56..ed6de91 100644 --- a/examples/opt/opt_server.py +++ b/examples/opt/opt_server.py @@ -8,12 +8,14 @@ from transformers import GPT2Tokenizer -app = FastAPI() # 创建 api 对象 +app = FastAPI() # 创建 api 对象 -@app.get("/") # 根路由 + +@app.get("/") # 根路由 def root(): return {"200"} + @app.get("/run/{request}") def run(request: str, max_seq_length: int): @@ -28,9 +30,9 @@ def run(request: str, max_seq_length: int): if '<|endoftext|>' in total_predicted_text: break input_token = tokenizer(total_predicted_text, return_tensors="pt") - + return {total_predicted_text} - + @app.get("/shutdown") async def shutdown(): @@ -41,42 +43,42 @@ async def shutdown(): def launch_engine(model_class, - model_type, - max_batch_size: int = 1, - tp_init_size: int = -1, - pp_init_size: int = -1, - host: str = "localhost", - port: int = 29500, - dtype = torch.float, - checkpoint: str = None, - tokenizer_path: str = None, - server_host = "localhost", - server_port = 8005, - log_level = "info" - ): - + model_type, + max_batch_size: int = 1, + tp_init_size: int = -1, + pp_init_size: int = -1, + host: str = "localhost", + port: int = 29500, + dtype=torch.float, + checkpoint: str = None, + tokenizer_path: str = None, + server_host="localhost", + server_port=8005, + log_level="info" + ): + # only for the generation task global tokenizer if(tokenizer_path): tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path) - + if checkpoint: model_config = {'dtype': dtype, 'checkpoint': checkpoint} else: model_config = {'dtype': dtype} - + global engine - engine = InferenceEngine(model_class, - model_config, - model_type, - max_batch_size = max_batch_size, - tp_init_size = tp_init_size, - pp_init_size = pp_init_size, - host = host, - port = port, - dtype = dtype) + engine = InferenceEngine(model_class, + model_config, + model_type, + max_batch_size=max_batch_size, + tp_init_size=tp_init_size, + pp_init_size=pp_init_size, + host=host, + port=port, + dtype=dtype) global server config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) server = uvicorn.Server(config=config) - server.run() \ No newline at end of file + server.run()