Skip to content

Commit

Permalink
fix: in-place for index and search
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiniuYu committed Sep 13, 2022
1 parent 9825200 commit 69ef7b4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
20 changes: 13 additions & 7 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,13 @@ def _prepare_streaming(self, disable, total):
':arrow_down: Recv', total=total, total_size=0, start=False
)

def _gather_result(self, response, results: 'DocumentArray', attribute: str = ''):
def _gather_result(
self, response, results: 'DocumentArray', attribute: Optional[str] = ''
):
from rich import filesize

r = response.data.docs
if not attribute:
if attribute:
results[r[:, 'id']][:, attribute] = r[:, attribute]

if not self._pbar._tasks[self._r_task].started:
Expand Down Expand Up @@ -580,7 +582,9 @@ def index(self, content, **kwargs):
self._client.post(
on='/index',
**self._get_post_payload(content, results, kwargs),
on_done=partial(self._gather_result, results=results),
on_done=partial(
self._gather_result, results=results, attribute='embedding'
),
parameters=parameters,
)

Expand Down Expand Up @@ -626,9 +630,10 @@ async def aindex(self, content, **kwargs):
**self._get_post_payload(content, results, kwargs),
parameters=kwargs.pop('parameters', None),
):
if not results:
results[da[:, 'id']].embeddings = da.embeddings

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
Expand Down Expand Up @@ -763,9 +768,10 @@ async def asearch(self, content, **kwargs):
**self._get_post_payload(content, results, kwargs),
parameters=parameters,
):
if not results:
results[da[:, 'id']][:, 'matches'] = da[:, 'matches']

if not self._pbar._tasks[self._r_task].started:
self._pbar.start_task(self._r_task)
results.extend(da)
self._pbar.update(
self._r_task,
advance=len(da),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def generate_docs(tag):
assert len(set([d.id[:2] for d in r])) == 1


def test_client_large_input(make_flow):
def test_client_large_input(make_torch_flow):
from clip_client.client import Client

inputs = ['hello' for _ in range(600)]

c = Client(server=f'grpc://0.0.0.0:{make_flow.port}')
c = Client(server=f'grpc://0.0.0.0:{make_torch_flow.port}')
with pytest.warns(UserWarning):
c.encode(inputs if not callable(inputs) else inputs())

Expand Down

0 comments on commit 69ef7b4

Please sign in to comment.