Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: drop image content to boost latency #824

Merged
merged 9 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 49 additions & 53 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,15 @@ def _iter_doc(
_mime = mimetypes.guess_type(c)[0]
if _mime and _mime.startswith('image'):
d = Document(
tags={'__created_by_CAS__': True, '__loaded_by_CAS__': True},
uri=c,
).load_uri_to_blob()
else:
d = Document(tags={'__created_by_CAS__': True}, text=c)
d = Document(text=c)
elif isinstance(c, Document):
if c.content_type in ('text', 'blob'):
d = c
elif not c.blob and c.uri:
c.load_uri_to_blob()
c.tags['__loaded_by_CAS__'] = True
d = c
elif c.tensor is not None:
d = c
Expand Down Expand Up @@ -288,8 +286,12 @@ def encode(self, content, **kwargs):

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.pop('model_name', '') if parameters else ''

self._client.post(
on=f'/encode/{model_name}'.rstrip('/'),
**self._get_post_payload(content, results, kwargs),
Expand All @@ -299,10 +301,6 @@ def encode(self, content, **kwargs):
parameters=parameters,
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(results, unbox)

Expand Down Expand Up @@ -345,7 +343,10 @@ async def aencode(self, content, **kwargs):

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''

async for da in self._async_client.post(
Expand All @@ -367,10 +368,6 @@ async def aencode(self, content, **kwargs):
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
return self._unboxed_result(results, unbox)

Expand Down Expand Up @@ -423,7 +420,6 @@ def _prepare_single_doc(d: 'Document'):
return d
elif not d.blob and d.uri:
d.load_uri_to_blob()
d.tags['__loaded_by_CAS__'] = True
return d
elif d.tensor is not None:
return d
Expand All @@ -439,18 +435,6 @@ def _prepare_rank_doc(d: 'Document', _source: str = 'matches'):
setattr(d, _source, [Client._prepare_single_doc(c) for c in _get(d)])
return d

@staticmethod
def _reset_rank_doc(d: 'Document', _source: str = 'matches'):
_get = lambda d: getattr(d, _source)

if d.tags.pop('__loaded_by_CAS__', False):
d.pop('blob')

for c in _get(d):
if c.tags.pop('__loaded_by_CAS__', False):
c.pop('blob')
return d

def rank(
self, docs: Union['DocumentArray', Iterable['Document']], **kwargs
) -> 'DocumentArray':
Expand All @@ -474,8 +458,12 @@ def rank(

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''

self._client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, results, kwargs),
Expand All @@ -485,9 +473,6 @@ def rank(
parameters=parameters,
)

for r in results:
self._reset_rank_doc(r, _source=kwargs.get('source', 'matches'))

return results

async def arank(
Expand All @@ -507,8 +492,12 @@ async def arank(

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)
model_name = parameters.get('model_name', '') if parameters else ''

async for da in self._async_client.post(
on=f'/rank/{model_name}'.rstrip('/'),
**self._get_rank_payload(docs, results, kwargs),
Expand All @@ -528,9 +517,6 @@ async def arank(
),
)

for r in results:
self._reset_rank_doc(r, _source=kwargs.get('source', 'matches'))

return results

@overload
Expand Down Expand Up @@ -581,14 +567,21 @@ def index(self, content, **kwargs):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.index(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', None)
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

self._client.post(
on='/index',
**self._get_post_payload(content, results, kwargs),
Expand All @@ -598,10 +591,6 @@ def index(self, content, **kwargs):
parameters=parameters,
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results

@overload
Expand Down Expand Up @@ -633,17 +622,25 @@ async def aindex(self, content, **kwargs):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.aindex(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

async for da in self._async_client.post(
on='/index',
**self._get_post_payload(content, results, kwargs),
parameters=kwargs.pop('parameters', None),
parameters=parameters,
):
results[da[:, 'id']].embeddings = da.embeddings

Expand All @@ -659,10 +656,6 @@ async def aindex(self, content, **kwargs):
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results

@overload
Expand Down Expand Up @@ -716,15 +709,21 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
raise TypeError(
f'content must be an Iterable of [str, Document], try `.search(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = limit
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

self._client.post(
on='/search',
Expand All @@ -735,10 +734,6 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results

@overload
Expand Down Expand Up @@ -772,16 +767,21 @@ async def asearch(self, content, limit: int = 10, **kwargs):
raise TypeError(
f'content must be an Iterable of [str, Document], try `.asearch(["{content}"])` instead'
)
if hasattr(content, '__len__') and len(content) == 0:
return DocumentArray()

self._prepare_streaming(
not kwargs.get('show_progress'),
total=len(content) if hasattr(content, '__len__') else None,
)
results = DocumentArray()

results = DocumentArray()
with self._pbar:
parameters = kwargs.pop('parameters', {})
parameters['limit'] = limit
parameters['drop_image_content'] = parameters.get(
'drop_image_content', True
)

async for da in self._async_client.post(
on='/search',
Expand All @@ -802,8 +802,4 @@ async def asearch(self, content, limit: int = 10, **kwargs):
),
)

for r in results:
if hasattr(r, 'tags') and r.tags.pop('__loaded_by_CAS__', False):
r.pop('blob')

return results
14 changes: 10 additions & 4 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from multiprocessing.pool import ThreadPool
from typing import Optional, Dict
from functools import partial

import onnxruntime as ort
from clip_server.executors.helper import (
Expand Down Expand Up @@ -99,13 +100,16 @@ def __init__(

self._model.start_sessions(sess_options=sess_options, providers=providers)

def _preproc_images(self, docs: 'DocumentArray'):
def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
with self.monitor(
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
):
return preproc_image(
docs, preprocess_fn=self._image_transform, return_np=True
docs,
preprocess_fn=self._image_transform,
return_np=True,
drop_image_content=drop_image_content,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand All @@ -117,7 +121,8 @@ def _preproc_texts(self, docs: 'DocumentArray'):

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
await self.encode(docs['@r,m'])
_drop_image_content = parameters.get('drop_image_content', False)
await self.encode(docs['@r,m'], drop_image_content=_drop_image_content)

set_rank(docs)

Expand All @@ -129,6 +134,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
Expand All @@ -138,7 +144,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
partial(self._preproc_images, drop_image_content=_drop_image_content),
batch_size=self._minibatch_size,
pool=self._pool,
):
Expand Down
10 changes: 7 additions & 3 deletions server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from multiprocessing.pool import ThreadPool
from typing import Optional, Dict
from functools import partial

import numpy as np
from clip_server.executors.helper import (
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(self._model.image_size)

def _preproc_images(self, docs: 'DocumentArray'):
def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
with self.monitor(
name='preprocess_images_seconds',
documentation='images preprocess time in seconds',
Expand All @@ -77,6 +78,7 @@ def _preproc_images(self, docs: 'DocumentArray'):
preprocess_fn=self._image_transform,
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand All @@ -90,7 +92,8 @@ def _preproc_texts(self, docs: 'DocumentArray'):

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
await self.encode(docs['@r,m'])
_drop_image_content = parameters.get('drop_image_content', False)
await self.encode(docs['@r,m'], drop_image_content=_drop_image_content)

set_rank(docs)

Expand All @@ -102,6 +105,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
f'`traversal_paths` is deprecated. Use `access_paths` instead.'
)
access_paths = parameters['traversal_paths']
_drop_image_content = parameters.get('drop_image_content', False)

_img_da = DocumentArray()
_txt_da = DocumentArray()
Expand All @@ -111,7 +115,7 @@ async def encode(self, docs: 'DocumentArray', parameters: Dict = {}, **kwargs):
# for image
if _img_da:
for minibatch, batch_data in _img_da.map_batch(
self._preproc_images,
partial(self._preproc_images, drop_image_content=_drop_image_content),
batch_size=self._minibatch_size,
pool=self._pool,
):
Expand Down
Loading