Skip to content

Commit

Permalink
Add tensor_fields parameter to add_documents (#123)
Browse files Browse the repository at this point in the history
add_documents will now accept tensor_fields. When this is provided, all fields not in tensor_fields are treated as non-tensor fields. non_tensor_fields is now deprecated and if used, py-marqo will warn the user.
  • Loading branch information
farshidz authored Jul 17, 2023
1 parent a7836c2 commit 02aefeb
Show file tree
Hide file tree
Showing 18 changed files with 157 additions and 116 deletions.
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,19 @@ import marqo
mq = marqo.Client(url='http://localhost:8882')

mq.create_index("my-first-index")
mq.index("my-first-index").add_documents([
{
"Title": "The Travels of Marco Polo",
"Description": "A 13th-century travelogue describing Polo's travels"
},
{
"Title": "Extravehicular Mobility Unit (EMU)",
"Description": "The EMU is a spacesuit that provides environmental protection, "
"mobility, life support, and communications for astronauts",
"_id": "article_591"
}]
mq.index("my-first-index").add_documents(
[
{
"Title": "The Travels of Marco Polo",
"Description": "A 13th-century travelogue describing Polo's travels"},
{
"Title": "Extravehicular Mobility Unit (EMU)",
"Description": "The EMU is a spacesuit that provides environmental protection, "
"mobility, life support, and communications for astronauts",
"_id": "article_591"
}
],
tensor_fields=["Title", "Description"]
)

results = mq.index("my-first-index").search(
Expand Down Expand Up @@ -278,8 +280,7 @@ import pprint
mq = marqo.Client(url="http://localhost:8882")

mq.create_index("my-weighted-query-index")
mq.index("my-weighted-query-index").add_documents(
[
mq.index("my-weighted-query-index").add_documents([
{
"Title": "Smartphone",
"Description": "A smartphone is a portable computer device that combines mobile telephone "
Expand All @@ -296,7 +297,8 @@ mq.index("my-weighted-query-index").add_documents(
"is an extinct carnivorous marsupial."
"The last known of its species died in 1936.",
},
]
],
tensor_fields=["Title", "Description"]
)

# initially we ask for a type of communications device which is popular in the 21st century
Expand Down Expand Up @@ -371,9 +373,6 @@ mq.index("my-first-multimodal-index").add_documents(
},
},
],
# Create the mappings, here we define our captioned_image mapping
# which weights the image more heavily than the caption - these pairs
# will be represented by a single vector in the index
mappings={
"captioned_image": {
"type": "multimodal_combination",
Expand All @@ -383,6 +382,7 @@ mq.index("my-first-multimodal-index").add_documents(
},
}
},
tensor_fields=["captioned_image"],
)

# Search this index with a simple text query
Expand Down
44 changes: 33 additions & 11 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]]

start_time_client_request = timer()
if highlights is not None:
logging.warning("Deprecation warning for parameter 'highlights'. "
"Please use the 'showHighlights' instead. ")
mq_logger.warning("Deprecation warning for parameter 'highlights'. "
"Please use the 'showHighlights' instead. ")
show_highlights = highlights if show_highlights is True else show_highlights

