diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index 40273c0bf..8cfd1a195 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -1,27 +1,28 @@ """The API entrypoint for Tensor Search""" +import json +import os import typing -from fastapi.responses import JSONResponse -from fastapi import Request, Depends -from marqo.tensor_search.models.add_docs_objects import AddDocsParams -from marqo.tensor_search.models.add_docs_objects import ModelAuth -from marqo.errors import InvalidArgError, MarqoWebError, MarqoError +from typing import List, Dict, Optional, Union + +import pydantic from fastapi import FastAPI, Query -import json -from marqo.tensor_search import tensor_search +from fastapi import Request, Depends +from fastapi.responses import JSONResponse + from marqo import config -from typing import List, Dict -import os -from marqo.tensor_search.models.api_models import BulkSearchQuery, SearchQuery -from marqo.tensor_search.web import api_validation, api_utils -from marqo.tensor_search.on_start_script import on_start from marqo import version +from marqo.errors import InvalidArgError, MarqoWebError, MarqoError, BadRequestError +from marqo.tensor_search import tensor_search from marqo.tensor_search.backend import get_index_info from marqo.tensor_search.enums import RequestType +from marqo.tensor_search.models.add_docs_objects import (AddDocsParams, ModelAuth, + AddDocsBodyParams) +from marqo.tensor_search.models.api_models import BulkSearchQuery, SearchQuery +from marqo.tensor_search.on_start_script import on_start +from marqo.tensor_search.telemetry import RequestMetricsStore, TelemetryMiddleware from marqo.tensor_search.throttling.redis_throttle import throttle from marqo.tensor_search.utils import add_timing -import pydantic - -from marqo.tensor_search.telemetry import RequestMetricsStore, TelemetryMiddleware +from marqo.tensor_search.web import api_validation, api_utils def replace_host_localhosts(OPENSEARCH_IS_INTERNAL: str, OS_URL: str): @@ -168,33 +169,38 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api @app.post("/indexes/{index_name}/documents") @throttle(RequestType.INDEX) def add_or_replace_documents( - docs: List[Dict], + request: Request, + body: typing.Union[AddDocsBodyParams, List[Dict]], index_name: str, refresh: bool = True, marqo_config: config.Config = Depends(generate_config), - non_tensor_fields: List[str] = Query(default=[]), + non_tensor_fields: Optional[List[str]] = Query(default=[]), device: str = Depends(api_validation.validate_device), - use_existing_tensors: bool = False, - image_download_headers: typing.Optional[dict] = Depends( + use_existing_tensors: Optional[bool] = False, + image_download_headers: Optional[dict] = Depends( api_utils.decode_image_download_headers ), - model_auth: typing.Optional[ModelAuth] = Depends( + model_auth: Optional[ModelAuth] = Depends( api_utils.decode_query_string_model_auth ), - mappings: typing.Optional[dict] = Depends(api_utils.decode_mappings)): + mappings: Optional[dict] = Depends(api_utils.decode_mappings)): + """add_documents endpoint (replace existing docs with the same id)""" - add_docs_params = AddDocsParams( - index_name=index_name, docs=docs, auto_refresh=refresh, - device=device, non_tensor_fields=non_tensor_fields, - use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, - mappings=mappings, model_auth=model_auth - ) + add_docs_params = api_utils.add_docs_params_orchestrator(index_name=index_name, body=body, + device=device, auto_refresh=refresh, + non_tensor_fields=non_tensor_fields, mappings=mappings, + model_auth=model_auth, + image_download_headers=image_download_headers, + use_existing_tensors=use_existing_tensors, + query_parameters=request.query_params) + with RequestMetricsStore.for_request().time(f"POST /indexes/{index_name}/documents"): return tensor_search.add_documents( config=marqo_config, add_docs_params=add_docs_params ) + @app.get("/indexes/{index_name}/documents/{document_id}") def get_document_by_id(index_name: str, document_id: str, marqo_config: config.Config = Depends(generate_config), diff --git a/src/marqo/tensor_search/models/add_docs_objects.py b/src/marqo/tensor_search/models/add_docs_objects.py index 326725630..0b5f0caac 100644 --- a/src/marqo/tensor_search/models/add_docs_objects.py +++ b/src/marqo/tensor_search/models/add_docs_objects.py @@ -9,12 +9,28 @@ from marqo.errors import InternalError from pydantic import BaseModel from marqo.tensor_search.utils import get_best_available_device +from typing import List, Dict class AddDocsParamsConfig: arbitrary_types_allowed = True +class AddDocsBodyParams(BaseModel): + """The parameters of the body parameters of tensor_search_add_documents() function""" + class Config: + arbitrary_types_allowed = True + allow_mutation = False + extra = "forbid" # Raise error on unknown fields + + nonTensorFields: List = Field(default_factory=list) + useExistingTensors: bool = False + imageDownloadHeaders: dict = Field(default_factory=dict) + modelAuth: Optional[ModelAuth] = None + mappings: Optional[dict] = None + documents: Union[Sequence[Union[dict, Any]], np.ndarray] + + class AddDocsParams(BaseModel): """Represents the parameters of the tensor_search.add_documents() function @@ -56,4 +72,4 @@ def __init__(self, **data: dict): # Ensure `None` and passing nothing are treated the same for device if "device" not in data or data["device"] is None: data["device"] = get_best_available_device() - super().__init__(**data) + super().__init__(**data) \ No newline at end of file diff --git a/src/marqo/tensor_search/web/api_utils.py b/src/marqo/tensor_search/web/api_utils.py index 7bf060a4e..16b12291c 100644 --- a/src/marqo/tensor_search/web/api_utils.py +++ b/src/marqo/tensor_search/web/api_utils.py @@ -5,6 +5,10 @@ from typing import Optional from marqo.tensor_search.utils import construct_authorized_url from marqo.tensor_search.models.add_docs_objects import ModelAuth +from marqo.tensor_search.models.add_docs_objects import AddDocsParams, AddDocsBodyParams +from marqo.errors import BadRequestError +from typing import Union, List, Optional, Dict +from fastapi import Request def upconstruct_authorized_url(opensearch_url: str) -> str: @@ -48,7 +52,7 @@ def translate_api_device(device: Optional[str]) -> Optional[str]: lowered_device.startswith(acceptable), lowered_device.replace(acceptable, ""), acceptable - ) + ) for acceptable in acceptable_devices] try: @@ -129,3 +133,47 @@ def decode_mappings(mappings: Optional[str] = None) -> dict: return as_dict except json.JSONDecodeError as e: raise InvalidArgError(f"Error parsing mappings. Message: {e}") + + +def add_docs_params_orchestrator(index_name: str, body: Union[AddDocsBodyParams, List[Dict]], + device: str, auto_refresh: bool = True, non_tensor_fields: Optional[List[str]] = [], + mappings: Optional[dict] = dict(), model_auth: Optional[ModelAuth] = None, + image_download_headers: Optional[dict] = dict(), + use_existing_tensors: Optional[bool] = False, query_parameters: Optional[Dict] = dict()) -> AddDocsParams: + """An orchestrator for the add_documents API to support both old and new versions of the API. + All the arguments are decoded and validated in the API function. This function is only responsible for orchestrating. + + Returns: + AddDocsParams: An AddDocsParams object for internal use + """ + + if isinstance(body, AddDocsBodyParams): + docs = body.documents + + # Check for query parameters that are not supported in the new API + deprecated_fields = ["non_tensor_fields", "use_existing_tensors", "image_download_headers", "model_auth", "mappings"] + if any(field in query_parameters for field in deprecated_fields): + raise BadRequestError("Marqo is not accepting any of the following parameters in the query string: " + "`non_tensor_fields`, `use_existing_tensors`, `image_download_headers`, `model_auth`, `mappings`. " + "Please move these parameters to the request body as " + "`nonTensorFields`,` useExistingTensors`, `imageDownloadHeaders`, `modelAuth`, `mappings`. and try again. " + "Please check `https://docs.marqo.ai/latest/API-Reference/documents/` for the correct APIs.") + + mappings = body.mappings + non_tensor_fields = body.nonTensorFields + use_existing_tensors = body.useExistingTensors + model_auth = body.modelAuth + image_download_headers = body.imageDownloadHeaders + + elif isinstance(body, list) and all(isinstance(item, dict) for item in body): + docs = body + + else: + raise InternalError(f"Unexpected request body type `{type(body).__name__} for `/documents` API. ") + + return AddDocsParams( + index_name=index_name, docs=docs, auto_refresh=auto_refresh, + device=device, non_tensor_fields=non_tensor_fields, + use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, + mappings=mappings, model_auth=model_auth + ) diff --git a/tests/tensor_search/test_api_utils.py b/tests/tensor_search/test_api_utils.py index 920834648..2bc2677c6 100644 --- a/tests/tensor_search/test_api_utils.py +++ b/tests/tensor_search/test_api_utils.py @@ -1,10 +1,12 @@ import pydantic -from marqo.tensor_search.models.add_docs_objects import ModelAuth +from marqo.tensor_search.models.add_docs_objects import ModelAuth, AddDocsParams, AddDocsBodyParams +from marqo.tensor_search.web.api_utils import add_docs_params_orchestrator from marqo.tensor_search.models.private_models import S3Auth import urllib.parse from marqo.tensor_search.web import api_utils -from marqo.errors import InvalidArgError, InternalError +from marqo.errors import InvalidArgError, InternalError, BadRequestError from tests.marqo_test import MarqoTestCase +import unittest class TestApiUtils(MarqoTestCase): @@ -26,15 +28,17 @@ def test_translate_api_device_bad(self): def test_generate_config(self): for opensearch_url, authorized_url in [ - ("http://admin:admin@localhost:9200", "http://admin:admin@localhost:9200"), - ("http://localhost:9200", "http://admin:admin@localhost:9200"), - ("https://admin:admin@localhost:9200", "https://admin:admin@localhost:9200"), - ("https://localhost:9200", "https://admin:admin@localhost:9200"), - ("http://king_user:mysecretpw@unusual.com/happy@chappy:9200", "http://king_user:mysecretpw@unusual.com/happy@chappy:9200"), - ("http://unusual.com/happy@chappy:9200", "http://admin:admin@unusual.com/happy@chappy:9200"), - ("http://www.unusual.com/happy@@@@#chappy:9200", "http://admin:admin@www.unusual.com/happy@@@@#chappy:9200"), - ("://", "://admin:admin@") - ]: + ("http://admin:admin@localhost:9200", "http://admin:admin@localhost:9200"), + ("http://localhost:9200", "http://admin:admin@localhost:9200"), + ("https://admin:admin@localhost:9200", "https://admin:admin@localhost:9200"), + ("https://localhost:9200", "https://admin:admin@localhost:9200"), + ("http://king_user:mysecretpw@unusual.com/happy@chappy:9200", + "http://king_user:mysecretpw@unusual.com/happy@chappy:9200"), + ("http://unusual.com/happy@chappy:9200", "http://admin:admin@unusual.com/happy@chappy:9200"), + ( + "http://www.unusual.com/happy@@@@#chappy:9200", "http://admin:admin@www.unusual.com/happy@@@@#chappy:9200"), + ("://", "://admin:admin@") + ]: c = api_utils.upconstruct_authorized_url(opensearch_url=opensearch_url) assert authorized_url == c @@ -45,7 +49,8 @@ def test_generate_config_bad_url(self): raise AssertionError except InternalError: pass - + + class TestDecodeQueryStringModelAuth(MarqoTestCase): def test_decode_query_string_model_auth_none(self): @@ -71,4 +76,116 @@ def test_decode_query_string_model_auth_valid(self): def test_decode_query_string_model_auth_invalid(self): with self.assertRaises(pydantic.ValidationError): - api_utils.decode_query_string_model_auth("invalid_url_encoded_string") \ No newline at end of file + api_utils.decode_query_string_model_auth("invalid_url_encoded_string") + + +class TestAddDocsParamsOrchestrator(unittest.TestCase): + def test_add_docs_params_orchestrator(self): + # Set up the arguments for the function + index_name = "test-index" + body = AddDocsBodyParams(documents=[{"test": "doc"}], + nonTensorFields=["field1"], + useExistingTensors=True, + imageDownloadHeaders={"header1": "value1"}, + modelAuth=ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")), + mappings={"map1": "value1"}) + device = "test-device" + auto_refresh = True + + # Query parameters should be parsed as default values + non_tensor_fields = [] + use_existing_tensors = False + image_download_headers = dict() + model_auth = None + mappings = dict() + + # Call the function with the arguments + result = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings, + model_auth, image_download_headers, use_existing_tensors) + + # Assert that the result is as expected + assert isinstance(result, AddDocsParams) + assert result.index_name == "test-index" + assert result.docs == body.documents + assert result.device == "test-device" + assert result.non_tensor_fields == ["field1"] + assert result.use_existing_tensors == True + assert result.docs == [{"test": "doc"}] + assert result.image_download_headers == {"header1": "value1"} + + def test_add_docs_params_orchestrator_deprecated_query_parameters(self): + # Set up the arguments for the function + index_name = "test-index" + model_auth = ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")) + + body = [{"test": "doc"}] + + device = "test-device" + non_tensor_fields = ["field1"] + use_existing_tensors = True + image_download_headers = {"header1": "value1"} + model_auth = model_auth + mappings = {"map1": "value1"} + auto_refresh = True + + # Call the function with the arguments + result = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings, + model_auth, image_download_headers, use_existing_tensors) + + # Assert that the result is as expected + assert isinstance(result, AddDocsParams) + assert result.index_name == "test-index" + assert result.docs == body + assert result.device == "test-device" + assert result.non_tensor_fields == ["field1"] + assert result.use_existing_tensors == True + assert result.docs == [{"test": "doc"}] + assert result.image_download_headers == {"header1": "value1"} + + def test_add_docs_params_orchestrator_error(self): + # Test the case where the function should raise an error due to invalid input + body = "invalid body type" # Not an instance of AddDocsBodyParams or List[Dict] + + index_name = "test-index" + model_auth = ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")) + + device = "test-device" + non_tensor_fields = ["field1"] + use_existing_tensors = True + image_download_headers = {"header1": "value1"} + model_auth = model_auth + mappings = {"map1": "value1"} + auto_refresh = True + + # Use pytest.raises to check for the error + try: + _ = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings, + model_auth, image_download_headers, use_existing_tensors) + except InternalError as e: + self.assertIn("Unexpected request body type", str(e)) + + def test_add_docs_params_orchestrator_deprecated_query_parameters_error(self): + # Test the case where the function should raise an error due to deprecated query parameters + index_name = "test-index" + model_auth = ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")) + device = "test-device" + auto_refresh = True + body = AddDocsBodyParams(documents=[{"test": "doc"}], + nonTensorFields=["field1"], + useExistingTensors=True, + imageDownloadHeaders={"header1": "value1"}, + modelAuth=ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")), + mappings={"map1": "value1"}) + + params = {"non_tensor_fields": ["what"], "use_existing_tensors": True, + "image_download_headers": {"header2": "value2"}, "model_auth": model_auth, + "mappings": {"map2": "value2"}} + + for param, value in params.items(): + kwargs = {key: None for key in params.keys()} + kwargs[param] = value + try: + add_docs_params_orchestrator(index_name, body, device, auto_refresh=auto_refresh, + query_parameters=kwargs, **kwargs) + except BadRequestError as e: + self.assertIn("Marqo is not accepting any of the following parameters in the query string", str(e))