Skip to content

Commit

Permalink
fix: set direct docs (#6183)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Jul 23, 2024
1 parent b28c633 commit 4eaeb2d
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 40 deletions.
2 changes: 1 addition & 1 deletion jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _result_handler(result):

resp = DataRequest(r_str)
if da is not None:
resp.data.docs = da
resp.direct_docs = da

callback_exec(
response=resp,
Expand Down
6 changes: 3 additions & 3 deletions jina/clients/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ async def _get_results(*args, **kwargs):
if return_responses:
result.append(resp)
else:
result.extend(resp.data.docs)
result.extend(resp.docs)
if return_results:
if not return_responses and is_singleton and len(result) == 1:
return result[0]
Expand All @@ -438,6 +438,7 @@ async def _get_results(*args, **kwargs):
results_in_order=results_in_order,
stream=stream,
prefetch=prefetch,
return_type=return_type,
on=on,
**kwargs,
)
Expand Down Expand Up @@ -507,7 +508,6 @@ async def post(
c.continue_on_error = continue_on_error

parameters = _include_results_field_in_param(parameters)

async for result in c._get_results(
on=on,
inputs=inputs,
Expand Down Expand Up @@ -538,7 +538,7 @@ async def post(
is_singleton = True
result.document_array_cls = DocList[return_type]
if not return_responses:
ret_docs = result.data.docs
ret_docs = result.docs
if is_singleton and len(ret_docs) == 1:
yield ret_docs[0]
else:
Expand Down
13 changes: 6 additions & 7 deletions jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ async def post(body: input_model, response: Response):
data = body.data
if isinstance(data, list):
if not docarray_v2:
req.data.docs = DocumentArray.from_pydantic_model(data)
req.direct_docs = DocumentArray.from_pydantic_model(data)
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_list_model](data)
req.direct_docs = DocList[input_doc_list_model](data)
else:
if not docarray_v2:
req.data.docs = DocumentArray([Document.from_pydantic_model(data)])
req.direct_docs = DocumentArray([Document.from_pydantic_model(data)])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_list_model]([data])
req.direct_docs = DocList[input_doc_list_model]([data])
if body.header is None:
req.header.request_id = req.docs[0].id

Expand All @@ -122,7 +122,6 @@ async def post(body: input_model, response: Response):
docs_response = resp.docs.to_dict()
else:
docs_response = resp.docs

ret = output_model(data=docs_response, parameters=resp.parameters)

return ret
Expand Down Expand Up @@ -152,10 +151,10 @@ async def streaming_get(request: Request = None, body: input_doc_model = None):
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
req.data.docs = DocumentArray([body])
req.direct_docs = DocumentArray([body])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_model]([body])
req.direct_docs = DocList[input_doc_model]([body])
event_generator = _gen_dict_documents(await caller(req))
return EventSourceResponse(event_generator)

Expand Down
33 changes: 19 additions & 14 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def call_handle(request):
'is_generator'
]

return self.process_single_data(request, None, is_generator=is_generator)
return self.process_single_data(request, None, http=True, is_generator=is_generator)

app = get_fastapi_app(
request_models_map=request_models_map, caller=call_handle, **kwargs
Expand All @@ -201,7 +201,7 @@ def call_handle(request):
'is_generator'
]

return self.process_single_data(request, None, is_generator=is_generator)
return self.process_single_data(request, None, http=True, is_generator=is_generator)

app = get_fastapi_app(
request_models_map=request_models_map, caller=call_handle, **kwargs
Expand Down Expand Up @@ -548,7 +548,7 @@ def _record_response_size_monitoring(self, requests):
requests[0].nbytes, attributes=attributes
)

def _set_result(self, requests, return_data, docs):
def _set_result(self, requests, return_data, docs, http=False):
# assigning result back to request
if return_data is not None:
if isinstance(return_data, DocumentArray):
Expand All @@ -568,10 +568,12 @@ def _set_result(self, requests, return_data, docs):
f'The return type must be DocList / Dict / `None`, '
f'but getting {return_data!r}'
)

WorkerRequestHandler.replace_docs(
requests[0], docs, self.args.output_array_type
)
if not http:
WorkerRequestHandler.replace_docs(
requests[0], docs, self.args.output_array_type
)
else:
requests[0].direct_docs = docs
return docs

def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False):
Expand Down Expand Up @@ -659,11 +661,12 @@ async def handle_generator(
)

async def handle(
self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None
self, requests: List['DataRequest'], http=False, tracing_context: Optional['Context'] = None
) -> DataRequest:
"""Initialize private parameters and execute private loading functions.
:param requests: The messages to handle containing a DataRequest
:param http: Flag indicating if it is used by the HTTP server for some optims
:param tracing_context: Optional OpenTelemetry tracing context from the originating request.
:returns: the processed message
"""
Expand Down Expand Up @@ -721,7 +724,7 @@ async def handle(
docs_map=docs_map,
tracing_context=tracing_context,
)
_ = self._set_result(requests, return_data, docs)
_ = self._set_result(requests, return_data, docs, http=http)

for req in requests:
req.add_executor(self.deployment_name)
Expand Down Expand Up @@ -909,18 +912,19 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest':

# serving part
async def process_single_data(
self, request: DataRequest, context, is_generator: bool = False
self, request: DataRequest, context, http: bool = False, is_generator: bool = False
) -> DataRequest:
"""
Process the received requests and return the result as a new request
:param request: the data request to process
:param context: grpc context
:param http: Flag indicating if it is used by the HTTP server for some optims
:param is_generator: whether the request should be handled with streaming
:returns: the response request
"""
self.logger.debug('recv a process_single_data request')
return await self.process_data([request], context, is_generator=is_generator)
return await self.process_data([request], context, http=http, is_generator=is_generator)

