Skip to content

Commit

Permalink
[inference]Re push async dynamic batching (hpcaitech#4901)
Browse files Browse the repository at this point in the history
* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
  • Loading branch information
CjhHa1 and isky-cd authored Oct 13, 2023
1 parent fced140 commit fbf3c09
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 109 deletions.
15 changes: 10 additions & 5 deletions colossalai/inference/dynamic_batching/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
self.request_id = request_id
self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids)
Expand All @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False
self.prompts = prompts

def to_rpc_obj(self):
return {
Expand All @@ -36,7 +37,11 @@ def stop_sequences_matched(self):
if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
if (
stop_len > 0
and len(self.output_ids) >= stop_len
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
):
return True
return False

Expand Down Expand Up @@ -102,7 +107,7 @@ def mark_finished_req(self, eos_id):
has_new_finish = True
return has_new_finish

def filter_finished(self)->List[Req]:
def filter_finished(self) -> List[Req]:
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
Expand All @@ -111,9 +116,9 @@ def filter_finished(self)->List[Req]:
finished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
unfinished_req.append(req)
else:
finished_req.append(req)
finished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return finished_req
Expand Down
139 changes: 76 additions & 63 deletions colossalai/inference/manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from typing import List
import asyncio
from typing import List

from transformers import AutoTokenizer

from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req
Expand All @@ -9,16 +10,17 @@
from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine

from transformers import AutoTokenizer
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


class DynamicBatchManager:
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
eos_id,
model,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
Expand All @@ -30,6 +32,7 @@ def __init__(
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
Expand All @@ -45,32 +48,32 @@ def __init__(
self.eos_id = eos_id
self.has_wait_tokens = 0
self.max_wait_tokens = 10

self.model = model

self.stats_tool = Stats(log_stats, log_stats_interval)
self.mem_usage_interval = log_stats_interval * 2
self._set_tokenizer(tokenizer_name=self.model)

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
req = Req(request_id, prompt_ids, sampling_params)
req = Req(request_id, prompt_ids, sampling_params, prompts)
self.req_queue.append(req)
return

def add_input(self, request_id, sampling_params, input_ids):
async def add_input(self, request_id, sampling_params, prompts):
"""
Encode and Add new input to req queue. support one sequence input for now.
"""
prompt_ids = self.tokenizer.encode(input_ids)
prompt_ids = self.tokenizer.encode(prompts)
prompt_len = len(prompt_ids)
if prompt_len > self.engine.max_input_len:
raise ValueError(
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
)
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(prompt_ids, sampling_params, request_id)
self.add_req(request_id, prompt_ids, sampling_params, prompts)
return

def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
Expand All @@ -88,10 +91,15 @@ async def loop_for_fwd(self):
The main loop for a dynamic batching process.
"""
counter_count = 0
#self.running_batch is not None or self.req_queue.waiting_req_list
# self.running_batch is not None or self.req_queue.waiting_req_list
while True:
async for item in self._step():
yield item
if self.running_batch is not None or self.req_queue.waiting_req_list:
async for result in self._step():
yield result
else:
# need to wait for new requests
await asyncio.sleep(0.1)
continue
counter_count += 1
if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0:
Expand All @@ -103,30 +111,33 @@ async def loop_for_fwd(self):
)
self.stats_tool.print_stats()

if self.running_batch is None:
time.sleep(0.1) # 10ms

def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
def _set_tokenizer(
self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True
):
if tokenizer is not None:
self.tokenizer = tokenizer
self.tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai.")

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
except TypeError as e:
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai."
)

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
use_fast = False
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)

def _step(self):
async def _step(self):
"""
Logic for handling requests
"""
Expand All @@ -136,33 +147,36 @@ def _step(self):
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
yield from self._prefill_batch(self.running_batch)
async for item in self._prefill_batch(self.running_batch):
yield item
self._filter_runing_batch()
self.has_wait_tokens = 0
return

if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
yield from self._prefill_batch(new_mini_batch)
async for item in self._prefill_batch(new_mini_batch):
yield item
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
async for item in self._decode_batch(self.running_batch):
yield item
self._filter_runing_batch()
self.has_wait_tokens += 1

return

def _init_batch(self, batch: Batch, dtype="fp16"):
Expand All @@ -187,7 +201,7 @@ def _init_batch(self, batch: Batch, dtype="fp16"):
)
self.engine.cache[batch_id] = batch_data

def _prefill_batch(self, batch):
async def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
Expand All @@ -198,19 +212,20 @@ def _prefill_batch(self, batch):
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
yield from self._handle_finish_req(batch, has_new_finished_req)

async for item in self._handle_finish_req(batch, has_new_finished_req):
yield item
# delete finished reqs

def _decode_batch(self, batch: Batch):
async def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
yield from self._handle_finish_req(batch, has_new_finished_req)
async for item in self._handle_finish_req(batch, has_new_finished_req):
yield item

def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id
Expand Down Expand Up @@ -240,15 +255,15 @@ def _remove_batch(self, batch):
batch.free_self()
del batch

def _handle_finish_req(self, batch: Batch, has_new_finished_req):
async def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs=batch.filter_finished()
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
yield from self._output_process(finished_reqs)

async for item in self._output_process(finished_reqs):
yield item

def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
Expand All @@ -267,18 +282,24 @@ async def _output_process(self, finished_reqs: List[Req]):
"""
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
yield output, req.request_id, req.output_metadata_list
yield req.prompts + output

def clean_up(self):
# this logic should be implemented in the future.
pass

async def generate(self,request_id,prompt_id,sampling_params):
async def generate(self, request_id, prompt_id, sampling_params):
"""
Generate the output of a request.
"""
self.add_input(request_id,prompt_id,sampling_params)


await self.add_input(request_id, prompt_id, sampling_params)


async def process_data(dbm):
async for data in dbm.loop_for_fwd():
print(data)


def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
Expand All @@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
batch_manager.clean_up()
raise

batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))

asyncio.run(prod_task)

for item in batch_manager.loop_for_fwd():
print(item)
raise RuntimeError("Failed to start dynamic batching")

return batch_manager
33 changes: 0 additions & 33 deletions colossalai/inference/test_async.py

This file was deleted.

Loading

0 comments on commit fbf3c09

Please sign in to comment.