Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
[opt] executor update making batch policy
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Aug 26, 2022
2 parents 2796eed + 19dffb8 commit 9e14ad5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
55 changes: 39 additions & 16 deletions examples/opt/executor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import asyncio
import time
import torch
from threading import Thread
from typing import Any, Dict, Deque, List
from typing import Any, Dict, Deque
from collections import namedtuple, deque
from energonai.logging import get_dist_logger

GenerationArgs = namedtuple('GenerationArgs', ['top_k', 'top_p', 'temperature', 'max_tokens'])
SubmitEntry = namedtuple('SubmitEntry', ['text', 'args'])
GenerationArgs = namedtuple('GenerationArgs', ['top_k', 'top_p', 'temperature'])
SubmitEntry = namedtuple('SubmitEntry', ['inputs', 'args', 'decode_steps'])


class Executor:
def __init__(self, engine, tokenizer, max_batch_size: int = 1) -> None:
def __init__(self, engine, pad_token_id: int = 0, max_batch_size: int = 1) -> None:
self.engine = engine
self.tokenizer = tokenizer
self.pad_token_id = pad_token_id
self.max_batch_size = max_batch_size
self.running: bool = False
self.thread = None
Expand All @@ -24,12 +25,13 @@ def _start(self) -> None:
self.running = True
while self.running:
if len(self.submit_queue) > 0:
inputs, entry_ids = self._make_batch()
inputs, entry_ids, trunc_lens, decode_steps = self._make_batch()
start = time.time()
outputs = self.engine.run(inputs).to_here()
for entry_id, output in zip(entry_ids, outputs):
self.ready_map[entry_id] = self.tokenizer.decode(output, skip_special_tokens=True)
self.logger.info(f'batch size: {len(entry_ids)}, time: {time.time()-start:.3f} s')
for entry_id, output, trunc_len in zip(entry_ids, outputs, trunc_lens):
self.ready_map[entry_id] = output[:trunc_len]
self.logger.info(
f'batch size: {len(entry_ids)}, decode steps: {decode_steps}, time: {time.time()-start:.3f} s')

def _make_batch(self):
entry = self.submit_queue.popleft()
Expand All @@ -39,24 +41,31 @@ def _make_batch(self):
break
if self.submit_queue[0].args != entry.args:
break
if self.submit_queue[0].decode_steps > entry.decode_steps:
break
batch.append(self.submit_queue.popleft())
batch_text = [e.text for e in batch]
inputs = self.tokenizer(batch_text, padding=True, return_tensors='pt')
inputs, max_len = self._left_padding([e.inputs for e in batch])
entry_ids = []
trunc_lens = []
for e in batch:
entry_ids.append(id(e))
trunc_lens.append(max_len + e.decode_steps)
inputs['top_k'] = entry.args.top_k
inputs['top_p'] = entry.args.top_p
inputs['temperature'] = entry.args.temperature
inputs['max_tokens'] = entry.args.max_tokens
return inputs, [id(e) for e in batch]
inputs['max_tokens'] = max_len + entry.decode_steps
return inputs, entry_ids, trunc_lens, entry.decode_steps

def start(self):
self.thread = Thread(target=self._start)
self.thread.start()

def submit(self, promt, max_tokens, top_k, top_p, temperature):
def submit(self, inputs, max_tokens, top_k, top_p, temperature):
if not self.running:
raise RuntimeError('executor is shutdown')
args = GenerationArgs(top_k, top_p, temperature, max_tokens)
entry = SubmitEntry(promt, args)
args = GenerationArgs(top_k, top_p, temperature)
decode_steps = max_tokens - len(inputs['input_ids'])
entry = SubmitEntry(inputs, args, decode_steps)
self.submit_queue.append(entry)
return id(entry)

Expand All @@ -71,3 +80,17 @@ async def wait(self, entry_id):
def teardown(self):
self.running = False
self.thread.join()

def _left_padding(self, batch_inputs):
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
outputs = {'input_ids': [], 'attention_mask': []}
for inputs in batch_inputs:
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
padding_len = max_len - len(input_ids)
input_ids = [self.pad_token_id] * padding_len + input_ids
attention_mask = [0] * padding_len + attention_mask
outputs['input_ids'].append(input_ids)
outputs['attention_mask'].append(attention_mask)
for k in outputs:
outputs[k] = torch.tensor(outputs[k])
return outputs, max_len
6 changes: 4 additions & 2 deletions examples/opt/opt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ class GenerationTaskReq(BaseModel):
@app.post('/generation')
async def generate(data: GenerationTaskReq, request: Request):
logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}')
handle = executor.submit(data.prompt, data.max_tokens, data.top_k, data.top_p, data.temperature)
inputs = tokenizer(data.prompt)
handle = executor.submit(inputs, data.max_tokens, data.top_k, data.top_p, data.temperature)
output = await executor.wait(handle)
output = tokenizer.decode(output, skip_special_tokens=True)
return {'text': output}


Expand Down Expand Up @@ -83,7 +85,7 @@ def launch_engine(model_class,
port=port,
dtype=dtype)
global executor
executor = Executor(engine, tokenizer, max_batch_size=16)
executor = Executor(engine, pad_token_id=tokenizer.pad_token_id, max_batch_size=16)
executor.start()

global server
Expand Down

0 comments on commit 9e14ad5

Please sign in to comment.