Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

batch manager by ziming liu #44

Merged
merged 1 commit into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 198 additions & 18 deletions energon/server/batch_manager.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,202 @@
import queue
from threading import Event
import torch
import time
from scipy import stats
import numpy as np
from energon.engine import InferenceEngine
from transformers import GPT2Tokenizer
import random
import redis
import os
from tqdm import tqdm, trange
import threading
from engine_server import InferenceEngine
from readerwriterlock import rwlock


class Batch_Manager():
def __init__(self, engine: InferenceEngine, batch_acc: int=20, max_wait_time: int=5):
def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1,
repeat_round: int = 3):
def select_top_k(predictions, k=10):
predicted_index = random.choice(
predictions[0, -1, :].sort(descending=True)[1][:k]).item()
return predicted_index

print("fetching cached cost")
cached_name = "cached_cost_{}_{}_{}_{}.npy".format(max_seq_len, max_batch_size, step, repeat_round)
if os.path.exists(cached_name):
print("loading cached cost from file")
cached_cost = np.load(cached_name).tolist()
else:
print("generating new cached cost")
cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)]
input_text = ""
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
for tmp_len in trange(1, max_seq_len + 1, step):
input_text += "test "
for tmp_batch in range(1, max_batch_size + 1):
batched_text = [input_text for _ in range(tmp_batch)]
start_time = time.time()
for k in range(repeat_round):
input_token = tokenizer(batched_text, return_tensors="pt")
output = engine.run(input_token)
predictions = output.to_here()
predicted_index = select_top_k(predictions, k=1)
total_predicted_text = tokenizer.decode(predicted_index)
time_cost = (time.time() - start_time) / repeat_round
cached_cost[tmp_len][tmp_batch] = time_cost
for k in range(1, step):
cached_cost[tmp_len + k][tmp_batch] = time_cost
np.save(cached_name, np.array(cached_cost))
print("cached cost loaded")
return cached_cost


class single_request():
def __init__(self, input_, time_stamp: float, input_str: str):
self.input_ = input_
self.text = input_str
self.time_ = time_stamp
self.seq_len = input_['input_ids'].shape[1]


class Manager:
def __init__(self):
pass

def insert_req(self, time_stamp: float, input_ids, input_str: str):
pass


class Batch_Manager(Manager):
def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180,
max_batch_size: int = 32, lr: float = 0.01, max_seq_len=1024):
super().__init__()
self.engine = engine
self.batch_acc = batch_acc
self.max_wait_time = max_wait_time
self.req_queue = queue.Queue()
self.res_dict = dict()

def fetch_inference_res(self, time_stamp: int):
if time_stamp not in self.res_dict.keys():
return "Error: Inference may have failed"
res = self.res_dict.pop(time_stamp)
return res

def append_req(self, time_stamp: int, input_str: str):
self.req_queue.put({})
self.max_batch_size = max_batch_size
self.lr = lr
self.mu = init_mu
self.theta = init_theta
self.max_seq_len = max_seq_len
# self.normal_weight = self._init_normal_dist_weight()
self.req_list = []
self.req_list_lock = rwlock.RWLockFair()
self.read_lock = self.req_list_lock.gen_rlock()
self.write_lock = self.req_list_lock.gen_wlock()
self.cached_cost = cached_cost
self.tokenizer = GPT2Tokenizer.from_pretrained('/home/lcdjs/hf_gpt2')
self.tokenizer.pad_token = GPT2Tokenizer.eos_token
self.running_flag = True
self.publisher = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
self.main_thread = threading.Thread(target=self.processing_batch)
self.main_thread.start()

def insert_req(self, time_stamp: float, input_ids, input_str: str):
tmp_req = single_request(input_ids, time_stamp, input_str)
self.write_lock.acquire()
self.req_list.append(tmp_req)
self.req_list.sort(key=lambda x: x.seq_len)
self.write_lock.release()

