Skip to content

Commit

Permalink
[Inference] Dynamic Batching Inference, online and offline (#4953)
Browse files Browse the repository at this point in the history
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent 4e4a10c commit cf579ff
Show file tree
Hide file tree
Showing 30 changed files with 2,005 additions and 92 deletions.
133 changes: 133 additions & 0 deletions colossalai/inference/async_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio

from colossalai.inference.dynamic_batching.ray_dist_init import Driver

from .dynamic_batching.io_struct import RequestOutput
from .dynamic_batching.sampling_params import SamplingParams


class RequestTracker:
"""
A class for trace down all the requests, abstraction for async
"""

def __init__(self) -> None:
self._requests: asyncio.Queue[str] = asyncio.Queue()
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
self.new_requests_event = None

def __contains__(self, item):
return item in self._requests

def init_event(self):
self.new_requests_event = asyncio.Event()

def add_request(self, request_id: str):
"""Add a request to be sent to the engine on the next background
loop iteration."""
self._requests.put_nowait(request_id)
self.new_requests_event.set() # NOTE: we may find a better way to clear this event

def add_stop(self):
"""
Add a StopIteration flag to stop async generator.
"""
self._finished_requests.put_nowait(StopIteration)
self.new_requests_event.clear()

def process_request_output(self, request_output: RequestOutput) -> None:
"""Process a request output from the engine."""
self._finished_requests.put_nowait(request_output)

async def wait_for_new_requests(self):
await self.new_requests_event.wait()

def __aiter__(self):
return self

async def __anext__(self) -> RequestOutput:
result = await self._finished_requests.get()
# print("result of ", result)
if result is StopIteration:
raise StopAsyncIteration
return result


class Async_Engine:

"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
Request Tracker: manage incoming requests and restore finished ones
Generate: exposed func for add new input and return finished ones
"""

def __init__(
self,
router_config,
engine_config,
start_engine_loop: bool = True,
) -> None:
self.driver = Driver(router_config=router_config, engine_config=engine_config)
self.background_loop = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()

def _step(self):
"""
Logic for handling requests
"""
request_outputs = self.driver.step()
if request_outputs is not None:
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()

def abort_request(self, request_id: str):
self.driver.abort(request_id)

def _has_requests_in_progress(self):
return self.driver.is_running()

async def run_loop_fwd(self):
has_requests_in_progress = self._has_requests_in_progress()
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
self._step()
await asyncio.sleep(0)

@property
def is_running(self):
return self.background_loop is not None and not self.background_loop.done()

def start_background_loop(self):
if self.is_running:
raise RuntimeError("Background loop is already running.")

self._request_tracker.init_event()

self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
self.background_loop = asyncio.shield(self.background_loop_unshielded)

async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.driver.add_input(request_id, prompt, sampling_params)
self._request_tracker.add_request(request_id)

async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
"""
The only exposed func, adding new request and return a async generator that yields the existing results.
"""
try:
if not self.is_running:
self.start_background_loop()

await self.add_request(request_id, prompt, sampling_params)

async for request_output in self._request_tracker:
yield request_output

except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self.abort_request(request_id)
raise e
151 changes: 151 additions & 0 deletions colossalai/inference/async_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import List

from .dynamic_batching.io_struct import Batch, Req, RequestOutput
from .manager import DynamicBatchManager
from .tensor_parallel import TPInferEngine


class Async_DynamicBatchManager(DynamicBatchManager):
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num: int,
batch_max_tokens: int,
model: str,
tokenizer=None,
eos_id=None,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
super().__init__(
tp_engine,
max_total_token_num,
batch_max_tokens,
model,
tokenizer,
eos_id,
log_stats,
log_stats_interval,
running_batch,
waiting_req_list,
)

def _step(self):
"""
Logic for handling requests
"""
has_new_finished = False
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0

else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

if has_new_finished:
return outputs
return None

def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)

# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs
# delete finished reqs

def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs

def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
return self._output_process(finished_reqs)
return None

def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
outputs = []
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
return outputs


def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
batch_manager = Async_DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
raise Exception

return batch_manager
Empty file.
40 changes: 40 additions & 0 deletions colossalai/inference/dynamic_batching/get_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.
license: MIT, see LICENSE for more details.
"""

from transformers import AutoTokenizer

_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
tokenizer=None,
tokenizer_name: str = "",
trust_remote_code: bool = False,
use_fast: bool = True,
):
if tokenizer is not None:
tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai."
)

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
return tokenizer
Loading

0 comments on commit cf579ff

Please sign in to comment.