Skip to content

Commit

Permalink
debugged client and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez committed Sep 4, 2024
1 parent c0fafbd commit 5e9f539
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/marqo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def create_index(
normalize_embeddings: Optional[bool] = None,
text_preprocessing: Optional[marqo_index.TextPreProcessing] = None,
image_preprocessing: Optional[marqo_index.ImagePreProcessing] = None,
audio_preprocessing: Optional[marqo_index.AudioPreProcessing] = None,
video_preprocessing: Optional[marqo_index.VideoPreProcessing] = None,
vector_numeric_type: Optional[marqo_index.VectorNumericType] = None,
ann_parameters: Optional[marqo_index.AnnParameters] = None,
wait_for_readiness: bool = True,
Expand Down Expand Up @@ -141,6 +143,8 @@ def create_index(
normalize_embeddings=normalize_embeddings,
text_preprocessing=text_preprocessing,
image_preprocessing=image_preprocessing,
audio_preprocessing=audio_preprocessing,
video_preprocessing=video_preprocessing,
vector_numeric_type=vector_numeric_type,
ann_parameters=ann_parameters,
wait_for_readiness=wait_for_readiness,
Expand Down
6 changes: 6 additions & 0 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def create(config: Config,
normalize_embeddings: Optional[bool] = None,
text_preprocessing: Optional[marqo_index.TextPreProcessing] = None,
image_preprocessing: Optional[marqo_index.ImagePreProcessing] = None,
audio_preprocessing: Optional[marqo_index.AudioPreProcessing] = None,
video_preprocessing: Optional[marqo_index.VideoPreProcessing] = None,
vector_numeric_type: Optional[marqo_index.VectorNumericType] = None,
ann_parameters: Optional[marqo_index.AnnParameters] = None,
inference_type: Optional[str] = None,
Expand Down Expand Up @@ -158,6 +160,8 @@ def create(config: Config,
normalizeEmbeddings=normalize_embeddings,
textPreprocessing=text_preprocessing,
imagePreprocessing=image_preprocessing,
audioPreprocessing=audio_preprocessing,
videoPreprocessing=video_preprocessing,
vectorNumericType=vector_numeric_type,
annParameters=ann_parameters,
textChunkPrefix=text_chunk_prefix,
Expand All @@ -181,6 +185,8 @@ def create(config: Config,
normalizeEmbeddings=normalize_embeddings,
textPreprocessing=text_preprocessing,
imagePreprocessing=image_preprocessing,
audioPreprocessing=audio_preprocessing,
videoPreprocessing=video_preprocessing,
vectorNumericType=vector_numeric_type,
annParameters=ann_parameters,
numberOfInferences=number_of_inferences,
Expand Down
2 changes: 2 additions & 0 deletions src/marqo/models/create_index_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class IndexSettings(MarqoBaseModel):
normalizeEmbeddings: Optional[bool] = None
textPreprocessing: Optional[marqo_index.TextPreProcessing] = None
imagePreprocessing: Optional[marqo_index.ImagePreProcessing] = None
audioPreprocessing: Optional[marqo_index.AudioPreProcessing] = None
videoPreprocessing: Optional[marqo_index.VideoPreProcessing] = None
vectorNumericType: Optional[marqo_index.VectorNumericType] = None
annParameters: Optional[marqo_index.AnnParameters] = None
textQueryPrefix: Optional[str] = None
Expand Down
10 changes: 10 additions & 0 deletions src/marqo/models/marqo_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class FieldType(str, Enum):
ArrayFloat = 'array<float>'
ArrayDouble = 'array<double>'
ImagePointer = 'image_pointer'
VideoPointer = 'video_pointer'
AudioPointer = 'audio_pointer'
MultimodalCombination = 'multimodal_combination'
CustomVector = "custom_vector"
MapInt = 'map<text, int>'
Expand Down Expand Up @@ -77,6 +79,14 @@ class TextPreProcessing(StrictBaseModel):
class ImagePreProcessing(StrictBaseModel):
patchMethod: Optional[PatchMethod] = Field(None, alias="patch_method")

class VideoPreProcessing(StrictBaseModel):
splitLength: Optional[int] = Field(None, alias="split_length")
splitOverlap: Optional[int] = Field(None, alias="split_overlap")

class AudioPreProcessing(StrictBaseModel):
splitLength: Optional[int] = Field(None, alias="split_length")
splitOverlap: Optional[int] = Field(None, alias="split_overlap")


class Model(StrictBaseModel):
name: Optional[str] = None
Expand Down
22 changes: 22 additions & 0 deletions tests/marqo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def setUpClass(cls) -> None:
cls.unstructured_no_model_index_name = "unstructured_no_model_index"
cls.structured_image_index_name_simple_preprocessing_method = \
"structured_image_index_simple_preprocessing_method"
cls.unstructured_languagebind_index_name = "unstructured_languagebind_index"
cls.structured_languagebind_index_name = "structured_languagebind_index"

# TODO: include structured when boolean_field bug for structured is fixed
cls.test_cases = [
(CloudTestIndex.unstructured_image, cls.unstructured_index_name),
Expand Down Expand Up @@ -262,6 +265,25 @@ def setUpClass(cls) -> None:
"type": "no_model",
"dimensions": 512
}
},
{
"indexName": cls.unstructured_languagebind_index_name,
"type": "unstructured",
"model": "LanguageBind/Video_V1.5_FT_Audio_FT_Image",
"treatUrlsAndPointersAsMedia": True,
"treatUrlsAndPointersAsImages": True
},
{
"indexName": cls.structured_languagebind_index_name,
"type": "structured",
"model": "LanguageBind/Video_V1.5_FT_Audio_FT_Image",
"allFields": [
{"name": "text_field", "type": "text"},
{"name": "video_field", "type": "video_pointer"},
{"name": "audio_field", "type": "audio_pointer"},
{"name": "image_field", "type": "image_pointer"}
],
"tensorFields": ["text_field", "video_field", "audio_field", "image_field"]
}
])
except Exception as e:
Expand Down
95 changes: 95 additions & 0 deletions tests/v2_tests/test_create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,98 @@ def test_dash_and_underscore_in_index_name(self):
self.assertEqual(1, len(res['hits']))
self.client.delete_index("test-dash-and-under-score")
self.client.delete_index("test_dash_and_under_score")

def test_create_unstructured_index_with_languagebind(self):
self.client.create_index(
index_name=self.index_name,
type="unstructured",
model="LanguageBind/Video_V1.5_FT_Audio_FT_Image",
treat_urls_and_pointers_as_media=True,
treat_urls_and_pointers_as_images=True
)

index_settings = self.client.index(self.index_name).get_settings()

expected_settings = {
"type": "unstructured",
"model": "LanguageBind/Video_V1.5_FT_Audio_FT_Image",
"normalizeEmbeddings": True,
"treatUrlsAndPointersAsMedia": True,
"treatUrlsAndPointersAsImages": True,
"vectorNumericType": "float"
}

for key, value in expected_settings.items():
self.assertEqual(value, index_settings[key])

# Test adding and searching documents
ix = self.client.index(self.index_name)

res = ix.add_documents(
documents = [
{"audio_field": "https://audio-previews.elements.envatousercontent.com/files/187680354/preview.mp3", "_id": "corporate"},
{"audio_field": "https://audio-previews.elements.envatousercontent.com/files/492763015/preview.mp3", "_id": "lofi"},
],
tensor_fields=["audio_field"]
)

doc = ix.search(
q="corporate video background music",
limit=5
)

self.assertEqual(2, len(doc['hits']))
self.assertEqual("corporate", doc['hits'][0]['_id'])
self.assertEqual("lofi", doc['hits'][1]['_id'])

def test_create_structured_index_with_languagebind(self):
self.client.create_index(
index_name=self.index_name,
type="structured",
model="LanguageBind/Video_V1.5_FT_Audio_FT_Image",
all_fields=[
{"name": "text_field", "type": "text"},
{"name": "video_field", "type": "video_pointer"},
{"name": "audio_field", "type": "audio_pointer"},
{"name": "image_field", "type": "image_pointer"}
],
tensor_fields=["text_field", "video_field", "audio_field", "image_field"]
)

index_settings = self.client.index(self.index_name).get_settings()

expected_settings = {
"type": "structured",
"model": "LanguageBind/Video_V1.5_FT_Audio_FT_Image",
"normalizeEmbeddings": True,
"vectorNumericType": "float",
"tensorFields": ["text_field", "video_field", "audio_field", "image_field"],
"allFields": [
{"features": [], "name": "text_field", "type": "text"},
{"features": [], "name": "video_field", "type": "video_pointer"},
{"features": [], "name": "audio_field", "type": "audio_pointer"},
{"features": [], "name": "image_field", "type": "image_pointer"},
]
}

for key, value in expected_settings.items():
self.assertEqual(value, index_settings[key])

# Test adding and searching documents
ix = self.client.index(self.index_name)

res = ix.add_documents(
documents = [
{"audio_field": "https://audio-previews.elements.envatousercontent.com/files/187680354/preview.mp3", "_id": "corporate"},
{"audio_field": "https://audio-previews.elements.envatousercontent.com/files/492763015/preview.mp3", "_id": "lofi"},
],
)

doc = ix.search(
q="corporate video background music",
limit=5
)

self.assertEqual(2, len(doc['hits']))
self.assertEqual("corporate", doc['hits'][0]['_id'])
self.assertEqual("lofi", doc['hits'][1]['_id'])
57 changes: 57 additions & 0 deletions tests/v2_tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,60 @@ def test_embed_non_numeric_weight_fails(self):
self.client.index(test_index_name).embed(content={"text to embed": "not a number"})

self.assertIn("not a valid float", str(e.exception))

def test_embed_images_with_languagebind(self):
"""Embeds multiple images using LanguageBind model."""
test_index_name = self.unstructured_languagebind_index_name

image_urls = [
"https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png",
"https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png",
"https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png"
]

embed_res = self.client.index(test_index_name).embed(content=image_urls)

self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], image_urls)
self.assertEqual(len(embed_res["embeddings"]), 3)

