Skip to content

Commit

Permalink
test: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 19, 2024
1 parent 8769ffb commit 0a38a23
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 179 deletions.
211 changes: 101 additions & 110 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from asyncio import Event, Task
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
from jina._docarray import docarray_v2
import contextlib

if not docarray_v2:
from docarray import DocumentArray
Expand All @@ -25,18 +24,13 @@ def __init__(
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
allow_concurrent: bool = False,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
) -> None:
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
if allow_concurrent and flush_all:
self._data_lock = contextlib.AsyncExitStack()
else:
self._data_lock = contextlib.AsyncExitStack()
self.func = func
if params is None:
params = dict()
Expand Down Expand Up @@ -116,26 +110,24 @@ async def push(self, request: DataRequest, http=False) -> asyncio.Queue:
# this push requests the data lock. The order of accessing the data lock guarantees that this request will be put in the `big_doc`
# before the `flush` task processes it.
self._start_timer()
async with self._data_lock:
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
metric_value = num_docs
if self._custom_metric is not None:
metrics = [self._custom_metric(doc) for doc in docs]
metric_value += sum(metrics)
self._docs_metrics.extend(metrics)
self._metric_value += metric_value
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(num_docs)
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if self._metric_value >= self._preferred_batch_size:
self._flush_trigger.set()
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush(http))
self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
metric_value = num_docs
if self._custom_metric is not None:
metrics = [self._custom_metric(doc) for doc in docs]
metric_value += sum(metrics)
self._docs_metrics.extend(metrics)
self._metric_value += metric_value
self._request_idxs.extend([next_req_idx] * num_docs)
self._request_lens.append(num_docs)
self._requests.append(request)
queue = asyncio.Queue()
self._requests_completed.append(queue)
if self._metric_value >= self._preferred_batch_size:
self._flush_trigger.set()

return queue

Expand Down Expand Up @@ -271,96 +263,76 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option

await self._flush_trigger.wait()
# writes to shared data between tasks need to be mutually exclusive
async with self._data_lock:
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
docs_metrics_in_batch = copy.copy(self._docs_metrics)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

self._reset()

# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.

if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()
big_doc_in_batch = copy.copy(self._big_doc)
requests_idxs_in_batch = copy.copy(self._request_idxs)
requests_lens_in_batch = copy.copy(self._request_lens)
docs_metrics_in_batch = copy.copy(self._docs_metrics)
requests_in_batch = copy.copy(self._requests)
requests_completed_in_batch = copy.copy(self._requests_completed)

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
else:
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
self._reset()

# At this moment, we have documents concatenated in big_doc_in_batch corresponding to requests in
# requests_idxs_in_batch with its lengths stored in requests_lens_in_batch. For each requests, there is a queue to
# communicate that the request has been processed properly.

if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
else:
non_assigned_to_response_docs = self._response_docarray_cls()

non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
input_len_before_call: int = len(docs_inner_batch)
batch_res_docs = None
try:
batch_res_docs = await self.func(
docs=docs_inner_batch,
parameters=self.params,
docs_matrix=None, # joining manually with batch queue is not supported right now
tracing_context=None,
)
# Output validation
if (docarray_v2 and isinstance(batch_res_docs, DocList)) or (
not docarray_v2
and isinstance(batch_res_docs, DocumentArray)
):
if not len(batch_res_docs) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(batch_res_docs)}'
)
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
elif batch_res_docs is None:
if not len(docs_inner_batch) == input_len_before_call:
raise ValueError(
f'Dynamic Batching requires input size to equal output size. Expected output size {input_len_before_call}, but got {len(docs_inner_batch)}'
)
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
array_name = (
'DocumentArray' if not docarray_v2 else 'DocList'
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
raise TypeError(
f'The return type must be {array_name} / `None` when using dynamic batching, '
f'but getting {batch_res_docs!r}'
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
except Exception as exc:
# All the requests containing docs in this Exception should be raising it
for request_full in requests_completed_in_batch[
involved_requests_min_indx: involved_requests_max_indx + 1
]:
await request_full.put(exc)
else:
# We need to attribute the docs to their requests
non_assigned_to_response_docs.extend(
batch_res_docs or docs_inner_batch
)
non_assigned_to_response_request_idxs.extend(req_idxs)
num_assigned_docs = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
Expand All @@ -369,6 +341,25 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
requests_completed_in_batch,
)

sum_from_previous_first_req_idx = (
len(non_assigned_to_response_docs) - num_assigned_docs
)
non_assigned_to_response_docs = non_assigned_to_response_docs[
num_assigned_docs:
]
non_assigned_to_response_request_idxs = (
non_assigned_to_response_request_idxs[num_assigned_docs:]
)
if len(non_assigned_to_response_request_idxs) > 0:
_ = await _assign_results(
non_assigned_to_response_docs,
non_assigned_to_response_request_idxs,
sum_from_previous_first_req_idx,
requests_lens_in_batch,
requests_in_batch,
requests_completed_in_batch,
)


async def close(self):
"""Closes the batch queue by flushing pending requests."""
Expand Down
1 change: 0 additions & 1 deletion jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,6 @@ async def handle(
].response_schema,
output_array_type=self.args.output_array_type,
params=params,
allow_concurrent=self.args.allow_concurrent,
**self._batchqueue_config[exec_endpoint],
)
# This is necessary because push might need to await for the queue to be emptied
Expand Down
38 changes: 2 additions & 36 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,8 @@ def foo(self, docs, **kwargs):


@pytest.mark.asyncio
@pytest.mark.parametrize('use_custom_metric', [True, False])
@pytest.mark.parametrize('flush_all', [False, True])
@pytest.mark.parametrize('use_custom_metric', [True])
@pytest.mark.parametrize('flush_all', [True])
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
class DynCustomBatchProcessor(Executor):

Expand All @@ -766,37 +766,3 @@ def foo(self, docs, **kwargs):
):
res.extend(r)
assert len(res) == 50 # 1 request per input

# If custom_metric and flush all
if use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "10"

elif not use_custom_metric and not flush_all:
for doc in res:
assert doc.text == "50"

elif use_custom_metric and flush_all:
# There will be 2 "10" and the rest will be "240"
num_10 = 0
num_240 = 0
for doc in res:
if doc.text == "10":
num_10 += 1
elif doc.text == "240":
num_240 += 1

assert num_10 == 2
assert num_240 == 48
elif not use_custom_metric and flush_all:
# There will be 10 "50" and the rest will be "200"
num_50 = 0
num_200 = 0
for doc in res:
if doc.text == "50":
num_50 += 1
elif doc.text == "200":
num_200 += 1

assert num_50 == 10
assert num_200 == 40
Loading

0 comments on commit 0a38a23

Please sign in to comment.