Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Lzm develop #33

Merged
merged 3 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions energon/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self,
checkpoint: load parameter.
"""
super().__init__()

self.model_class = model_class
self.model_config = model_config
self.model_type = model_type
Expand All @@ -59,15 +59,15 @@ def __init__(self,
# for TP
self.rrefs = []
# for rpc
self.WORKER_NAME = "wok{}"
self.WORKER_NAME = "wok{}"
self._init_dist_rpc()
self._init_model()
self.key = CircleInt()

def _init_dist_rpc(self):
r'''
Based on global_context, init the rpc connection.
'''
'''
launch_from_multiprocess(tp_size = self.tp_size, pp_size = self.pp_size, rank = self.rank, local_rank = self.rank, world_size = self.global_world_size, host = self.host, port = self.port)
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
num_worker_threads=16)
Expand Down
3 changes: 1 addition & 2 deletions energon/nn/layer/parallel_1d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.

# print("gathered Linear input: {}".format(input_parallel))
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
Expand Down Expand Up @@ -480,7 +480,6 @@ def forward(self, input_: Tensor) -> Tensor:
input_ = input_
else:
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)

output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)

Expand Down
22 changes: 22 additions & 0 deletions energon/server/batch_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import queue
from threading import Event
import threading
from engine_server import InferenceEngine


class Batch_Manager():
def __init__(self, engine: InferenceEngine, batch_acc: int=20, max_wait_time: int=5):
self.engine = engine
self.batch_acc = batch_acc
self.max_wait_time = max_wait_time
self.req_queue = queue.Queue()
self.res_dict = dict()

def fetch_inference_res(self, time_stamp: int):
if time_stamp not in self.res_dict.keys():
return "Error: Inference may have failed"
res = self.res_dict.pop(time_stamp)
return res

def append_req(self, time_stamp: int, input_str: str):
self.req_queue.put({})
44 changes: 36 additions & 8 deletions energon/utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ def load_checkpoint(
if kwargs['prefix'] != '':
model_state = remove_prefix(model_state, kwargs["prefix"])
# print("Rank {}: {}".format(gpc.get_global_rank(), model_state))
print("+"*30)
print(model_state.keys())
# print("+"*30)
# print(model_state.keys())
try:
model.load_state_dict(model_state, strict=strict)
except RuntimeError as e:
Expand Down Expand Up @@ -277,13 +277,13 @@ def save_checkpoint(file,

# if lr_scheduler is not None:
# checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
print("before save")
# print("before save")
torch.save(checkpoint, file, **kwargs)
print("after save")
# print("after save")


def judge_t(key_):
key_words = ['attn.query_key_value.weight', 'mlp.dense_1.weight', 'mlp.dense_2.weight']
key_words = ['attn.query_key_value.weight', 'mlp.dense_1.weight', 'mlp.dense_2.weight', 'attn.dense.weight']
for word_ in key_words:
if word_ in key_:
return True
Expand All @@ -296,13 +296,41 @@ def processing_HF_GPT(state_dict: OrderedDict):
new_k = module_name_mapping(k_)
if new_k == "":
continue


new_v = state_dict[k_]
if judge_t(new_k):
new_v = torch.transpose(new_v, 0, 1)
new_dict[new_k] = new_v
if "attn.query_key_value.weight" in new_k:
num_ = re.search(r"blocks\.blk_\d+?\.", new_k)
if num_:
prefix = num_.group()
else:
prefix = ''
# print("prefix: {}".format(prefix))
q_, k_, v_ = torch.chunk(new_v, 3, 0)
# new_dict[prefix + "attn.query_.weight"] = torch.transpose(q_, 0, 1)
# new_dict[prefix + "attn.key_.weight"] = torch.transpose(k_, 0, 1)
# new_dict[prefix + "attn.value_.weight"] = torch.transpose(v_, 0, 1)
new_dict[prefix + "attn.query_.weight"] = q_
new_dict[prefix + "attn.key_.weight"] = k_
new_dict[prefix + "attn.value_.weight"] = v_
elif "attn.query_key_value.bias" in new_k:
num_ = re.search(r"blocks\.blk_\d+?\.", new_k)
if num_:
prefix = num_.group()
else:
prefix = ''
print("prefix: {}".format(prefix))
q_, k_, v_ = torch.chunk(new_v, 3, 0)
new_dict[prefix + "attn.query_.bias"] = q_
new_dict[prefix + "attn.key_.bias"] = k_
new_dict[prefix + "attn.value_.bias"] = v_
else:
new_dict[new_k] = new_v
new_dict['head.dense.weight'] = new_dict['embed.word_embeddings.weight'].clone()
print("="*100)
print(new_dict.keys())
# print("="*100)
# print(new_dict.keys())
return {"model": new_dict, "epoch": 0}


Expand Down
35 changes: 35 additions & 0 deletions example/HuggingFace_GPT2/HF.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

tp_size=2
pp_size=2
model=gpt2_small
world_size=`expr $tp_size \* $pp_size`
server_port_start=8005
host="localhost"
port=29499
CUDA_VISIBLE_DEVICES=4,5,6,7

export PYTHONPATH=/home/lcdjs/ColossalAI-Inference/example

for ((i=1; i<${world_size}; i++))
do
server_port=`expr $server_port_start + $i`
python3 /home/lclzm/ColossalAI_Inference/energon/engine/server.py --port ${server_port} &
echo "process: ${i} launches"
done

sleep 3

for ((i=1; i<${world_size}; i++))
do
server_port=`expr $server_port_start + $i`
curl -X 'GET' \
"http://127.0.0.1:${server_port}/start/${tp_size}?pp_size=${pp_size}&backend=nccl&seed=1024&verbose=true&rank=${i}&local_rank=${i}&host=${host}&port=${port}" \
-H 'accept: application/json' &
echo "http://127.0.0.1:${server_port}/start/${tp_size}?pp_size=${pp_size}&backend=nccl&seed=1024&verbose=true&rank=${i}&local_rank=${i}&host=${host}&port=${port}"
echo "evoke process: ${i} init rpc"
done

python3 HF_GPT2_inference.py --fp16 --tensor_para_size=${tp_size} --pipe_para_size=${pp_size} --port=${port}
# python3 gpt_inference.py --fp16 --model_name=gpt2_large --tensor_para_size=1 --pipe_para_size=2 --port=29499
# tritonserver --model-repository /opt/tritonserver/host/python_backend/models
74 changes: 60 additions & 14 deletions example/HuggingFace_GPT2/HF_GPT2_inference.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import random
import numpy as np
import argparse

import torch
import torch.multiprocessing as mp
import torch.nn as nn

from transformers import GPT2Tokenizer

from energon.context import ParallelMode
from energon.core import global_context as gpc

Expand All @@ -24,6 +27,11 @@
"gpt3": gpt3
}

def select_top_k(predictions, k=10):
predicted_index = random.choice(
predictions[0, -1, :].sort(descending=True)[1][:10]).item()
return predicted_index


def build_gpt_model():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -52,25 +60,63 @@ def build_gpt_model():
pp_init_size=args.pipe_para_size,
port=args.port,
dtype=dtype_)
input_ids = torch.randint(1, 10, (32, 40), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (32, 1, 40), dtype=torch.int64)
tokenizer = GPT2Tokenizer.from_pretrained('./')
# tokenizer = GPT2Tokenizer(vocab_file="vocab.json", merges_file="merges.txt")
# test_input = ["MANY YEARS LATER as he faced the firing squad, Colonel Aureliano Buendía was to remember that"
# for _ in range(10)]
test_input = "I do not"
print(test_input)
input_token = tokenizer(test_input, return_tensors="pt")
# tokens_tensor = torch.tensor([input_token])
total_predicted_text = test_input
# print(input_ids['input_ids'])
# print(input_ids['attention_mask'])
# input_ids = input_token['input_ids']
# print(input_ids.shape)
# input_ids = torch.randint(1, 10, (32, 40), dtype=torch.int64)
# attention_mask = torch.randint(0, 1, (32, 1, 40), dtype=torch.int64)
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
output = engine.run(input_token)
predictions = output.to_here()
predicted_index = select_top_k(predictions, k=1)
total_predicted_text += tokenizer.decode(predicted_index)
print(total_predicted_text)
# sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
# sample = dict(hidden_states=hidden_states, input_ids=input_ids)

output = engine.run(sample)
# print(output.to_here())
output.to_here()
# output = engine.run(input_token)
# output.to_here()
# print(tokenizer.decode(output.to_here()[0]))
timer = get_timers()
timer('time1').start()

for i in range(1, args.iteration):
input_ids = torch.randint(1, 10, (32, i % 20 + 2), dtype=torch.int64)
attention_mask = torch.randint(0, 1, (32, 1, i % 20 + 2), dtype=torch.int64)
hidden_states = None
sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
output = engine.run(sample)
output = engine.run(input_token)
predictions = output.to_here()
predicted_index = select_top_k(predictions, k=10)
total_predicted_text += tokenizer.decode(predicted_index)
# print(total_predicted_text)
if '<|endoftext|>' in total_predicted_text:
# 如果出现文本结束标志,就结束文本生成
break

print(output.to_here())
input_token = tokenizer(total_predicted_text, return_tensors="pt")

# input_ids = tokenizer(test_input, return_tensors='pt')
# # input_ids = torch.randint(1, 10, (32, 40), dtype=torch.int64)
# # attention_mask = torch.randint(0, 1, (32, 1, 40), dtype=torch.int64)
# hidden_states = None
# # sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
# # sample = dict(hidden_states=hidden_states, input_ids=input_ids)
# # input_ids = torch.randint(1, 10, (32, i % 20 + 2), dtype=torch.int64)
# # attention_mask = torch.randint(0, 1, (32, 1, i % 20 + 2), dtype=torch.int64)
# # hidden_states = None
# # sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
# output = engine.run(input_ids)

# print(tokenizer.decode(output.to_here()))
print(total_predicted_text)
# print(output.to_here())
timer('time1').stop()

time1 = timer('time1').elapsed()
Expand All @@ -84,7 +130,7 @@ def build_gpt_model():
f'Time: {time1 / args.iteration}')

engine.clear()

#

if __name__ == "__main__":
build_gpt_model()
43 changes: 43 additions & 0 deletions example/HuggingFace_GPT2/save_hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from transformers import GPT2Tokenizer
import random
import logging

def select_top_k(predictions, k=10):
predicted_index = random.choice(
predictions[0, -1, :].sort(descending=True)[1][:10]).item()
return predicted_index
logging.basicConfig(level=logging.INFO)

# 载入预训练模型的分词器
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 使用 GPT2Tokenizer 对输入进行编码
text = "I do not think that apples are"
indexed_tokens = tokenizer.encode(text)
tokens_tensor = torch.tensor([indexed_tokens])
from transformers import GPT2Model

# 读取 GPT-2 预训练模型
model = GPT2Model.from_pretrained("gpt2")
model.eval()

total_predicted_text = text
n = 100 # 预测过程的循环次数
for _ in range(n):
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0]

predicted_index = select_top_k(predictions, k=10)
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
total_predicted_text += tokenizer.decode(predicted_index)

if '<|endoftext|>' in total_predicted_text:
# 如果出现文本结束标志,就结束文本生成
break

indexed_tokens += [predicted_index]
tokens_tensor = torch.tensor([indexed_tokens])

print(total_predicted_text)
4 changes: 2 additions & 2 deletions example/HuggingFace_GPT2/single_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
model_config=model_config,
model_type='gpt',
max_batch_size=32,
tp_init_size=1,
pp_init_size=1,
tp_init_size=2,
pp_init_size=2,
port=29501,
dtype=dtype_)
Loading