path_with_query_str = (
Expand Down Expand Up @@ -268,6 +268,7 @@ def add_documents(
auto_refresh: bool = True,
client_batch_size: int = None,
device: str = None,
tensor_fields: List[str] = None,
non_tensor_fields: List[str] = None,
use_existing_tensors: bool = False,
image_download_headers: dict = None,
Expand All @@ -287,7 +288,11 @@ def add_documents(
client-side.
device: the device used to index the data. Examples include "cpu",
"cuda" and "cuda:2"
non_tensor_fields: fields within documents to not create and store tensors against.
tensor_fields: fields within documents to create and store tensors against.
non_tensor_fields: fields within documents to not create and store tensors against. Cannot be used with
tensor_fields.
.. deprecated:: 2.0.0
This parameter has been deprecated and will be removed in Marqo 2.0.0. User tensor_fields instead.
use_existing_tensors: use vectors that already exist in the docs.
image_download_headers: a dictionary of headers to be passed while downloading images,
for URLs found in documents
Expand All @@ -296,15 +301,26 @@ def add_documents(
Returns:
Response body outlining indexing result
"""
if non_tensor_fields is None:
non_tensor_fields = []
if tensor_fields is not None and non_tensor_fields is not None:
raise errors.InvalidArgError('Cannot define `non_tensor_fields` when `tensor_fields` is defined. '
'`non_tensor_fields` has been deprecated and will be removed in Marqo 2.0.0. '
'Its use is discouraged.')

if tensor_fields is None and non_tensor_fields is None:
raise errors.InvalidArgError('You must include the `tensor_fields` parameter. '
'Use `tensor_fields=[]` to index for lexical-only search.')

if non_tensor_fields is not None:
mq_logger.warning('The `non_tensor_fields` parameter has been deprecated and will be removed in '
'Marqo 2.0.0. Use `tensor_fields` instead.')

if image_download_headers is None:
image_download_headers = dict()
return self._add_docs_organiser(
documents=documents, auto_refresh=auto_refresh,
client_batch_size=client_batch_size, device=device, non_tensor_fields=non_tensor_fields,
use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, mappings=mappings,
model_auth=model_auth
client_batch_size=client_batch_size, device=device, tensor_fields=tensor_fields,
non_tensor_fields=non_tensor_fields, use_existing_tensors=use_existing_tensors,
image_download_headers=image_download_headers, mappings=mappings, model_auth=model_auth
)

def _add_docs_organiser(
Expand All @@ -313,17 +329,20 @@ def _add_docs_organiser(
auto_refresh=True,
client_batch_size: int = None,
device: str = None,
tensor_fields: List = None,
non_tensor_fields: List = None,
use_existing_tensors: bool = False,
image_download_headers: dict = None,
mappings: dict = None,
model_auth: dict = None
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:

if (tensor_fields is None and non_tensor_fields is None) \
or (tensor_fields is not None and non_tensor_fields is not None):
raise ValueError("Exactly one of tensor_fields or non_tensor_fields must be provided.")

error_detected_message = ('Errors detected in add documents call. '
'Please examine the returned result object for more information.')
if non_tensor_fields is None:
non_tensor_fields = []

num_docs = len(documents)

Expand All @@ -336,12 +355,15 @@ def _add_docs_organiser(
)

base_body = {
"nonTensorFields" : non_tensor_fields,
"useExistingTensors" : use_existing_tensors,
"imageDownloadHeaders" : image_download_headers,
"mappings" : mappings,
"modelAuth": model_auth,
}
if tensor_fields is not None:
base_body['tensorFields'] = tensor_fields
else:
base_body['nonTensorFields'] = non_tensor_fields

end_time_client_process = timer()
total_client_process_time = end_time_client_process - start_time_client_process
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__marqo_version__ = "0.1.0"
__marqo_version__ = "1.0.0"
__marqo_release_page__ = f"https://github.com/marqo-ai/marqo/releases/tag/{__marqo_version__}"

__minimum_supported_marqo_version__ = "0.1.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/marqo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def wrapper(self, *args, **kwargs):
if len(docs) == 0:
continue
self.client.create_index(index_name=index_name)
self.client.index(index_name).add_documents(docs)
self.client.index(index_name).add_documents(docs, non_tensor_fields=[])
if self.IS_MULTI_INSTANCE:
self.warm_request(self.client.bulk_search, [{
"index": index_name,
Expand Down
49 changes: 29 additions & 20 deletions tests/v0_tests/test_add_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_add_documents_with_ids(self):
}
res = self.client.index(self.index_name_1).add_documents([
d1, d2
])
], tensor_fields=["field X", "field 1", "doc title"])
retrieved_d1 = self.client.index(self.index_name_1).get_document(
document_id="e197e580-0393-4f4e-90e9-8cdf4b17e339")
assert retrieved_d1 == d1
Expand All @@ -115,7 +115,7 @@ def test_add_documents(self):
"doc title": "Just Your Average Doc",
"field X": "this is a solid doc"
}
res = self.client.index(self.index_name_1).add_documents([d1, d2])
res = self.client.index(self.index_name_1).add_documents([d1, d2], tensor_fields=["field X", "field 1", "doc title"])
ids = [item["_id"] for item in res["items"]]
assert len(ids) == 2
assert ids[0] != ids[1]
Expand All @@ -133,14 +133,14 @@ def test_add_documents_with_ids_twice(self):
"field X": "this is a solid doc",
"_id": "56"
}
self.client.index(self.index_name_1).add_documents([d1])
self.client.index(self.index_name_1).add_documents([d1], tensor_fields=["field X", "doc title"])
assert d1 == self.client.index(self.index_name_1).get_document("56")
d2 = {
"_id": "56",
"completely": "different doc.",
"field X": "this is a solid doc"
}
self.client.index(self.index_name_1).add_documents([d2])
self.client.index(self.index_name_1).add_documents([d2], tensor_fields=["field X", "completely"])
assert d2 == self.client.index(self.index_name_1).get_document("56")

def test_add_batched_documents(self):
Expand All @@ -158,7 +158,7 @@ def test_add_batched_documents(self):
"_id": doc_id}
for doc_id in doc_ids]
assert len(docs) == 100
ix.add_documents(docs, client_batch_size=4)
ix.add_documents(docs, client_batch_size=4, tensor_fields=["Title", "Generic text"])
ix.refresh()
# takes too long to search for all...
for _id in [0, 19, 20, 99]:
Expand All @@ -180,7 +180,7 @@ def test_delete_docs(self):
self.client.index(self.index_name_1).add_documents([
{"abc": "wow camel", "_id": "123"},
{"abc": "camels are cool", "_id": "foo"}
])
], tensor_fields=["abc"])

if self.IS_MULTI_INSTANCE:
self.warm_request(self.client.index(self.index_name_1).search, "wow camel")
Expand All @@ -200,7 +200,7 @@ def test_delete_docs(self):

def test_delete_docs_empty_ids(self):
self.client.create_index(index_name=self.index_name_1)
self.client.index(self.index_name_1).add_documents([{"abc": "efg", "_id": "123"}])
self.client.index(self.index_name_1).add_documents([{"abc": "efg", "_id": "123"}], tensor_fields=["abc"])
try:
self.client.index(self.index_name_1).delete_documents([])
raise AssertionError
Expand All @@ -213,13 +213,13 @@ def test_delete_docs_empty_ids(self):
def test_get_document(self):
my_doc = {"abc": "efg", "_id": "123"}
self.client.create_index(index_name=self.index_name_1)
self.client.index(self.index_name_1).add_documents([my_doc])
self.client.index(self.index_name_1).add_documents([my_doc], tensor_fields=["abc"])
retrieved = self.client.index(self.index_name_1).get_document(document_id='123')
assert retrieved == my_doc

def test_add_documents_missing_index_fails(self):
with pytest.raises(MarqoWebError) as ex:
self.client.index(self.index_name_1).add_documents([{"abd": "efg"}])
self.client.index(self.index_name_1).add_documents([{"abd": "efg"}], tensor_fields=["abc"])

assert "index_not_found" == ex.value.code

Expand All @@ -232,7 +232,7 @@ def test_add_documents_with_device(self):
def run():
temp_client.index(self.index_name_1).add_documents(documents=[
{"d1": "blah"}, {"d2", "some data"}
], device="cuda:45")
], device="cuda:45", tensor_fields=["d1", "d2"])
return True

assert run()
Expand All @@ -249,7 +249,7 @@ def test_add_documents_with_device_batching(self):
def run():
temp_client.index(self.index_name_1).add_documents(documents=[
{"d1": "blah"}, {"d2", "some data"}, {"d2331": "blah"}, {"45d2", "some data"}
], client_batch_size=2, device="cuda:37")
], client_batch_size=2, device="cuda:37", tensor_fields=["d1", "d2", "d2331", "45d2"])
return True

assert run()
Expand All @@ -267,7 +267,7 @@ def test_add_documents_device_not_set(self):
def run():
temp_client.index(self.index_name_1).add_documents(documents=[
{"d1": "blah"}, {"d2", "some data"}
])
], tensor_fields=["d1", "d2"])
return True

assert run()
Expand All @@ -286,10 +286,10 @@ def test_add_documents_set_refresh(self):
def run():
temp_client.index(self.index_name_1).add_documents(documents=[
{"d1": "blah"}, {"d2", "some data"}
], auto_refresh=False)
], auto_refresh=False, tensor_fields=["d1", "d2"])
temp_client.index(self.index_name_1).add_documents(documents=[
{"d1": "blah"}, {"d2", "some data"}
], auto_refresh=True)
], auto_refresh=True, tensor_fields=["d1", "d2"])
return True

assert run()
Expand All @@ -306,7 +306,7 @@ def test_add_documents_with_no_processes(self):
def run():
self.client.index(self.index_name_1).add_documents(documents=[
{"d1": "blah"}, {"d2", "some data"}
])
], tensor_fields=["d1", "d2"])
return True

assert run()
Expand All @@ -324,7 +324,7 @@ def test_resilient_indexing(self):
d1 = {"d1": "blah", "_id": "1234"}
d2 = {"d2": "blah", "_id": "5678"}
docs = [d1, {"content": "some terrible doc", "d3": "blah", "_id": 12345}, d2]
self.client.index(self.index_name_1).add_documents(documents=docs)
self.client.index(self.index_name_1).add_documents(documents=docs, tensor_fields=["d1", "d2", "d3", "content"])

if self.IS_MULTI_INSTANCE:
time.sleep(1)
Expand Down Expand Up @@ -356,7 +356,8 @@ def test_batching_add_docs(self):
@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
res = self.client.index(self.index_name_1).add_documents(
auto_refresh=auto_refresh, documents=docs, client_batch_size=client_batch_size, )
auto_refresh=auto_refresh, documents=docs, client_batch_size=client_batch_size,
tensor_fields=["Title", "Description"])
if client_batch_size is not None:
assert isinstance(res, list)
assert len(res) == math.ceil(docs_to_add / client_batch_size)
Expand Down Expand Up @@ -431,7 +432,7 @@ def test_use_existing_fields(self):
"title 1": "content 1",
"desc 2": "content 2. blah blah blah",
"new f": "12345 "
}], use_existing_tensors=True
}], use_existing_tensors=True, tensor_fields=["desc 2", "new f", "title 1"]
)
# we don't get desc 2 facets, because it was already a non_tensor_field
assert {"title 1", "_embedding", "new f"} == functools.reduce(
Expand Down Expand Up @@ -478,7 +479,7 @@ def test_multimodal_combination_doc(self):
"type": "multimodal_combination", "weights": {
"space child 1": 0.5,
"space child 2": 0.5,
}}}, auto_refresh=True)
}}}, auto_refresh=True, tensor_fields=["combo_text_image", "space field"])

if self.IS_MULTI_INSTANCE:
self.warm_request(self.client.index(self.index_name_1).search,
Expand Down Expand Up @@ -534,7 +535,7 @@ def test_add_docs_image_download_headers(self):
def run():
image_download_headers = {"Authentication": "my-secret-key"}
self.client.index(index_name=self.index_name_1).add_documents(
documents=[{"some": "data"}], image_download_headers=image_download_headers)
documents=[{"some": "data"}], image_download_headers=image_download_headers, tensor_fields=["some"])
args, kwargs = mock__post.call_args
assert "imageDownloadHeaders" in kwargs['body']
assert kwargs['body']['imageDownloadHeaders'] == image_download_headers
Expand All @@ -543,4 +544,12 @@ def run():

assert run()

def test_add_docs_logs_deprecation_warning_if_non_tensor_fields(self):
# Arrange
documents = [{'id': 'doc1', 'text': 'Test document'}]
non_tensor_fields = ['text']

with self.assertLogs('marqo', level='WARNING') as cm:
self.client.create_index(self.index_name_1)
self.client.index(self.index_name_1).add_documents(documents=documents, non_tensor_fields=non_tensor_fields)
self.assertTrue({'`non_tensor_fields`', 'Marqo', '2.0.0.'}.issubset(set(cm.output[0].split(" "))))
3 changes: 2 additions & 1 deletion tests/v0_tests/test_boost_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def setUp(self) -> None:
"Description": "A history of household pets",
"_id": "d2"
}
]
],
tensor_fields=["Title", "Description"]
)

self.query = "What are the best pets"
Expand Down
Loading

0 comments on commit 02aefeb

Please sign in to comment.