From b1139bc62336301b439cc6a4f69bfc81fba3e687 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Wed, 25 Sep 2024 12:51:09 +0200 Subject: [PATCH] fix: dyn batching configs (#6204) --- .../serve/runtimes/worker/request_handling.py | 27 +++++++---- .../dynamic_batching/test_dynamic_batching.py | 45 +++++++++++++++++++ 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index b6edd7cddc090..a813e60bddb95 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -263,6 +263,9 @@ def _init_batchqueue_dict(self): if getattr(self._executor, 'dynamic_batching', None) is not None: # We need to sort the keys into endpoints and functions # Endpoints allow specific configurations while functions allow configs to be applied to all endpoints of the function + self.logger.debug( + f'Executor Dynamic Batching configs: {self._executor.dynamic_batching}' + ) dbatch_endpoints = [] dbatch_functions = [] request_models_map = self._executor._get_endpoint_models_dict() @@ -275,11 +278,10 @@ def _init_batchqueue_dict(self): ) raise Exception(error_msg) - if dbatch_config.get("use_dynamic_batching", True): - if key.startswith('/'): - dbatch_endpoints.append((key, dbatch_config)) - else: - dbatch_functions.append((key, dbatch_config)) + if key.startswith('/'): + dbatch_endpoints.append((key, dbatch_config)) + else: + dbatch_functions.append((key, dbatch_config)) # Specific endpoint configs take precedence over function configs for endpoint, dbatch_config in dbatch_endpoints: @@ -295,10 +297,19 @@ def _init_batchqueue_dict(self): for endpoint in func_endpoints[func_name]: if endpoint not in self._batchqueue_config: self._batchqueue_config[endpoint] = dbatch_config + else: + # we need to eventually copy the `custom_metric` + if dbatch_config.get('custom_metric', None) is not None: + self._batchqueue_config[endpoint]['custom_metric'] = dbatch_config.get('custom_metric') + + keys_to_remove = [] + for k, batch_config in self._batchqueue_config.items(): + if not batch_config.get('use_dynamic_batching', True): + keys_to_remove.append(k) + + for k in keys_to_remove: + self._batchqueue_config.pop(k) - self.logger.debug( - f'Executor Dynamic Batching configs: {self._executor.dynamic_batching}' - ) self.logger.debug( f'Endpoint Batch Queue Configs: {self._batchqueue_config}' ) diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 8f08d364899a4..0e42785d1b8be 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -244,6 +244,51 @@ def test_timeout(add_parameters, use_stream): assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly' +@pytest.mark.parametrize( + 'add_parameters', + [ + { + 'uses': PlaceholderExecutorWrongDecorator, + 'uses_dynamic_batching': USES_DYNAMIC_BATCHING_PLACE_HOLDER_EXECUTOR, + } + ], +) +@pytest.mark.parametrize('use_stream', [False, True]) +@pytest.mark.parametrize('use_dynamic_batching', [False, True]) +def test_timeout_no_use(add_parameters, use_stream, use_dynamic_batching): + for k, v in add_parameters["uses_dynamic_batching"].items(): + v["use_dynamic_batching"] = use_dynamic_batching + f = Flow().add(**add_parameters) + with f: + start_time = time.time() + f.post('/bar', inputs=DocumentArray.empty(2), stream=use_stream) + time_taken = time.time() - start_time + if use_dynamic_batching: + assert time_taken > 2, 'Timeout ended too fast' + assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly' + else: + assert time_taken < 2 + + with mp.Pool(3) as p: + start_time = time.time() + list( + p.map( + call_api, + [ + RequestStruct(f.port, '/bar', range(1), use_stream), + RequestStruct(f.port, '/bar', range(1), not use_stream), + RequestStruct(f.port, '/bar', range(1), use_stream), + ], + ) + ) + time_taken = time.time() - start_time + if use_dynamic_batching: + assert time_taken > 2, 'Timeout ended too fast' + assert time_taken < 2 + TIMEOUT_TOLERANCE, 'Timeout ended too slowly' + else: + assert time_taken < 2 + + @pytest.mark.parametrize( 'add_parameters', [