def cal_priority(self, batch_list: list, cur_stamp: float):
cur_len = batch_list[-1].seq_len
earliest_timestamp = min([i.time_ for i in batch_list])

wait_time = cur_stamp - earliest_timestamp
batch_size = len(batch_list)
appear_possibility_weight = 1.0 / self.cal_norm_weight(cur_len)

# TODO adjust the euqation
priority = appear_possibility_weight * batch_size * np.exp(wait_time)
return priority

# def _init_normal_dist_weight(self):
# temp_weight_list = [0]
# for i in range(1, self.max_seq_len):
# temp_weight_list.append(stats.norm(self.mu, self.theta).cdf(i) -
# stats.norm(self.mu, self.theta).cdf(i - 1))
# return temp_weight_list

def cal_norm_weight(self, seq_len):
return stats.norm(self.mu, self.theta).cdf(seq_len) - \
stats.norm(self.mu, self.theta).cdf(seq_len - 1)

def update_norm(self, batch_: list):
new_mu = np.mean([i.seq_len for i in batch_])
delta_mu = new_mu - self.mu
self.mu += self.lr * delta_mu
temp_batch = np.array([i.seq_len - self.mu for i in batch_])
new_theta = np.sqrt(np.mean(temp_batch ** 2))
delta_theta = new_theta - self.theta
self.theta += self.lr * delta_theta
return

def wrap_batch(self):
self.write_lock.acquire()
states = [0]
start_idx_list = [0]
for i in range(1, len(self.req_list) + 1):
j = i - 1
start_idx = i - 1
cur_length = self.req_list[i - 1].seq_len
# print(i, j, cur_length)
min_cost = self.cached_cost[cur_length][1] + states[j]
while j > max(0, i - self.max_batch_size):
tmp_cost = states[j - 1] + \
self.cached_cost[cur_length][i - j + 1]
if tmp_cost < min_cost:
min_cost = tmp_cost
start_idx = j - 1
j -= 1
states.append(min_cost)
start_idx_list.append(start_idx)
i = len(self.req_list)
res_start = -1
res_end = -1
max_priority = -1
cur_timestamp = time.time()
while i > 0:
end_idx = i
start_idx = start_idx_list[i]
current_batch = self.req_list[start_idx: end_idx]
current_priority = self.cal_priority(current_batch, cur_timestamp)
if current_priority > max_priority:
max_priority = current_priority
res_start = start_idx
res_end = end_idx
i = start_idx - 1
result_batch = self.req_list[res_start: res_end]
del self.req_list[res_start: res_end]
self.update_norm(result_batch)
self.write_lock.release()
return result_batch

def processing_batch(self):
while self.running_flag:
if len(self.req_list) > 0:
target_batch = self.wrap_batch()
pad_len = target_batch[-1].seq_len
print("A batch with {} requests and length of {} packed".format(len(target_batch), pad_len))
input_text = [i.text for i in target_batch]
input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt")
print("input_ids shape: {}".format(input_ids['input_ids'].shape))
print("attention_mask shape: {}".format(input_ids['attention_mask'].shape))
# input_ids = self.tokenizer(input_text, return_tensors="pt")
# print(input_ids)
output = self.engine.run(input_ids)
pub_thread = threading.Thread(target=self.publish_result, args=(output, target_batch))
pub_thread.start()

def publish_result(self, output, target_batch):
def select_top_k(batch_id, predictions, k=10):
predicted_index = random.choice(
predictions[batch_id, -1, :].sort(descending=True)[1][:k]).item()
return predicted_index
# print("output: {}".format(output))
predictions = output.to_here()
# print("predictions: {}".format(predictions), flush=True)
# decode_list = self.tokenizer.decode(predictions)
for i in range(len(target_batch)):
# print(i, predictions.shape, target_batch)
temp_st = target_batch[i].time_
chosen_pred = select_top_k(i, predictions, k=5)
text_ = self.tokenizer.decode(chosen_pred)
print("text: {}".format(text_))
self.publisher.publish(str(temp_st), text_)
99 changes: 99 additions & 0 deletions examples/gpt/gpt_batch_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import time

