Skip to content

Commit

Permalink
test: test no data lock in batch queue (#6201)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Sep 20, 2024
1 parent 246f596 commit 3894353
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 221 deletions.
214 changes: 103 additions & 111 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
from jina._docarray import docarray_v2
import contextlib

if not docarray_v2:
from docarray import DocumentArray
Expand All @@ -25,18 +24,13 @@ def __init__(
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
) -> None:
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
if allow_concurrent and flush_all:
self._data_lock = contextlib.AsyncExitStack()
else:
self._data_lock = asyncio.Lock()
self.func = func
if params is None:
params = dict()
Expand Down Expand Up @@ -64,7 +58,7 @@ def __str__(self) -> str:
def _reset(self) -> None:
"""Set all events and reset the batch queue."""
self._requests: List[DataRequest] = []
# a list of every request ID
# a list of every request idx inside self._requests
self._request_idxs: List[int] = []
self._request_lens: List[int] = []
self._docs_metrics: List[int] = []
Expand Down Expand Up @@ -116,26 +110,24 @@ async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
# this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc`
# before the `flush` task processes it.
self._start_timer()
async with self._data_lock:
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
metric_value = num_docs
if self._custom_metric is not None:
metrics = [self._custom_metric(doc) for doc in docs]
metric_value += sum(metrics)
self._docs_metrics.extend(metrics)
self._metric_value += metric_value
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(num_docs)
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if self._metric_value >= self._preferred_batch_size:
self._flush_trigger.set()
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))
self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
metric_value = num_docs
if self._custom_metric is not None:
metrics = [self._custom_metric(doc) for doc in docs]
metric_value += sum(metrics)
self._docs_metrics.extend(metrics)
self._metric_value += metric_value
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(num_docs)
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if self._metric_value >= self._preferred_batch_size:
self._flush_trigger.set()

return queue

Expand Down Expand Up @@ -271,96 +263,76 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
async with self._data_lock:
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
docs_metrics_in_batch = copy.copy(self._docs_metrics)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

self._reset()

# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.

if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
docs_metrics_in_batch = copy.copy(self._docs_metrics)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
else:
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
self._reset()

# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.

if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
Expand All @@ -369,6 +341,26 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
requests_completed_in_batch,
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
)


async def close(self):
"""Closes the batch queue by flushing pending requests."""
if not self._is_closed:
Expand Down
1 change: 0 additions & 1 deletion jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,6 @@ async def handle(
].response_schema,
output_array_type=self.args.output_array_type,
params=params,
allow_concurrent=self.args.allow_concurrent,
**self._batchqueue_config[exec_endpoint],
)
# This is necessary because push might need to await for the queue to be emptied
Expand Down
Loading

0 comments on commit 3894353

Please sign in to comment.