diff --git a/client/clip_client/client.py b/client/clip_client/client.py index cafa7b4fb..52eba3abb 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -144,11 +144,13 @@ def _prepare_streaming(self, disable, total): ':arrow_down: Recv', total=total, total_size=0, start=False ) - def _gather_result(self, response, results: 'DocumentArray', attribute: str = ''): + def _gather_result( + self, response, results: 'DocumentArray', attribute: Optional[str] = '' + ): from rich import filesize r = response.data.docs - if not attribute: + if attribute: results[r[:, 'id']][:, attribute] = r[:, attribute] if not self._pbar._tasks[self._r_task].started: @@ -580,7 +582,9 @@ def index(self, content, **kwargs): self._client.post( on='/index', **self._get_post_payload(content, results, kwargs), - on_done=partial(self._gather_result, results=results), + on_done=partial( + self._gather_result, results=results, attribute='embedding' + ), parameters=parameters, ) @@ -626,9 +630,10 @@ async def aindex(self, content, **kwargs): **self._get_post_payload(content, results, kwargs), parameters=kwargs.pop('parameters', None), ): - if not results: + results[da[:, 'id']].embeddings = da.embeddings + + if not self._pbar._tasks[self._r_task].started: self._pbar.start_task(self._r_task) - results.extend(da) self._pbar.update( self._r_task, advance=len(da), @@ -763,9 +768,10 @@ async def asearch(self, content, **kwargs): **self._get_post_payload(content, results, kwargs), parameters=parameters, ): - if not results: + results[da[:, 'id']][:, 'matches'] = da[:, 'matches'] + + if not self._pbar._tasks[self._r_task].started: self._pbar.start_task(self._r_task) - results.extend(da) self._pbar.update( self._r_task, advance=len(da), diff --git a/tests/test_client.py b/tests/test_client.py index 13ea24420..bb5f640d3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -57,12 +57,12 @@ def generate_docs(tag): assert len(set([d.id[:2] for d in r])) == 1 -def test_client_large_input(make_flow): +def test_client_large_input(make_torch_flow): from clip_client.client import Client inputs = ['hello' for _ in range(600)] - c = Client(server=f'grpc://0.0.0.0:{make_flow.port}') + c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}') with pytest.warns(UserWarning): c.encode(inputs if not callable(inputs) else inputs())