# Check that embeddings are non-zero and have the expected shape
for embedding in embed_res["embeddings"]:
self.assertGreater(len(embedding), 0)
self.assertTrue(any(abs(x) > 1e-6 for x in embedding))

# Check that embeddings are close to the expected values
expected_embedding = [0.019889963790774345, -0.01263524405658245,
0.026028314605355263, 0.005291664972901344, -0.013181567192077637]
for embedding in embed_res["embeddings"]:
for i, value in enumerate(expected_embedding):
self.assertAlmostEqual(embedding[i], value, places=5)


def test_embed_videos_with_languagebind(self):
"""Embeds multiple videos using LanguageBind model."""
test_index_name = self.structured_languagebind_index_name

video_urls = [
"https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4",
"https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4",
"https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4"
]

embed_res = self.client.index(test_index_name).embed(content=video_urls)

self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], video_urls)
self.assertEqual(len(embed_res["embeddings"]), 3)

# Check that embeddings are non-zero and have the expected shape
for embedding in embed_res["embeddings"]:
self.assertGreater(len(embedding), 0)
self.assertTrue(any(abs(x) > 1e-6 for x in embedding))

# Check that embeddings are close to the expected values
expected_embedding = [0.0394694060087204, 0.049264926463365555,
-0.014714145101606846, 0.05715121701359749, -0.019508328288793564]
for embedding in embed_res["embeddings"]:
for i, value in enumerate(expected_embedding):
self.assertAlmostEqual(embedding[i], value, places=5)

0 comments on commit 5e9f539

Please sign in to comment.