diff --git a/client/clip_client/client.py b/client/clip_client/client.py index 108291e27..575b6645f 100644 --- a/client/clip_client/client.py +++ b/client/clip_client/client.py @@ -1,5 +1,7 @@ import mimetypes +import os import time +import warnings from typing import ( overload, TYPE_CHECKING, @@ -12,9 +14,10 @@ ) from urllib.parse import urlparse + if TYPE_CHECKING: - from docarray import DocumentArray, Document import numpy as np + from docarray import DocumentArray, Document class Client: @@ -103,23 +106,50 @@ def encode(self, content, **kwargs): f'content must be an Iterable of [str, Document], try `.encode(["{content}"])` instead' ) - r = self._client.post(**self._get_post_payload(content, kwargs)) - return self._pack_result(r) + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) + with self._pbar: + self._client.post( + **self._get_post_payload(content, kwargs), on_done=self._gather_result + ) + return self._unboxed_result + + def _gather_result(self, r): + from rich import filesize + + if not self._results: + self._pbar.start_task(self._r_task) + r = r.data.docs + self._results.extend(r) + self._pbar.update( + self._r_task, + advance=len(r), + total_size=str( + filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) + ), + ) - def _pack_result(self, r): - if r.embeddings is None: + @property + def _unboxed_result(self): + if self._results.embeddings is None: raise ValueError( 'empty embedding returned from the server. ' 'This often due to a mis-config of the server, ' 'restarting the server or changing the serving port number often solves the problem' ) - return r.embeddings if self._return_plain else r + return self._results.embeddings if self._return_plain else self._results def _iter_doc(self, content) -> Generator['Document', None, None]: + from rich import filesize from docarray import Document self._return_plain = True + if hasattr(self, '_pbar'): + self._pbar.start_task(self._s_task) + for c in content: if isinstance(c, str): self._return_plain = True @@ -141,11 +171,21 @@ def _iter_doc(self, content) -> Generator['Document', None, None]: else: raise TypeError(f'unsupported input type {c!r}') + if hasattr(self, '_pbar'): + self._pbar.update( + self._s_task, + advance=1, + total_size=str( + filesize.decimal( + int(os.environ.get('JINA_GRPC_SEND_BYTES', '0')) + ) + ), + ) + def _get_post_payload(self, content, kwargs): return dict( on='/', inputs=self._iter_doc(content), - show_progress=kwargs.get('show_progress'), request_size=kwargs.get('batch_size', 8), total_docs=len(content) if hasattr(content, '__len__') else None, ) @@ -224,12 +264,72 @@ async def aencode( ... async def aencode(self, content, **kwargs): - from docarray import DocumentArray + from rich import filesize + + self._prepare_streaming( + not kwargs.get('show_progress'), + total=len(content) if hasattr(content, '__len__') else None, + ) - r = DocumentArray() async for da in self._async_client.post( **self._get_post_payload(content, kwargs) ): - r.extend(da) + if not self._results: + self._pbar.start_task(self._r_task) + self._results.extend(da) + self._pbar.update( + self._r_task, + advance=len(da), + total_size=str( + filesize.decimal(int(os.environ.get('JINA_GRPC_RECV_BYTES', '0'))) + ), + ) + + return self._unboxed_result + + def _prepare_streaming(self, disable, total): + + if total is None: + total = 500 + warnings.warn( + 'the length of the input is unknown, the progressbar would not be accurate.' + ) + + from rich.progress import ( + Progress, + BarColumn, + SpinnerColumn, + MofNCompleteColumn, + TextColumn, + TimeRemainingColumn, + ) + + self._pbar = Progress( + SpinnerColumn(), + TextColumn('[bold]{task.description}'), + BarColumn(), + MofNCompleteColumn(), + '•', + TimeRemainingColumn(), + '•', + TextColumn( + '[bold blue]{task.fields[total_size]}', + justify='right', + style='progress.filesize', + ), + transient=True, + disable=disable, + ) + os.environ['JINA_GRPC_SEND_BYTES'] = '0' + os.environ['JINA_GRPC_RECV_BYTES'] = '0' + + self._s_task = self._pbar.add_task( + ':arrow_up: Send', total=total, total_size=0, start=False + ) + self._r_task = self._pbar.add_task( + ':arrow_down: Recv', total=total, total_size=0, start=False + ) + + from docarray import DocumentArray - return self._pack_result(r) + self._results = DocumentArray() diff --git a/client/setup.py b/client/setup.py index 5b395d700..0133370e8 100644 --- a/client/setup.py +++ b/client/setup.py @@ -41,7 +41,7 @@ long_description_content_type='text/markdown', zip_safe=False, setup_requires=['setuptools>=18.0', 'wheel'], - install_requires=['jina', 'docarray[common]>=0.9.18'], + install_requires=['jina>=3.2.10', 'docarray[common]>=0.9.18'], extras_require={ 'test': [ 'pytest', diff --git a/server/setup.py b/server/setup.py index 75bb72f61..afd5a317a 100644 --- a/server/setup.py +++ b/server/setup.py @@ -41,7 +41,7 @@ long_description_content_type='text/markdown', zip_safe=False, setup_requires=['setuptools>=18.0', 'wheel'], - install_requires=['ftfy', 'torch', 'regex', 'torchvision', 'jina'], + install_requires=['ftfy', 'torch', 'regex', 'torchvision', 'jina>=3.2.10'], extras_require={ 'onnx': ['onnxruntime', 'onnx', 'onnxruntime-gpu'], },