From f4a8664fe0a26bac686ecd235421e91f0149db39 Mon Sep 17 00:00:00 2001 From: maruyama Date: Fri, 6 May 2022 14:26:20 +0800 Subject: [PATCH] Add comments to Batch Manager --- energon/server/batch_manager.py | 132 ++++++++++++++++++++++++-------- requirements.txt | 19 +++-- 2 files changed, 114 insertions(+), 37 deletions(-) diff --git a/energon/server/batch_manager.py b/energon/server/batch_manager.py index 4e90c8a..73b44a0 100644 --- a/energon/server/batch_manager.py +++ b/energon/server/batch_manager.py @@ -1,4 +1,9 @@ -import torch +""" +------------------------------------------ +Class Batch Manager and the function for generating cached cost. +This code modifies the batch wrapping algorithm of Turbo Transformer. +------------------------------------------ +""" import time from scipy import stats import numpy as np @@ -7,28 +12,47 @@ import random import redis import os -from tqdm import tqdm, trange +from tqdm import trange import threading from readerwriterlock import rwlock +import logging def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1, repeat_round: int = 3): - def select_top_k(predictions, k=10): - predicted_index = random.choice( - predictions[0, -1, :].sort(descending=True)[1][:k]).item() - return predicted_index + """ + Test the running time for different sequence length and batch size on the current machine. + :param engine: InferenceEngine from energon.engine + :type engine: InferenceEngine + :param max_seq_len: The max sequence length that is measured. + :param max_batch_size: The max batch size that is measured. + :param step: Run time is measured every other 'step' of sequence length + :param repeat_round: We inference current batch 'repeat_round' times and take average. + """ - print("fetching cached cost") + def select_top_k(temp_predictions, top_k: int = 10): + """ + Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions + for each sequence in this batch. + :param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size) + which contains the possibilities for each word in this batch. + :type temp_predictions: torch.Tensor + :param top_k: How many top words to choose from. + """ + temp_predicted_index = random.choice( + temp_predictions[0, -1, :].sort(descending=True)[1][:top_k]).item() + return temp_predicted_index + + logging.log(0, "fetching cached cost") cached_name = "cached_cost_{}_{}_{}_{}.npy".format(max_seq_len, max_batch_size, step, repeat_round) if os.path.exists(cached_name): - print("loading cached cost from file") + logging.log(0, "loading cached cost from file") cached_cost = np.load(cached_name).tolist() else: - print("generating new cached cost") + logging.log(0, "generating new cached cost") cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)] input_text = "" - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + tokenizer = GPT2Tokenizer.from_pretrained("./") for tmp_len in trange(1, max_seq_len + 1, step): input_text += "test " for tmp_batch in range(1, max_batch_size + 1): @@ -39,18 +63,25 @@ def select_top_k(predictions, k=10): output = engine.run(input_token) predictions = output.to_here() predicted_index = select_top_k(predictions, k=1) - total_predicted_text = tokenizer.decode(predicted_index) + tokenizer.decode(predicted_index) time_cost = (time.time() - start_time) / repeat_round cached_cost[tmp_len][tmp_batch] = time_cost for k in range(1, step): cached_cost[tmp_len + k][tmp_batch] = time_cost np.save(cached_name, np.array(cached_cost)) - print("cached cost loaded") + logging.log(0, "cached cost loaded") return cached_cost -class single_request(): +class single_request: def __init__(self, input_, time_stamp: float, input_str: str): + """ + class to store related information for a single request. + :param input_: The output of GPT2Tokenizer.tokenizer, a dict including input_ids and attention_mask + :param time_stamp: The time stamp when we receive the request. We use the time stamp as a index to + identify the request. + :param input_str: The input string of the request. + """ self.input_ = input_ self.text = input_str self.time_ = time_stamp @@ -58,6 +89,9 @@ def __init__(self, input_, time_stamp: float, input_str: str): class Manager: + """ + Base class of batch manager. + """ def __init__(self): pass @@ -66,19 +100,30 @@ def insert_req(self, time_stamp: float, input_ids, input_str: str): class Batch_Manager(Manager): + """ + This batch manager is mainly used for maintaining a queue of request to be processed. The requests in the + queue is wrapped into batches according to the sequence length and the priority calculated with the equation + in function cal_priority and then sent into the inference engine. + """ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180, - max_batch_size: int = 32, lr: float = 0.01, max_seq_len=1024): + max_batch_size: int = 32, lr: float = 0.01): + """ + :param engine: The InferenceEngine from energon.engine + :param cached_cost: The output of function generate_cached_cost + :param init_mu: initial mean value we suppose for incoming sequence length. + :param init_theta: initial variance value we suppose for incoming sequence length. + :param max_batch_size: the max number of requests that can be wrapped into one batch. + :param lr: the learning rate we use to update the mean and variance that we suppose for the normal + distribution of sequence length. + """ super().__init__() self.engine = engine self.max_batch_size = max_batch_size self.lr = lr self.mu = init_mu self.theta = init_theta - self.max_seq_len = max_seq_len - # self.normal_weight = self._init_normal_dist_weight() self.req_list = [] self.req_list_lock = rwlock.RWLockFair() - self.read_lock = self.req_list_lock.gen_rlock() self.write_lock = self.req_list_lock.gen_wlock() self.cached_cost = cached_cost self.tokenizer = GPT2Tokenizer.from_pretrained('/home/lcdjs/hf_gpt2') @@ -89,6 +134,9 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51 self.main_thread.start() def insert_req(self, time_stamp: float, input_ids, input_str: str): + """ + Build a single_request class with the input string and then insert it into the queue. + """ tmp_req = single_request(input_ids, time_stamp, input_str) self.write_lock.acquire() self.req_list.append(tmp_req) @@ -96,6 +144,16 @@ def insert_req(self, time_stamp: float, input_ids, input_str: str): self.write_lock.release() def cal_priority(self, batch_list: list, cur_stamp: float): + """ + Given a wrapped batch, calculate its priority to decide which batch to be given to the inference engine. + The equation is based on the sequence length, batch size and the max wait time among the batch. + We suppose that the length of the requests follows a normal distribution, so for the batches with a + length that has a higher possibility to appear, we tend to let it wait a little longer for other requests + with similar length in order to increase the batch size. + The batches with larger batch size also gains higher priority. + In order to avoid starving problem, we use exponential function to raise the priority of batches which + have waited for long. + """ cur_len = batch_list[-1].seq_len earliest_timestamp = min([i.time_ for i in batch_list]) @@ -107,18 +165,18 @@ def cal_priority(self, batch_list: list, cur_stamp: float): priority = appear_possibility_weight * batch_size * np.exp(wait_time) return priority - # def _init_normal_dist_weight(self): - # temp_weight_list = [0] - # for i in range(1, self.max_seq_len): - # temp_weight_list.append(stats.norm(self.mu, self.theta).cdf(i) - - # stats.norm(self.mu, self.theta).cdf(i - 1)) - # return temp_weight_list - def cal_norm_weight(self, seq_len): + """ + Approximately estimate the possibility of a certain sequence length using normal distribution. + """ return stats.norm(self.mu, self.theta).cdf(seq_len) - \ stats.norm(self.mu, self.theta).cdf(seq_len - 1) def update_norm(self, batch_: list): + """ + Every time we are done inserting a request into the inference engine, we update mu and theta of our + distribution with the current batch and the pre-set learning rate. + """ new_mu = np.mean([i.seq_len for i in batch_]) delta_mu = new_mu - self.mu self.mu += self.lr * delta_mu @@ -129,6 +187,11 @@ def update_norm(self, batch_: list): return def wrap_batch(self): + """ + Given a sorted sequence list, calculate the best way to wrap the batch with DP according to the + cached cost. + The algorithm in this function comes from the paper of Turbo Transformer. + """ self.write_lock.acquire() states = [0] start_idx_list = [0] @@ -136,7 +199,6 @@ def wrap_batch(self): j = i - 1 start_idx = i - 1 cur_length = self.req_list[i - 1].seq_len - # print(i, j, cur_length) min_cost = self.cached_cost[cur_length][1] + states[j] while j > max(0, i - self.max_batch_size): tmp_cost = states[j - 1] + \ @@ -169,30 +231,38 @@ def wrap_batch(self): return result_batch def processing_batch(self): + """ + The background process that continuously calls wrap_batch, puts the batch into the inference engine, + and starts new processes that wait for and publish the inference result. + """ while self.running_flag: if len(self.req_list) > 0: target_batch = self.wrap_batch() pad_len = target_batch[-1].seq_len - print("A batch with {} requests and length of {} packed".format(len(target_batch), pad_len)) + logging.log(0, "A batch with {} requests and length of {} packed".format(len(target_batch), pad_len)) input_text = [i.text for i in target_batch] input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt") - print("input_ids shape: {}".format(input_ids['input_ids'].shape)) - print("attention_mask shape: {}".format(input_ids['attention_mask'].shape)) - # input_ids = self.tokenizer(input_text, return_tensors="pt") - # print(input_ids) + # print("input_ids shape: {}".format(input_ids['input_ids'].shape)) + # print("attention_mask shape: {}".format(input_ids['attention_mask'].shape)) output = self.engine.run(input_ids) pub_thread = threading.Thread(target=self.publish_result, args=(output, target_batch)) pub_thread.start() def publish_result(self, output, target_batch): + """ + Background process that waits for the inference result and uses the publisher of Redis to publish it to + the waiting requests. + :param output: the rpc reference of the inference result. + :param target_batch: the input batch + """ def select_top_k(batch_id, predictions, k=10): predicted_index = random.choice( predictions[batch_id, -1, :].sort(descending=True)[1][:k]).item() return predicted_index + # print("output: {}".format(output)) predictions = output.to_here() # print("predictions: {}".format(predictions), flush=True) - # decode_list = self.tokenizer.decode(predictions) for i in range(len(target_batch)): # print(i, predictions.shape, target_batch) temp_st = target_batch[i].time_ diff --git a/requirements.txt b/requirements.txt index ba35f20..482a85f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,17 @@ torch>=1.8 -numpy -tqdm +numpy~=1.21.2 +tqdm~=4.64.0 psutil packaging -fastapi +fastapi~=0.75.1 uvicorn==0.14 -typer -redis -scipy \ No newline at end of file +typer~=0.4.0 +redis~=4.2.2 +scipy~=1.8.0 +energon~=0.0.1b0 +pytest~=7.1.1 +requests~=2.27.1 +click~=8.1.2 +transformers~=4.18.0 +readerwriterlock~=1.0.9 +setuptools~=58.0.4 \ No newline at end of file