-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced140. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced140. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced140. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
- Loading branch information
1 parent
4e4a10c
commit cf579ff
Showing
30 changed files
with
2,005 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import asyncio | ||
|
||
from colossalai.inference.dynamic_batching.ray_dist_init import Driver | ||
|
||
from .dynamic_batching.io_struct import RequestOutput | ||
from .dynamic_batching.sampling_params import SamplingParams | ||
|
||
|
||
class RequestTracker: | ||
""" | ||
A class for trace down all the requests, abstraction for async | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._requests: asyncio.Queue[str] = asyncio.Queue() | ||
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() | ||
self.new_requests_event = None | ||
|
||
def __contains__(self, item): | ||
return item in self._requests | ||
|
||
def init_event(self): | ||
self.new_requests_event = asyncio.Event() | ||
|
||
def add_request(self, request_id: str): | ||
"""Add a request to be sent to the engine on the next background | ||
loop iteration.""" | ||
self._requests.put_nowait(request_id) | ||
self.new_requests_event.set() # NOTE: we may find a better way to clear this event | ||
|
||
def add_stop(self): | ||
""" | ||
Add a StopIteration flag to stop async generator. | ||
""" | ||
self._finished_requests.put_nowait(StopIteration) | ||
self.new_requests_event.clear() | ||
|
||
def process_request_output(self, request_output: RequestOutput) -> None: | ||
"""Process a request output from the engine.""" | ||
self._finished_requests.put_nowait(request_output) | ||
|
||
async def wait_for_new_requests(self): | ||
await self.new_requests_event.wait() | ||
|
||
def __aiter__(self): | ||
return self | ||
|
||
async def __anext__(self) -> RequestOutput: | ||
result = await self._finished_requests.get() | ||
# print("result of ", result) | ||
if result is StopIteration: | ||
raise StopAsyncIteration | ||
return result | ||
|
||
|
||
class Async_Engine: | ||
|
||
""" | ||
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager | ||
Background loop: inference reqs in waiting list (Listen) | ||
Request Tracker: manage incoming requests and restore finished ones | ||
Generate: exposed func for add new input and return finished ones | ||
""" | ||
|
||
def __init__( | ||
self, | ||
router_config, | ||
engine_config, | ||
start_engine_loop: bool = True, | ||
) -> None: | ||
self.driver = Driver(router_config=router_config, engine_config=engine_config) | ||
self.background_loop = None | ||
self.start_engine_loop = start_engine_loop | ||
self._request_tracker = RequestTracker() | ||
|
||
def _step(self): | ||
""" | ||
Logic for handling requests | ||
""" | ||
request_outputs = self.driver.step() | ||
if request_outputs is not None: | ||
for request_output in request_outputs: | ||
self._request_tracker.process_request_output(request_output) | ||
self._request_tracker.add_stop() | ||
|
||
def abort_request(self, request_id: str): | ||
self.driver.abort(request_id) | ||
|
||
def _has_requests_in_progress(self): | ||
return self.driver.is_running() | ||
|
||
async def run_loop_fwd(self): | ||
has_requests_in_progress = self._has_requests_in_progress() | ||
while True: | ||
if not has_requests_in_progress: | ||
await self._request_tracker.wait_for_new_requests() | ||
self._step() | ||
await asyncio.sleep(0) | ||
|
||
@property | ||
def is_running(self): | ||
return self.background_loop is not None and not self.background_loop.done() | ||
|
||
def start_background_loop(self): | ||
if self.is_running: | ||
raise RuntimeError("Background loop is already running.") | ||
|
||
self._request_tracker.init_event() | ||
|
||
self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) | ||
self.background_loop = asyncio.shield(self.background_loop_unshielded) | ||
|
||
async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): | ||
self.driver.add_input(request_id, prompt, sampling_params) | ||
self._request_tracker.add_request(request_id) | ||
|
||
async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): | ||
""" | ||
The only exposed func, adding new request and return a async generator that yields the existing results. | ||
""" | ||
try: | ||
if not self.is_running: | ||
self.start_background_loop() | ||
|
||
await self.add_request(request_id, prompt, sampling_params) | ||
|
||
async for request_output in self._request_tracker: | ||
yield request_output | ||
|
||
except (Exception, asyncio.CancelledError) as e: | ||
# If there is an exception or coroutine is cancelled, abort the request. | ||
self.abort_request(request_id) | ||
raise e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from typing import List | ||
|
||
from .dynamic_batching.io_struct import Batch, Req, RequestOutput | ||
from .manager import DynamicBatchManager | ||
from .tensor_parallel import TPInferEngine | ||
|
||
|
||
class Async_DynamicBatchManager(DynamicBatchManager): | ||
def __init__( | ||
self, | ||
tp_engine: TPInferEngine, | ||
max_total_token_num: int, | ||
batch_max_tokens: int, | ||
model: str, | ||
tokenizer=None, | ||
eos_id=None, | ||
log_stats=True, | ||
log_stats_interval=10, | ||
running_batch: Batch = None, | ||
waiting_req_list: List = [], | ||
): | ||
""" | ||
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager | ||
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) | ||
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests | ||
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine | ||
eos_id : The end token of a seq | ||
model: the model weight dir path, the app will load config, weights and tokenizer from this dir | ||
log_stats : whether to log stats | ||
log_stats_interval : log stats interval | ||
running_batch : running batch | ||
waiting_req_list : list of waiting requests, initialized before dynamic batch manager | ||
""" | ||
super().__init__( | ||
tp_engine, | ||
max_total_token_num, | ||
batch_max_tokens, | ||
model, | ||
tokenizer, | ||
eos_id, | ||
log_stats, | ||
log_stats_interval, | ||
running_batch, | ||
waiting_req_list, | ||
) | ||
|
||
def _step(self): | ||
""" | ||
Logic for handling requests | ||
""" | ||
has_new_finished = False | ||
if self.running_batch is None: | ||
new_batch = self.req_queue.generate_new_batch(self.running_batch) | ||
if new_batch is not None: | ||
self.stats_tool.count_prompt_tokens(new_batch) | ||
self.running_batch = new_batch | ||
has_new_finished, outputs = self._prefill_batch(self.running_batch) | ||
self._filter_runing_batch() | ||
self.has_wait_tokens = 0 | ||
|
||
else: | ||
if self.has_wait_tokens < self.max_wait_tokens: | ||
self.stats_tool.count_output_tokens(self.running_batch) | ||
has_new_finished, outputs = self._decode_batch(self.running_batch) | ||
self._filter_runing_batch() | ||
self.has_wait_tokens += 1 | ||
|
||
else: | ||
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) | ||
if new_mini_batch is not None: | ||
self.stats_tool.count_prompt_tokens(new_mini_batch) | ||
has_new_finished, outputs = self._prefill_batch(new_mini_batch) | ||
if not new_mini_batch.is_clear(): | ||
self._merge_batch(self.running_batch, new_mini_batch) | ||
self.running_batch.merge(new_mini_batch) | ||
self.has_wait_tokens = 0 | ||
|
||
else: | ||
self.stats_tool.count_output_tokens(self.running_batch) | ||
has_new_finished, outputs = self._decode_batch(self.running_batch) | ||
self._filter_runing_batch() | ||
self.has_wait_tokens += 1 | ||
|
||
if has_new_finished: | ||
return outputs | ||
return None | ||
|
||
def _prefill_batch(self, batch): | ||
""" | ||
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. | ||
""" | ||
self._init_batch(batch) | ||
|
||
# TODO: figure out if cache and batch id is needed | ||
ans = self.engine._prefill_batch(batch.batch_id) | ||
req_to_out_token_id = ans | ||
self._add_token_id_to_req(batch, req_to_out_token_id) | ||
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) | ||
outputs = self._handle_finish_req(batch, has_new_finished_req) | ||
return has_new_finished_req, outputs | ||
# delete finished reqs | ||
|
||
def _decode_batch(self, batch: Batch): | ||
""" | ||
Decoding process | ||
""" | ||
ans = self.engine._decode_batch(batch.batch_id) | ||
req_to_out_token_id = ans | ||
self._add_token_id_to_req(batch, req_to_out_token_id) | ||
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) | ||
outputs = self._handle_finish_req(batch, has_new_finished_req) | ||
return has_new_finished_req, outputs | ||
|
||
def _handle_finish_req(self, batch: Batch, has_new_finished_req): | ||
if has_new_finished_req: | ||
finished_reqs = batch.filter_finished() | ||
if batch.is_clear(): | ||
self._remove_batch(batch) | ||
else: | ||
self._filter_batch(batch) | ||
return self._output_process(finished_reqs) | ||
return None | ||
|
||
def _output_process(self, finished_reqs: List[Req]): | ||
""" | ||
Process the output of a batch. | ||
""" | ||
outputs = [] | ||
for req in finished_reqs: | ||
output = self.tokenizer.decode(req.output_ids) | ||
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) | ||
return outputs | ||
|
||
|
||
def start_dynamic_batching(args, tp_engine, waiting_req_list): | ||
try: | ||
batch_manager = Async_DynamicBatchManager( | ||
tp_engine=tp_engine, | ||
max_total_token_num=args.max_total_token_num, | ||
batch_max_tokens=args.batch_max_tokens, | ||
eos_id=args.eos_id, | ||
model=args.model, | ||
log_stats=not args.disable_log_stats, | ||
log_stats_interval=args.log_stats_interval, | ||
waiting_req_list=waiting_req_list, | ||
) | ||
|
||
except Exception: | ||
raise Exception | ||
|
||
return batch_manager |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. | ||
license: MIT, see LICENSE for more details. | ||
""" | ||
|
||
from transformers import AutoTokenizer | ||
|
||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" | ||
|
||
|
||
def get_tokenizer( | ||
tokenizer=None, | ||
tokenizer_name: str = "", | ||
trust_remote_code: bool = False, | ||
use_fast: bool = True, | ||
): | ||
if tokenizer is not None: | ||
tokenizer = tokenizer | ||
else: | ||
if "llama" in tokenizer_name.lower() and use_fast == True: | ||
print( | ||
"For some LLaMA-based models, initializing the fast tokenizer may " | ||
"take a long time. To eliminate the initialization time, consider " | ||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " | ||
"tokenizer. This is done automatically in Colossalai." | ||
) | ||
|
||
tokenizer_name = _FAST_LLAMA_TOKENIZER | ||
|
||
try: | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code | ||
) | ||
except TypeError: | ||
use_fast = False | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code | ||
) | ||
return tokenizer |
Oops, something went wrong.