Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Test for new Async engine #4935

Merged
merged 17 commits into from
Oct 19, 2023
133 changes: 133 additions & 0 deletions colossalai/inference/async_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio

from colossalai.inference.dynamic_batching.ray_dist_init import Driver

from .dynamic_batching.io_struct import RequestOutput
from .dynamic_batching.sampling_params import SamplingParams


class RequestTracker:
"""
A class for trace down all the requests, abstraction for async
"""

def __init__(self) -> None:
self._requests: asyncio.Queue[str] = asyncio.Queue()
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
self.new_requests_event = None

def __contains__(self, item):
return item in self._requests

def init_event(self):
self.new_requests_event = asyncio.Event()

def add_request(self, request_id: str):
"""Add a request to be sent to the engine on the next background
loop iteration."""
self._requests.put_nowait(request_id)
self.new_requests_event.set() # NOTE: we may find a better way to clear this event

def add_stop(self):
"""
Add a StopIteration flag to stop async generator.
"""
self._finished_requests.put_nowait(StopIteration)
self.new_requests_event.clear()

def process_request_output(self, request_output: RequestOutput) -> None:
"""Process a request output from the engine."""
self._finished_requests.put_nowait(request_output)

async def wait_for_new_requests(self):
await self.new_requests_event.wait()

def __aiter__(self):
return self

async def __anext__(self) -> RequestOutput:
result = await self._finished_requests.get()
# print("result of ", result)
if result is StopIteration:
raise StopAsyncIteration
return result


class Async_Engine:

"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
Request Tracker: manage incoming requests and restore finished ones
Generate: exposed func for add new input and return finished ones
"""

def __init__(
self,
router_config,
engine_config,
start_engine_loop: bool = True,
) -> None:
self.driver = Driver(router_config=router_config, engine_config=engine_config)
self.background_loop = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()

def _step(self):
"""
Logic for handling requests
"""
request_outputs = self.driver.step()
if request_outputs is not None:
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()

def abort(self, request_id: str):
self.driver.abort(request_id)

def _has_requests_in_progress(self):
return self.driver.is_running()

async def run_loop_fwd(self):
has_requests_in_progress = self._has_requests_in_progress()
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
self._step()
await asyncio.sleep(0)

@property
def is_running(self):
return self.background_loop is not None and not self.background_loop.done()

def start_background_loop(self):
if self.is_running:
raise RuntimeError("Background loop is already running.")

self._request_tracker.init_event()

self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
self.background_loop = asyncio.shield(self.background_loop_unshielded)

async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.driver.add_input(request_id, prompt, sampling_params)
self._request_tracker.add_request(request_id)

async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
"""
The only exposed func, adding new request and return a async generator that yields the existing results.
"""
try:
if not self.is_running:
self.start_background_loop()

await self.add_request(request_id, prompt, sampling_params)

async for request_output in self._request_tracker:
yield request_output

except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self.abort_request(request_id)
raise e
150 changes: 150 additions & 0 deletions colossalai/inference/async_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from typing import List

from .dynamic_batching.io_struct import Batch, Req, RequestOutput
from .manager import DynamicBatchManager
from .tensor_parallel import TPInferEngine


class Async_DynamicBatchManager(DynamicBatchManager):
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
eos_id,
model,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
super().__init__(
tp_engine,
max_total_token_num,
batch_max_tokens,
eos_id,
model,
log_stats,
log_stats_interval,
running_batch,
waiting_req_list,
)

def _step(self):
"""
Logic for handling requests
"""
has_new_finished = False
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0

else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

if has_new_finished:
return outputs
return None

def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)

# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs
# delete finished reqs

def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs

def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
return self._output_process(finished_reqs)
return None

def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
outputs = []
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
return outputs


def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
batch_manager = Async_DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
batch_manager.clean_up()
raise

return batch_manager
8 changes: 7 additions & 1 deletion colossalai/inference/dynamic_batching/get_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.

license: MIT, see LICENSE for more details.
"""

from transformers import AutoTokenizer

_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer"
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
Expand Down
16 changes: 8 additions & 8 deletions colossalai/inference/dynamic_batching/infer_batch.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Adapted from https://github.com/ModelTC/lightllm

import collections
from dataclasses import dataclass
from typing import Dict, List , Tuple
from typing import Dict, List, Tuple

import numpy as np
import torch

from colossalai.inference.tensor_parallel import MemoryManager

# make batch infer state an attr of InferBatch


# make batch infer state an attr of InferBatch
class InferSamplingParams:
def __init__(
self,
Expand Down Expand Up @@ -65,7 +66,7 @@ def init_batch(
cache_manager: MemoryManager,
vocab_size: int,
max_total_len: int,
) -> 'InferBatch':
) -> "InferBatch":
input_lengths = []
all_input_ids = []
requests_idx_mapping = {}
Expand All @@ -76,7 +77,7 @@ def init_batch(
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
# to avoid memory leak , we pre-allocate 12 more space for each batch.
# to avoid memory leak , we pre-allocate 12 more space for each batch.
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
for i, r in enumerate(requests):
# request id -> idx in list mapping
Expand Down Expand Up @@ -142,10 +143,9 @@ def free_self(self) -> None:
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)


@torch.no_grad()
def filter(self, request_ids: List[int]) -> 'InferBatch':
def filter(self, request_ids: List[int]) -> "InferBatch":
"""
Filter finished batch and return a new InferBatch with left ones.
"""
Expand Down Expand Up @@ -226,7 +226,7 @@ def filter(self, request_ids: List[int]) -> 'InferBatch':

@classmethod
@torch.no_grad()
def merge(cls, batch1, batch2) -> 'InferBatch':
def merge(cls, batch1, batch2) -> "InferBatch":
"""
Return megerd new InferBatch
"""
Expand Down
Loading
Loading