async def stream_doc(
self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext'
Expand Down Expand Up @@ -1065,13 +1069,14 @@ def _extract_tracing_context(
return None

async def process_data(
self, requests: List[DataRequest], context, is_generator: bool = False
self, requests: List[DataRequest], context, http=False, is_generator: bool = False
) -> DataRequest:
"""
Process the received requests and return the result as a new request
:param requests: the data requests to process
:param context: grpc context
:param http: Flag indicating if it is used by the HTTP server for some optims
:param is_generator: whether the request should be handled with streaming
:returns: the response request
"""
Expand All @@ -1094,11 +1099,11 @@ async def process_data(

if is_generator:
result = await self.handle_generator(
requests=requests, tracing_context=tracing_context
requests=requests,tracing_context=tracing_context
)
else:
result = await self.handle(
requests=requests, tracing_context=tracing_context
requests=requests, http=http, tracing_context=tracing_context
)

if self._successful_requests_metrics:
Expand Down
14 changes: 12 additions & 2 deletions jina/types/request/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(
self._pb_body = None
self._document_array_cls = DocumentArray
self._data = None
# to be used to bypass proto extra transforms
self.direct_docs = None

try:
if isinstance(request, jina_pb2.DataRequestProto):
Expand Down Expand Up @@ -275,7 +277,10 @@ def docs(self) -> 'DocumentArray':
"""Get the :class: `DocumentArray` with sequence `data.docs` as content.
.. # noqa: DAR201"""
return self.data.docs
if self.direct_docs is not None:
return self.direct_docs
else:
return self.data.docs

@property
def data(self) -> 'DataRequest._DataContent':
Expand Down Expand Up @@ -441,6 +446,8 @@ def __init__(
self._document_cls = Document
self.buffer = None
self._data = None
# to be used to bypass proto extra transforms
self.direct_doc = None

try:
if isinstance(request, jina_pb2.SingleDocumentRequestProto):
Expand Down Expand Up @@ -606,7 +613,10 @@ def doc(self) -> 'Document':
"""Get the :class: `DocumentArray` with sequence `data.docs` as content.
.. # noqa: DAR201"""
return self.data.doc
if self.direct_doc is not None:
return self.direct_doc
else:
return self.data.doc

@property
def data(self) -> 'SingleDocumentRequest._DataContent':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _assert_all_docs_processed(port, num_docs, endpoint):
target=f'0.0.0.0:{port}',
endpoint=endpoint,
)
docs = resp.data.docs
docs = resp.docs
assert docs.texts == ['long timeout' for _ in range(num_docs)]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_flow_returned_collect(protocol, port_generator):
def validate_func(resp):
num_evaluations = 0
scores = set()
for doc in resp.data.docs:
for doc in resp.docs:
num_evaluations += len(doc.evaluations)
scores.add(doc.evaluations['evaluate'].value)
assert num_evaluations == 1
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def foo(docs, **kwargs):
three_data_requests = [DataRequest() for _ in range(3)]
for req in three_data_requests:
req.data.docs = DocumentArray.empty(1)
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''

async def process_request(req):
q = await bq.push(req)
Expand All @@ -42,20 +42,20 @@ async def process_request(req):
assert time_spent >= 2000
# Test that since no more docs arrived, the function was triggerred after timeout
for resp in responses:
assert resp.data.docs[0].text == 'Done'
assert resp.docs[0].text == 'Done'

four_data_requests = [DataRequest() for _ in range(4)]
for req in four_data_requests:
req.data.docs = DocumentArray.empty(1)
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''
init_time = time.time()
tasks = [asyncio.create_task(process_request(req)) for req in four_data_requests]
responses = await asyncio.gather(*tasks)
time_spent = (time.time() - init_time) * 1000
assert time_spent < 2000
# Test that since no more docs arrived, the function was triggerred after timeout
for resp in responses:
assert resp.data.docs[0].text == 'Done'
assert resp.docs[0].text == 'Done'

await bq.close()

Expand Down Expand Up @@ -135,7 +135,7 @@ async def foo(docs, **kwargs):
data_requests = [DataRequest() for _ in range(3)]
for req in data_requests:
req.data.docs = DocumentArray.empty(10) # 30 docs in total
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''

async def process_request(req):
q = await bq.push(req)
Expand All @@ -150,7 +150,7 @@ async def process_request(req):
assert time_spent < 2000
# Test that since no more docs arrived, the function was triggerred after timeout
for resp in responses:
assert resp.data.docs[0].text == 'Done'
assert resp.docs[0].text == 'Done'

await bq.close()

Expand Down Expand Up @@ -196,9 +196,9 @@ async def process_request(req):
assert isinstance(item, Exception)
for i, req in enumerate(data_requests):
if i not in BAD_REQUEST_IDX:
assert req.data.docs[0].text == f'{i} Processed'
assert req.docs[0].text == f'{i} Processed'
else:
assert req.data.docs[0].text == 'Bad'
assert req.docs[0].text == 'Bad'


@pytest.mark.asyncio
Expand Down Expand Up @@ -246,11 +246,11 @@ async def process_request(req):
assert isinstance(item, Exception)
for i, req in enumerate(data_requests):
if i not in EXPECTED_BAD_REQUESTS:
assert req.data.docs[0].text == 'Processed'
assert req.docs[0].text == 'Processed'
elif i in TRIGGER_BAD_REQUEST_IDX:
assert req.data.docs[0].text == 'Bad'
assert req.docs[0].text == 'Bad'
else:
assert req.data.docs[0].text == ''
assert req.docs[0].text == ''


@pytest.mark.asyncio
Expand Down

0 comments on commit 4eaeb2d

Please sign in to comment.