import redis
import torch
import uvicorn
from transformers import GPT2Tokenizer
from fastapi import FastAPI
from fastapi import Response, Body
import torch.distributed.rpc as rpc
from energon.engine import InferenceEngine
from energon.server.batch_manager import Batch_Manager, generate_cached_cost, Manager

app = FastAPI() # 创建 api 对象
# cached_cost = None
# batch_manager = Manager()
# server = None

@app.get("/") # 根路由
def root():
return {"200"}


@app.post("/model_with_padding")
def run(
input_str: str = Body(..., title="input_str", embed=True)
):

red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
sub = red.pubsub()
input_token = tokenizer(input_str, return_tensors="pt")
time_stamp = time.time()
batch_manager.insert_req(time_stamp, input_token, input_str)
sub.subscribe(str(time_stamp))
predictions = input_str
for message in sub.listen():
if message is not None and isinstance(message, dict):
print(message)
predictions = message.get('data')
if not isinstance(predictions, int):
break

return {predictions}


@app.get("/shutdown")
async def shutdown():
engine.clear()
server.should_exit = True
server.force_exit = True
await server.shutdown()


def launch_engine(model_name,
model_type,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
host: str = "localhost",
port: int = 29500,
dtype=torch.float,
checkpoint: str = None,
tokenizer_path: str = None,
server_host="localhost",
server_port=8005,
log_level="info"
):
if checkpoint:
model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint}
else:
model_config = {'dtype': dtype}

global tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)

global engine
engine = InferenceEngine(model_name,
model_config,
model_type,
max_batch_size=max_batch_size,
tp_init_size=tp_init_size,
pp_init_size=pp_init_size,
host=host,
port=port,
dtype=dtype)

global cached_cost
cached_cost = generate_cached_cost(engine, max_seq_len=256, max_batch_size=4, step=4, repeat_round=2)

global batch_manager
batch_manager = Batch_Manager(engine, cached_cost, max_seq_len=256, max_batch_size=4)
print("batch manager initialized")

global server
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level)
server = uvicorn.Server(config=config)
print("running server")
server.run()
print("application started")
62 changes: 62 additions & 0 deletions tests/test_server/test_batch_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import requests
import threading
import torch
import random
import os
import numpy as np
import time

latency = []


def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)


def generate_test_dataset(text_num: int = 100, max_len: int = 1024):
file_name = "test_set_{}_{}.txt".format(text_num, max_len)
if os.path.exists(file_name):
f = open(file_name)
res_text_list = f.readlines()
else:
tmp_str = "test "
len_list = torch.randint(low=1, high=max_len, size=(1, text_num))
res_text_list = [(tmp_str * len_list[0][i]) + "\n" for i in range(text_num)]
f = open(file_name, "w")
f.writelines(res_text_list)
res_text_list = [i.replace(" \n", "").replace('\n', '') for i in res_text_list]
return res_text_list


def send_request(input_: str, url_: str, port: str):
global latency
url_ = url_ + ":" + port + "/model_with_padding"
params = {"input_str": input_}
start_ = time.time()
response = requests.post(url=url_, json=params).text
latency.append(time.time() - start_)
print(response)


def test_batch():
global latency
ip_ = "http://127.0.0.1"
port_ = "8010"
req_num = 50
seq_len = 64
# req_list = ["test " * 10 for _ in range(req_num)]
req_list = generate_test_dataset(req_num, seq_len)
for i in range(req_num):
time.sleep(0.005)
temp_thread = threading.Thread(target=send_request, args=(req_list[i], ip_, port_))
temp_thread.start()
time.sleep(20)
print(np.mean(latency))


if __name__ == "__main__":
test_batch()