Skip to content

Commit

Permalink
Move add_documents API query parameters into request body (#535)
Browse files Browse the repository at this point in the history
Move `POST /index_name/documents` (add documents) query parameters `non_tensor_fields`, `use_existing_tensors`, `image_download_headers`, `model_auth`, `mappings` to the request body, while still supporting the old request format for backwards compatibility.
  • Loading branch information
wanliAlex authored Jul 14, 2023
1 parent 767fe4f commit e8d078c
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 42 deletions.
60 changes: 33 additions & 27 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
"""The API entrypoint for Tensor Search"""
import json
import os
import typing
from fastapi.responses import JSONResponse
from fastapi import Request, Depends
from marqo.tensor_search.models.add_docs_objects import AddDocsParams
from marqo.tensor_search.models.add_docs_objects import ModelAuth
from marqo.errors import InvalidArgError, MarqoWebError, MarqoError
from typing import List, Dict, Optional, Union

import pydantic
from fastapi import FastAPI, Query
import json
from marqo.tensor_search import tensor_search
from fastapi import Request, Depends
from fastapi.responses import JSONResponse

from marqo import config
from typing import List, Dict
import os
from marqo.tensor_search.models.api_models import BulkSearchQuery, SearchQuery
from marqo.tensor_search.web import api_validation, api_utils
from marqo.tensor_search.on_start_script import on_start
from marqo import version
from marqo.errors import InvalidArgError, MarqoWebError, MarqoError, BadRequestError
from marqo.tensor_search import tensor_search
from marqo.tensor_search.backend import get_index_info
from marqo.tensor_search.enums import RequestType
from marqo.tensor_search.models.add_docs_objects import (AddDocsParams, ModelAuth,
AddDocsBodyParams)
from marqo.tensor_search.models.api_models import BulkSearchQuery, SearchQuery
from marqo.tensor_search.on_start_script import on_start
from marqo.tensor_search.telemetry import RequestMetricsStore, TelemetryMiddleware
from marqo.tensor_search.throttling.redis_throttle import throttle
from marqo.tensor_search.utils import add_timing
import pydantic

from marqo.tensor_search.telemetry import RequestMetricsStore, TelemetryMiddleware
from marqo.tensor_search.web import api_validation, api_utils


def replace_host_localhosts(OPENSEARCH_IS_INTERNAL: str, OS_URL: str):
Expand Down Expand Up @@ -168,33 +169,38 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
@app.post("/indexes/{index_name}/documents")
@throttle(RequestType.INDEX)
def add_or_replace_documents(
docs: List[Dict],
request: Request,
body: typing.Union[AddDocsBodyParams, List[Dict]],
index_name: str,
refresh: bool = True,
marqo_config: config.Config = Depends(generate_config),
non_tensor_fields: List[str] = Query(default=[]),
non_tensor_fields: Optional[List[str]] = Query(default=[]),
device: str = Depends(api_validation.validate_device),
use_existing_tensors: bool = False,
image_download_headers: typing.Optional[dict] = Depends(
use_existing_tensors: Optional[bool] = False,
image_download_headers: Optional[dict] = Depends(
api_utils.decode_image_download_headers
),
model_auth: typing.Optional[ModelAuth] = Depends(
model_auth: Optional[ModelAuth] = Depends(
api_utils.decode_query_string_model_auth
),
mappings: typing.Optional[dict] = Depends(api_utils.decode_mappings)):
mappings: Optional[dict] = Depends(api_utils.decode_mappings)):

"""add_documents endpoint (replace existing docs with the same id)"""
add_docs_params = AddDocsParams(
index_name=index_name, docs=docs, auto_refresh=refresh,
device=device, non_tensor_fields=non_tensor_fields,
use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers,
mappings=mappings, model_auth=model_auth
)
add_docs_params = api_utils.add_docs_params_orchestrator(index_name=index_name, body=body,
device=device, auto_refresh=refresh,
non_tensor_fields=non_tensor_fields, mappings=mappings,
model_auth=model_auth,
image_download_headers=image_download_headers,
use_existing_tensors=use_existing_tensors,
query_parameters=request.query_params)

with RequestMetricsStore.for_request().time(f"POST /indexes/{index_name}/documents"):
return tensor_search.add_documents(
config=marqo_config, add_docs_params=add_docs_params
)



@app.get("/indexes/{index_name}/documents/{document_id}")
def get_document_by_id(index_name: str, document_id: str,
marqo_config: config.Config = Depends(generate_config),
Expand Down
18 changes: 17 additions & 1 deletion src/marqo/tensor_search/models/add_docs_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,28 @@
from marqo.errors import InternalError
from pydantic import BaseModel
from marqo.tensor_search.utils import get_best_available_device
from typing import List, Dict


class AddDocsParamsConfig:
arbitrary_types_allowed = True


class AddDocsBodyParams(BaseModel):
"""The parameters of the body parameters of tensor_search_add_documents() function"""
class Config:
arbitrary_types_allowed = True
allow_mutation = False
extra = "forbid" # Raise error on unknown fields

nonTensorFields: List = Field(default_factory=list)
useExistingTensors: bool = False
imageDownloadHeaders: dict = Field(default_factory=dict)
modelAuth: Optional[ModelAuth] = None
mappings: Optional[dict] = None
documents: Union[Sequence[Union[dict, Any]], np.ndarray]


class AddDocsParams(BaseModel):
"""Represents the parameters of the tensor_search.add_documents() function
Expand Down Expand Up @@ -56,4 +72,4 @@ def __init__(self, **data: dict):
# Ensure `None` and passing nothing are treated the same for device
if "device" not in data or data["device"] is None:
data["device"] = get_best_available_device()
super().__init__(**data)
super().__init__(**data)
50 changes: 49 additions & 1 deletion src/marqo/tensor_search/web/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from typing import Optional
from marqo.tensor_search.utils import construct_authorized_url
from marqo.tensor_search.models.add_docs_objects import ModelAuth
from marqo.tensor_search.models.add_docs_objects import AddDocsParams, AddDocsBodyParams
from marqo.errors import BadRequestError
from typing import Union, List, Optional, Dict
from fastapi import Request


def upconstruct_authorized_url(opensearch_url: str) -> str:
Expand Down Expand Up @@ -48,7 +52,7 @@ def translate_api_device(device: Optional[str]) -> Optional[str]:
lowered_device.startswith(acceptable),
lowered_device.replace(acceptable, ""),
acceptable
)
)
for acceptable in acceptable_devices]

try:
Expand Down Expand Up @@ -129,3 +133,47 @@ def decode_mappings(mappings: Optional[str] = None) -> dict:
return as_dict
except json.JSONDecodeError as e:
raise InvalidArgError(f"Error parsing mappings. Message: {e}")


def add_docs_params_orchestrator(index_name: str, body: Union[AddDocsBodyParams, List[Dict]],
device: str, auto_refresh: bool = True, non_tensor_fields: Optional[List[str]] = [],
mappings: Optional[dict] = dict(), model_auth: Optional[ModelAuth] = None,
image_download_headers: Optional[dict] = dict(),
use_existing_tensors: Optional[bool] = False, query_parameters: Optional[Dict] = dict()) -> AddDocsParams:
"""An orchestrator for the add_documents API to support both old and new versions of the API.
All the arguments are decoded and validated in the API function. This function is only responsible for orchestrating.
Returns:
AddDocsParams: An AddDocsParams object for internal use
"""

if isinstance(body, AddDocsBodyParams):
docs = body.documents

# Check for query parameters that are not supported in the new API
deprecated_fields = ["non_tensor_fields", "use_existing_tensors", "image_download_headers", "model_auth", "mappings"]
if any(field in query_parameters for field in deprecated_fields):
raise BadRequestError("Marqo is not accepting any of the following parameters in the query string: "
"`non_tensor_fields`, `use_existing_tensors`, `image_download_headers`, `model_auth`, `mappings`. "
"Please move these parameters to the request body as "
"`nonTensorFields`,` useExistingTensors`, `imageDownloadHeaders`, `modelAuth`, `mappings`. and try again. "
"Please check `https://docs.marqo.ai/latest/API-Reference/documents/` for the correct APIs.")

mappings = body.mappings
non_tensor_fields = body.nonTensorFields
use_existing_tensors = body.useExistingTensors
model_auth = body.modelAuth
image_download_headers = body.imageDownloadHeaders

elif isinstance(body, list) and all(isinstance(item, dict) for item in body):
docs = body

else:
raise InternalError(f"Unexpected request body type `{type(body).__name__} for `/documents` API. ")

return AddDocsParams(
index_name=index_name, docs=docs, auto_refresh=auto_refresh,
device=device, non_tensor_fields=non_tensor_fields,
use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers,
mappings=mappings, model_auth=model_auth
)
143 changes: 130 additions & 13 deletions tests/tensor_search/test_api_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pydantic
from marqo.tensor_search.models.add_docs_objects import ModelAuth
from marqo.tensor_search.models.add_docs_objects import ModelAuth, AddDocsParams, AddDocsBodyParams
from marqo.tensor_search.web.api_utils import add_docs_params_orchestrator
from marqo.tensor_search.models.private_models import S3Auth
import urllib.parse
from marqo.tensor_search.web import api_utils
from marqo.errors import InvalidArgError, InternalError
from marqo.errors import InvalidArgError, InternalError, BadRequestError
from tests.marqo_test import MarqoTestCase
import unittest


class TestApiUtils(MarqoTestCase):
Expand All @@ -26,15 +28,17 @@ def test_translate_api_device_bad(self):

def test_generate_config(self):
for opensearch_url, authorized_url in [
("http://admin:admin@localhost:9200", "http://admin:admin@localhost:9200"),
("http://localhost:9200", "http://admin:admin@localhost:9200"),
("https://admin:admin@localhost:9200", "https://admin:admin@localhost:9200"),
("https://localhost:9200", "https://admin:admin@localhost:9200"),
("http://king_user:mysecretpw@unusual.com/happy@chappy:9200", "http://king_user:mysecretpw@unusual.com/happy@chappy:9200"),
("http://unusual.com/happy@chappy:9200", "http://admin:admin@unusual.com/happy@chappy:9200"),
("http://www.unusual.com/happy@@@@#chappy:9200", "http://admin:admin@www.unusual.com/happy@@@@#chappy:9200"),
("://", "://admin:admin@")
]:
("http://admin:admin@localhost:9200", "http://admin:admin@localhost:9200"),
("http://localhost:9200", "http://admin:admin@localhost:9200"),
("https://admin:admin@localhost:9200", "https://admin:admin@localhost:9200"),
("https://localhost:9200", "https://admin:admin@localhost:9200"),
("http://king_user:mysecretpw@unusual.com/happy@chappy:9200",
"http://king_user:mysecretpw@unusual.com/happy@chappy:9200"),
("http://unusual.com/happy@chappy:9200", "http://admin:admin@unusual.com/happy@chappy:9200"),
(
"http://www.unusual.com/happy@@@@#chappy:9200", "http://admin:admin@www.unusual.com/happy@@@@#chappy:9200"),
("://", "://admin:admin@")
]:
c = api_utils.upconstruct_authorized_url(opensearch_url=opensearch_url)
assert authorized_url == c

Expand All @@ -45,7 +49,8 @@ def test_generate_config_bad_url(self):
raise AssertionError
except InternalError:
pass



class TestDecodeQueryStringModelAuth(MarqoTestCase):

def test_decode_query_string_model_auth_none(self):
Expand All @@ -71,4 +76,116 @@ def test_decode_query_string_model_auth_valid(self):

def test_decode_query_string_model_auth_invalid(self):
with self.assertRaises(pydantic.ValidationError):
api_utils.decode_query_string_model_auth("invalid_url_encoded_string")
api_utils.decode_query_string_model_auth("invalid_url_encoded_string")


class TestAddDocsParamsOrchestrator(unittest.TestCase):
def test_add_docs_params_orchestrator(self):
# Set up the arguments for the function
index_name = "test-index"
body = AddDocsBodyParams(documents=[{"test": "doc"}],
nonTensorFields=["field1"],
useExistingTensors=True,
imageDownloadHeaders={"header1": "value1"},
modelAuth=ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")),
mappings={"map1": "value1"})
device = "test-device"
auto_refresh = True

# Query parameters should be parsed as default values
non_tensor_fields = []
use_existing_tensors = False
image_download_headers = dict()
model_auth = None
mappings = dict()

# Call the function with the arguments
result = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
model_auth, image_download_headers, use_existing_tensors)

# Assert that the result is as expected
assert isinstance(result, AddDocsParams)
assert result.index_name == "test-index"
assert result.docs == body.documents
assert result.device == "test-device"
assert result.non_tensor_fields == ["field1"]
assert result.use_existing_tensors == True
assert result.docs == [{"test": "doc"}]
assert result.image_download_headers == {"header1": "value1"}

def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
# Set up the arguments for the function
index_name = "test-index"
model_auth = ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test"))

body = [{"test": "doc"}]

device = "test-device"
non_tensor_fields = ["field1"]
use_existing_tensors = True
image_download_headers = {"header1": "value1"}
model_auth = model_auth
mappings = {"map1": "value1"}
auto_refresh = True

# Call the function with the arguments
result = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
model_auth, image_download_headers, use_existing_tensors)

# Assert that the result is as expected
assert isinstance(result, AddDocsParams)
assert result.index_name == "test-index"
assert result.docs == body
assert result.device == "test-device"
assert result.non_tensor_fields == ["field1"]
assert result.use_existing_tensors == True
assert result.docs == [{"test": "doc"}]
assert result.image_download_headers == {"header1": "value1"}

def test_add_docs_params_orchestrator_error(self):
# Test the case where the function should raise an error due to invalid input
body = "invalid body type" # Not an instance of AddDocsBodyParams or List[Dict]

index_name = "test-index"
model_auth = ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test"))

device = "test-device"
non_tensor_fields = ["field1"]
use_existing_tensors = True
image_download_headers = {"header1": "value1"}
model_auth = model_auth
mappings = {"map1": "value1"}
auto_refresh = True

# Use pytest.raises to check for the error
try:
_ = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
model_auth, image_download_headers, use_existing_tensors)
except InternalError as e:
self.assertIn("Unexpected request body type", str(e))

def test_add_docs_params_orchestrator_deprecated_query_parameters_error(self):
# Test the case where the function should raise an error due to deprecated query parameters
index_name = "test-index"
model_auth = ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test"))
device = "test-device"
auto_refresh = True
body = AddDocsBodyParams(documents=[{"test": "doc"}],
nonTensorFields=["field1"],
useExistingTensors=True,
imageDownloadHeaders={"header1": "value1"},
modelAuth=ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")),
mappings={"map1": "value1"})

params = {"non_tensor_fields": ["what"], "use_existing_tensors": True,
"image_download_headers": {"header2": "value2"}, "model_auth": model_auth,
"mappings": {"map2": "value2"}}

for param, value in params.items():
kwargs = {key: None for key in params.keys()}
kwargs[param] = value
try:
add_docs_params_orchestrator(index_name, body, device, auto_refresh=auto_refresh,
query_parameters=kwargs, **kwargs)
except BadRequestError as e:
self.assertIn("Marqo is not accepting any of the following parameters in the query string", str(e))

0 comments on commit e8d078c

Please sign in to comment.