diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index 653a98f051629..c55156bf69365 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -219,7 +219,7 @@ def _result_handler(result): resp = DataRequest(r_str) if da is not None: - resp.data.docs = da + resp.direct_docs = da callback_exec( response=resp, diff --git a/jina/clients/mixin.py b/jina/clients/mixin.py index ec0c52049d200..a6960fa355f63 100644 --- a/jina/clients/mixin.py +++ b/jina/clients/mixin.py @@ -414,7 +414,7 @@ async def _get_results(*args, **kwargs): if return_responses: result.append(resp) else: - result.extend(resp.data.docs) + result.extend(resp.docs) if return_results: if not return_responses and is_singleton and len(result) == 1: return result[0] @@ -438,6 +438,7 @@ async def _get_results(*args, **kwargs): results_in_order=results_in_order, stream=stream, prefetch=prefetch, + return_type=return_type, on=on, **kwargs, ) @@ -507,7 +508,6 @@ async def post( c.continue_on_error = continue_on_error parameters = _include_results_field_in_param(parameters) - async for result in c._get_results( on=on, inputs=inputs, @@ -538,7 +538,7 @@ async def post( is_singleton = True result.document_array_cls = DocList[return_type] if not return_responses: - ret_docs = result.data.docs + ret_docs = result.docs if is_singleton and len(ret_docs) == 1: yield ret_docs[0] else: diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index b45b94f7c62cf..889166d8aeb63 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -99,16 +99,16 @@ async def post(body: input_model, response: Response): data = body.data if isinstance(data, list): if not docarray_v2: - req.data.docs = DocumentArray.from_pydantic_model(data) + req.direct_docs = DocumentArray.from_pydantic_model(data) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model](data) + req.direct_docs = DocList[input_doc_list_model](data) else: if not docarray_v2: - req.data.docs = DocumentArray([Document.from_pydantic_model(data)]) + req.direct_docs = DocumentArray([Document.from_pydantic_model(data)]) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model]([data]) + req.direct_docs = DocList[input_doc_list_model]([data]) if body.header is None: req.header.request_id = req.docs[0].id @@ -122,7 +122,6 @@ async def post(body: input_model, response: Response): docs_response = resp.docs.to_dict() else: docs_response = resp.docs - ret = output_model(data=docs_response, parameters=resp.parameters) return ret @@ -152,10 +151,10 @@ async def streaming_get(request: Request = None, body: input_doc_model = None): req = DataRequest() req.header.exec_endpoint = endpoint_path if not docarray_v2: - req.data.docs = DocumentArray([body]) + req.direct_docs = DocumentArray([body]) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_model]([body]) + req.direct_docs = DocList[input_doc_model]([body]) event_generator = _gen_dict_documents(await caller(req)) return EventSourceResponse(event_generator) diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index af3786f2886d3..7d9958c35c049 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -177,7 +177,7 @@ def call_handle(request): 'is_generator' ] - return self.process_single_data(request, None, is_generator=is_generator) + return self.process_single_data(request, None, http=True, is_generator=is_generator) app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs @@ -201,7 +201,7 @@ def call_handle(request): 'is_generator' ] - return self.process_single_data(request, None, is_generator=is_generator) + return self.process_single_data(request, None, http=True, is_generator=is_generator) app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs @@ -548,7 +548,7 @@ def _record_response_size_monitoring(self, requests): requests[0].nbytes, attributes=attributes ) - def _set_result(self, requests, return_data, docs): + def _set_result(self, requests, return_data, docs, http=False): # assigning result back to request if return_data is not None: if isinstance(return_data, DocumentArray): @@ -568,10 +568,12 @@ def _set_result(self, requests, return_data, docs): f'The return type must be DocList / Dict / `None`, ' f'but getting {return_data!r}' ) - - WorkerRequestHandler.replace_docs( - requests[0], docs, self.args.output_array_type - ) + if not http: + WorkerRequestHandler.replace_docs( + requests[0], docs, self.args.output_array_type + ) + else: + requests[0].direct_docs = docs return docs def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): @@ -659,11 +661,12 @@ async def handle_generator( ) async def handle( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], http=False, tracing_context: Optional['Context'] = None ) -> DataRequest: """Initialize private parameters and execute private loading functions. :param requests: The messages to handle containing a DataRequest + :param http: Flag indicating if it is used by the HTTP server for some optims :param tracing_context: Optional OpenTelemetry tracing context from the originating request. :returns: the processed message """ @@ -721,7 +724,7 @@ async def handle( docs_map=docs_map, tracing_context=tracing_context, ) - _ = self._set_result(requests, return_data, docs) + _ = self._set_result(requests, return_data, docs, http=http) for req in requests: req.add_executor(self.deployment_name) @@ -909,18 +912,19 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest': # serving part async def process_single_data( - self, request: DataRequest, context, is_generator: bool = False + self, request: DataRequest, context, http: bool = False, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request :param request: the data request to process :param context: grpc context + :param http: Flag indicating if it is used by the HTTP server for some optims :param is_generator: whether the request should be handled with streaming :returns: the response request """ self.logger.debug('recv a process_single_data request') - return await self.process_data([request], context, is_generator=is_generator) + return await self.process_data([request], context, http=http, is_generator=is_generator) async def stream_doc( self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' @@ -1065,13 +1069,14 @@ def _extract_tracing_context( return None async def process_data( - self, requests: List[DataRequest], context, is_generator: bool = False + self, requests: List[DataRequest], context, http=False, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request :param requests: the data requests to process :param context: grpc context + :param http: Flag indicating if it is used by the HTTP server for some optims :param is_generator: whether the request should be handled with streaming :returns: the response request """ @@ -1094,11 +1099,11 @@ async def process_data( if is_generator: result = await self.handle_generator( - requests=requests, tracing_context=tracing_context + requests=requests,tracing_context=tracing_context ) else: result = await self.handle( - requests=requests, tracing_context=tracing_context + requests=requests, http=http, tracing_context=tracing_context ) if self._successful_requests_metrics: diff --git a/jina/types/request/data.py b/jina/types/request/data.py index c3fd12822e8c1..9c936833f376f 100644 --- a/jina/types/request/data.py +++ b/jina/types/request/data.py @@ -114,6 +114,8 @@ def __init__( self._pb_body = None self._document_array_cls = DocumentArray self._data = None + # to be used to bypass proto extra transforms + self.direct_docs = None try: if isinstance(request, jina_pb2.DataRequestProto): @@ -275,7 +277,10 @@ def docs(self) -> 'DocumentArray': """Get the :class: `DocumentArray` with sequence `data.docs` as content. .. # noqa: DAR201""" - return self.data.docs + if self.direct_docs is not None: + return self.direct_docs + else: + return self.data.docs @property def data(self) -> 'DataRequest._DataContent': @@ -441,6 +446,8 @@ def __init__( self._document_cls = Document self.buffer = None self._data = None + # to be used to bypass proto extra transforms + self.direct_doc = None try: if isinstance(request, jina_pb2.SingleDocumentRequestProto): @@ -606,7 +613,10 @@ def doc(self) -> 'Document': """Get the :class: `DocumentArray` with sequence `data.docs` as content. .. # noqa: DAR201""" - return self.data.doc + if self.direct_doc is not None: + return self.direct_doc + else: + return self.data.doc @property def data(self) -> 'SingleDocumentRequest._DataContent': diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 355e771c52fc7..0a9bf57847e8c 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -528,7 +528,7 @@ def _assert_all_docs_processed(port, num_docs, endpoint): target=f'0.0.0.0:{port}', endpoint=endpoint, ) - docs = resp.data.docs + docs = resp.docs assert docs.texts == ['long timeout' for _ in range(num_docs)] diff --git a/tests/integration/inspect_deployments_flow/test_inspect_deployments_flow.py b/tests/integration/inspect_deployments_flow/test_inspect_deployments_flow.py index 84d6443a7a5c7..d1e422b1f9a8d 100644 --- a/tests/integration/inspect_deployments_flow/test_inspect_deployments_flow.py +++ b/tests/integration/inspect_deployments_flow/test_inspect_deployments_flow.py @@ -145,7 +145,7 @@ def test_flow_returned_collect(protocol, port_generator): def validate_func(resp): num_evaluations = 0 scores = set() - for doc in resp.data.docs: + for doc in resp.docs: num_evaluations += len(doc.evaluations) scores.add(doc.evaluations['evaluate'].value) assert num_evaluations == 1 diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index 2d0a172ca5a27..bb922ed60d970 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -27,7 +27,7 @@ async def foo(docs, **kwargs): three_data_requests = [DataRequest() for _ in range(3)] for req in three_data_requests: req.data.docs = DocumentArray.empty(1) - assert req.data.docs[0].text == '' + assert req.docs[0].text == '' async def process_request(req): q = await bq.push(req) @@ -42,12 +42,12 @@ async def process_request(req): assert time_spent >= 2000 # Test that since no more docs arrived, the function was triggerred after timeout for resp in responses: - assert resp.data.docs[0].text == 'Done' + assert resp.docs[0].text == 'Done' four_data_requests = [DataRequest() for _ in range(4)] for req in four_data_requests: req.data.docs = DocumentArray.empty(1) - assert req.data.docs[0].text == '' + assert req.docs[0].text == '' init_time = time.time() tasks = [asyncio.create_task(process_request(req)) for req in four_data_requests] responses = await asyncio.gather(*tasks) @@ -55,7 +55,7 @@ async def process_request(req): assert time_spent < 2000 # Test that since no more docs arrived, the function was triggerred after timeout for resp in responses: - assert resp.data.docs[0].text == 'Done' + assert resp.docs[0].text == 'Done' await bq.close() @@ -135,7 +135,7 @@ async def foo(docs, **kwargs): data_requests = [DataRequest() for _ in range(3)] for req in data_requests: req.data.docs = DocumentArray.empty(10) # 30 docs in total - assert req.data.docs[0].text == '' + assert req.docs[0].text == '' async def process_request(req): q = await bq.push(req) @@ -150,7 +150,7 @@ async def process_request(req): assert time_spent < 2000 # Test that since no more docs arrived, the function was triggerred after timeout for resp in responses: - assert resp.data.docs[0].text == 'Done' + assert resp.docs[0].text == 'Done' await bq.close() @@ -196,9 +196,9 @@ async def process_request(req): assert isinstance(item, Exception) for i, req in enumerate(data_requests): if i not in BAD_REQUEST_IDX: - assert req.data.docs[0].text == f'{i} Processed' + assert req.docs[0].text == f'{i} Processed' else: - assert req.data.docs[0].text == 'Bad' + assert req.docs[0].text == 'Bad' @pytest.mark.asyncio @@ -246,11 +246,11 @@ async def process_request(req): assert isinstance(item, Exception) for i, req in enumerate(data_requests): if i not in EXPECTED_BAD_REQUESTS: - assert req.data.docs[0].text == 'Processed' + assert req.docs[0].text == 'Processed' elif i in TRIGGER_BAD_REQUEST_IDX: - assert req.data.docs[0].text == 'Bad' + assert req.docs[0].text == 'Bad' else: - assert req.data.docs[0].text == '' + assert req.docs[0].text == '' @pytest.mark.asyncio