diff --git a/examples/opt/executor.py b/examples/opt/executor.py index 592a57b..3f3d23a 100644 --- a/examples/opt/executor.py +++ b/examples/opt/executor.py @@ -1,18 +1,19 @@ import asyncio import time +import torch from threading import Thread -from typing import Any, Dict, Deque, List +from typing import Any, Dict, Deque from collections import namedtuple, deque from energonai.logging import get_dist_logger -GenerationArgs = namedtuple('GenerationArgs', ['top_k', 'top_p', 'temperature', 'max_tokens']) -SubmitEntry = namedtuple('SubmitEntry', ['text', 'args']) +GenerationArgs = namedtuple('GenerationArgs', ['top_k', 'top_p', 'temperature']) +SubmitEntry = namedtuple('SubmitEntry', ['inputs', 'args', 'decode_steps']) class Executor: - def __init__(self, engine, tokenizer, max_batch_size: int = 1) -> None: + def __init__(self, engine, pad_token_id: int = 0, max_batch_size: int = 1) -> None: self.engine = engine - self.tokenizer = tokenizer + self.pad_token_id = pad_token_id self.max_batch_size = max_batch_size self.running: bool = False self.thread = None @@ -24,12 +25,13 @@ def _start(self) -> None: self.running = True while self.running: if len(self.submit_queue) > 0: - inputs, entry_ids = self._make_batch() + inputs, entry_ids, trunc_lens, decode_steps = self._make_batch() start = time.time() outputs = self.engine.run(inputs).to_here() - for entry_id, output in zip(entry_ids, outputs): - self.ready_map[entry_id] = self.tokenizer.decode(output, skip_special_tokens=True) - self.logger.info(f'batch size: {len(entry_ids)}, time: {time.time()-start:.3f} s') + for entry_id, output, trunc_len in zip(entry_ids, outputs, trunc_lens): + self.ready_map[entry_id] = output[:trunc_len] + self.logger.info( + f'batch size: {len(entry_ids)}, decode steps: {decode_steps}, time: {time.time()-start:.3f} s') def _make_batch(self): entry = self.submit_queue.popleft() @@ -39,24 +41,31 @@ def _make_batch(self): break if self.submit_queue[0].args != entry.args: break + if self.submit_queue[0].decode_steps > entry.decode_steps: + break batch.append(self.submit_queue.popleft()) - batch_text = [e.text for e in batch] - inputs = self.tokenizer(batch_text, padding=True, return_tensors='pt') + inputs, max_len = self._left_padding([e.inputs for e in batch]) + entry_ids = [] + trunc_lens = [] + for e in batch: + entry_ids.append(id(e)) + trunc_lens.append(max_len + e.decode_steps) inputs['top_k'] = entry.args.top_k inputs['top_p'] = entry.args.top_p inputs['temperature'] = entry.args.temperature - inputs['max_tokens'] = entry.args.max_tokens - return inputs, [id(e) for e in batch] + inputs['max_tokens'] = max_len + entry.decode_steps + return inputs, entry_ids, trunc_lens, entry.decode_steps def start(self): self.thread = Thread(target=self._start) self.thread.start() - def submit(self, promt, max_tokens, top_k, top_p, temperature): + def submit(self, inputs, max_tokens, top_k, top_p, temperature): if not self.running: raise RuntimeError('executor is shutdown') - args = GenerationArgs(top_k, top_p, temperature, max_tokens) - entry = SubmitEntry(promt, args) + args = GenerationArgs(top_k, top_p, temperature) + decode_steps = max_tokens - len(inputs['input_ids']) + entry = SubmitEntry(inputs, args, decode_steps) self.submit_queue.append(entry) return id(entry) @@ -71,3 +80,17 @@ async def wait(self, entry_id): def teardown(self): self.running = False self.thread.join() + + def _left_padding(self, batch_inputs): + max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) + outputs = {'input_ids': [], 'attention_mask': []} + for inputs in batch_inputs: + input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + padding_len = max_len - len(input_ids) + input_ids = [self.pad_token_id] * padding_len + input_ids + attention_mask = [0] * padding_len + attention_mask + outputs['input_ids'].append(input_ids) + outputs['attention_mask'].append(attention_mask) + for k in outputs: + outputs[k] = torch.tensor(outputs[k]) + return outputs, max_len diff --git a/examples/opt/opt_server.py b/examples/opt/opt_server.py index c58e4a5..7afe696 100644 --- a/examples/opt/opt_server.py +++ b/examples/opt/opt_server.py @@ -24,8 +24,10 @@ class GenerationTaskReq(BaseModel): @app.post('/generation') async def generate(data: GenerationTaskReq, request: Request): logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') - handle = executor.submit(data.prompt, data.max_tokens, data.top_k, data.top_p, data.temperature) + inputs = tokenizer(data.prompt) + handle = executor.submit(inputs, data.max_tokens, data.top_k, data.top_p, data.temperature) output = await executor.wait(handle) + output = tokenizer.decode(output, skip_special_tokens=True) return {'text': output} @@ -83,7 +85,7 @@ def launch_engine(model_class, port=port, dtype=dtype) global executor - executor = Executor(engine, tokenizer, max_batch_size=16) + executor = Executor(engine, pad_token_id=tokenizer.pad_token_id, max_batch_size=16) executor.start() global server