diff --git a/.ci/opensearch/Dockerfile b/.ci/opensearch/Dockerfile index d3c3d525..74fce351 100755 --- a/.ci/opensearch/Dockerfile +++ b/.ci/opensearch/Dockerfile @@ -1,10 +1,18 @@ -ARG OPENSEARCH_VERSION +ARG OPENSEARCH_VERSION=latest FROM opensearchproject/opensearch:$OPENSEARCH_VERSION +# OPENSEARCH_VERSION needs to be redefined as any arg before FROM is outside build scope. +# Reference: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact +ARG OPENSEARCH_VERSION=latest ARG opensearch_path=/usr/share/opensearch ARG opensearch_yml=$opensearch_path/config/opensearch.yml ARG SECURE_INTEGRATION RUN echo "plugins.ml_commons.only_run_on_ml_node: false" >> $opensearch_yml; RUN echo "plugins.ml_commons.native_memory_threshold: 100" >> $opensearch_yml; +RUN if [ "$OPENSEARCH_VERSION" == "2.11.0" ] ; then \ + echo "plugins.ml_commons.model_access_control_enabled: true" >> $opensearch_yml; \ + echo "plugins.ml_commons.allow_registering_model_via_local_file: true" >> $opensearch_yml; \ + echo "plugins.ml_commons.allow_registering_model_via_url: true" >> $opensearch_yml; \ +fi RUN if [ "$SECURE_INTEGRATION" != "true" ] ; then echo "plugins.security.disabled: true" >> $opensearch_yml; fi diff --git a/.ci/run-opensearch.sh b/.ci/run-opensearch.sh index 54de7660..4e1fcd3b 100644 --- a/.ci/run-opensearch.sh +++ b/.ci/run-opensearch.sh @@ -4,7 +4,6 @@ # to form a cluster suitable for running the REST API tests. # # Export the NUMBER_OF_NODES variable to start more than 1 node - script_path=$(dirname $(realpath -s $0)) source $script_path/imports.sh set -euo pipefail diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e36c7735..1640df98 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -2,6 +2,10 @@ name: Integration tests on: [push, pull_request] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: integration: name: Integ @@ -13,6 +17,7 @@ jobs: secured: ["true"] entry: - { opensearch_version: 2.7.0 } + - { opensearch_version: 2.11.0 } steps: - name: Checkout diff --git a/CHANGELOG.md b/CHANGELOG.md index a810d313..d6c19bd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add script to trigger ml-models-release jenkins workflow with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) - Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283)) - Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310)) +- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332)) ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) diff --git a/opensearch_py_ml/ml_commons/ml_commons_client.py b/opensearch_py_ml/ml_commons/ml_commons_client.py index 99162e30..72e2e158 100644 --- a/opensearch_py_ml/ml_commons/ml_commons_client.py +++ b/opensearch_py_ml/ml_commons/ml_commons_client.py @@ -21,6 +21,7 @@ MODEL_VERSION_FIELD, TIMEOUT, ) +from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl from opensearch_py_ml.ml_commons.model_execute import ModelExecute from opensearch_py_ml.ml_commons.model_uploader import ModelUploader @@ -35,6 +36,7 @@ def __init__(self, os_client: OpenSearch): self._client = os_client self._model_uploader = ModelUploader(os_client) self._model_execute = ModelExecute(os_client) + self.model_access_control = ModelAccessControl(os_client) def execute(self, algorithm_name: str, input_json: dict) -> dict: """ diff --git a/opensearch_py_ml/ml_commons/model_access_control.py b/opensearch_py_ml/ml_commons/model_access_control.py new file mode 100644 index 00000000..bae4e603 --- /dev/null +++ b/opensearch_py_ml/ml_commons/model_access_control.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +from typing import List, Optional + +from opensearchpy import OpenSearch +from opensearchpy.exceptions import NotFoundError + +from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI +from opensearch_py_ml.ml_commons.validators.model_access_control import ( + validate_create_model_group_parameters, + validate_delete_model_group_parameters, + validate_search_model_group_parameters, + validate_update_model_group_parameters, +) + + +class ModelAccessControl: + API_ENDPOINT = "model_groups" + + def __init__(self, os_client: OpenSearch): + self.client = os_client + + def register_model_group( + self, + name: str, + description: Optional[str] = None, + access_mode: Optional[str] = "private", + backend_roles: Optional[List[str]] = None, + add_all_backend_roles: Optional[bool] = False, + ): + validate_create_model_group_parameters( + name, description, access_mode, backend_roles, add_all_backend_roles + ) + + body = {"name": name, "add_all_backend_roles": add_all_backend_roles} + if description: + body["description"] = description + if access_mode: + body["access_mode"] = access_mode + if backend_roles: + body["backend_roles"] = backend_roles + + return self.client.transport.perform_request( + method="POST", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_register", body=body + ) + + def update_model_group( + self, + update_query: dict, + model_group_id: Optional[str] = None, + ): + validate_update_model_group_parameters(update_query, model_group_id) + return self.client.transport.perform_request( + method="PUT", + url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}", + body=update_query, + ) + + def search_model_group(self, query: dict): + validate_search_model_group_parameters(query) + return self.client.transport.perform_request( + method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_search", body=query + ) + + def search_model_group_by_name( + self, + model_group_name: str, + _source: Optional[List] = None, + size: Optional[int] = 1, + ): + query = {"query": {"match": {"name": model_group_name}}, "size": size} + if _source: + query["_source"] = _source + return self.search_model_group(query) + + def get_model_group_id_by_name(self, model_group_name: str): + try: + res = self.search_model_group_by_name(model_group_name) + if res["hits"]["hits"]: + return res["hits"]["hits"][0]["_id"] + else: + raise NotFoundError + except NotFoundError: + print(f"No model group found with name:{model_group_name}") + return None + except Exception as ex: + print(f"Error in get_model_group_id_by_name: {ex}") + return None + + def delete_model_group(self, model_group_id: str): + validate_delete_model_group_parameters(model_group_id) + return self.client.transport.perform_request( + method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}" + ) + + def delete_model_group_by_name(self, model_group_name: str): + model_group_id = self.get_model_group_id_by_name(model_group_name) + if model_group_id is None: + raise NotFoundError(f"Model group {model_group_name} not found") + return self.delete_model_group(model_group_id=model_group_id) diff --git a/opensearch_py_ml/ml_commons/validators/__init__.py b/opensearch_py_ml/ml_commons/validators/__init__.py new file mode 100644 index 00000000..8d89f258 --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. diff --git a/opensearch_py_ml/ml_commons/validators/model_access_control.py b/opensearch_py_ml/ml_commons/validators/model_access_control.py new file mode 100644 index 00000000..2fb928e2 --- /dev/null +++ b/opensearch_py_ml/ml_commons/validators/model_access_control.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +""" Module for validating model access control parameters """ + +from typing import List, Optional + +ACCESS_MODES = ["public", "private", "restricted"] + +NoneType = type(None) + + +def _validate_model_group_name(name: str): + if not name or not isinstance(name, str): + raise ValueError("name is required and needs to be a string") + + +def _validate_model_group_description(description: Optional[str]): + if not isinstance(description, (NoneType, str)): + raise ValueError("description needs to be a string") + + +def _validate_model_group_access_mode(access_mode: Optional[str]): + if access_mode is None: + return + if access_mode not in ACCESS_MODES: + raise ValueError(f"access_mode can must be in {ACCESS_MODES} or None") + + +def _validate_model_group_backend_roles(backend_roles: Optional[List[str]]): + if not isinstance(backend_roles, (NoneType, list)): + raise ValueError("backend_roles should either be None or a list of roles names") + + +def _validate_model_group_add_all_backend_roles(add_all_backend_roles: Optional[bool]): + if not isinstance(add_all_backend_roles, (NoneType, bool)): + raise ValueError("add_all_backend_roles should be a boolean") + + +def _validate_model_group_query(query: dict, operation: Optional[str] = None): + if not isinstance(query, dict): + raise ValueError("query needs to be a dictionary") + + if operation and not isinstance(operation, str): + raise ValueError("operation needs to be a string") + + +def validate_create_model_group_parameters( + name: str, + description: Optional[str] = None, + access_mode: Optional[str] = "private", + backend_roles: Optional[List[str]] = None, + add_all_backend_roles: Optional[bool] = False, +): + _validate_model_group_name(name) + _validate_model_group_description(description) + _validate_model_group_access_mode(access_mode) + _validate_model_group_backend_roles(backend_roles) + _validate_model_group_add_all_backend_roles(add_all_backend_roles) + + if access_mode == "restricted": + if not backend_roles and not add_all_backend_roles: + raise ValueError( + "You must specify either backend_roles or add_all_backend_roles=True for restricted access_mode" + ) + + if backend_roles and add_all_backend_roles: + raise ValueError( + "You cannot specify both backend_roles and add_all_backend_roles=True at the same time" + ) + + elif access_mode == "private": + if backend_roles or add_all_backend_roles: + raise ValueError( + "You must not specify backend_roles or add_all_backend_roles=True for a private model group" + ) + + +def validate_update_model_group_parameters(update_query: dict, model_group_id: str): + if not isinstance(model_group_id, str): + raise ValueError("Invalid model_group_id. model_group_id needs to be a string") + + if not isinstance(update_query, dict): + raise ValueError("Invalid update_query. update_query needs to be a dictionary") + + +def validate_delete_model_group_parameters(model_group_id: str): + if not isinstance(model_group_id, str): + raise ValueError("Invalid model_group_id. model_group_id needs to be a string") + + +def validate_search_model_group_parameters(query: dict): + _validate_model_group_query(query) diff --git a/tests/ml_commons/test_ml_commons_client.py b/tests/ml_commons/test_ml_commons_client.py index 86c5af24..27cd79dc 100644 --- a/tests/ml_commons/test_ml_commons_client.py +++ b/tests/ml_commons/test_ml_commons_client.py @@ -300,7 +300,6 @@ def test_DEPRECATED_integration_model_train_upload_full_cycle(): assert model_file_exists == True assert model_config_file_exists == True if model_file_exists and model_config_file_exists: - raised = False model_id = "" task_id = "" try: @@ -308,12 +307,10 @@ def test_DEPRECATED_integration_model_train_upload_full_cycle(): MODEL_PATH, MODEL_CONFIG_FILE_PATH, load_model=False, isVerbose=True ) print("Model_id:", model_id) - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception during model registration" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when uploading model: {ex}" if model_id: - raised = False try: ml_load_status = ml_client.load_model(model_id, wait_until_loaded=False) task_id = ml_load_status.get("task_id") @@ -321,21 +318,17 @@ def test_DEPRECATED_integration_model_train_upload_full_cycle(): ml_model_status = ml_client.get_model_info(model_id) assert ml_model_status.get("model_state") != "DEPLOY_FAILED" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in model deployment" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when loading model: {ex}" - raised = False try: ml_model_status = ml_client.get_model_info(model_id) assert ml_model_status.get("model_format") == "TORCH_SCRIPT" assert ml_model_status.get("algorithm") == "TEXT_EMBEDDING" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in getting model info" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when getting model info: {ex}" if task_id: - raised = False ml_task_status = None try: ml_task_status = ml_client.get_task_info( @@ -344,47 +337,39 @@ def test_DEPRECATED_integration_model_train_upload_full_cycle(): assert ml_task_status.get("task_type") == "DEPLOY_MODEL" print("State:", ml_task_status.get("state")) assert ml_task_status.get("state") != "FAILED" - except: # noqa: E722 - print("Model Task Status:", ml_task_status) - raised = True - assert raised == False, "Raised Exception in pulling task info" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when getting task info: {ex}" # This is test is being flaky. Sometimes the test is passing and sometimes showing 500 error # due to memory circuit breaker. # Todo: We need to revisit this test. try: - raised = False sentences = ["First test sentence", "Second test sentence"] embedding_result = ml_client.generate_embedding(model_id, sentences) print(embedding_result) assert len(embedding_result.get("inference_results")) == 2 - except: # noqa: E722 - raised = True - assert ( - raised == False - ), "Raised Exception in generating sentence embedding" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when generating embedding: {ex}" try: delete_task_obj = ml_client.delete_task(task_id) assert delete_task_obj.get("result") == "deleted" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in deleting task" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when deleting task: {ex}" try: ml_client.unload_model(model_id) ml_model_status = ml_client.get_model_info(model_id) assert ml_model_status.get("model_state") != "UNDEPLOY_FAILED" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in model undeployment" + except Exception as ex: # noqa: E722 + assert ( + False + ), f"Exception occurred when pretrained model undeployment : {ex}" - raised = False try: delete_model_obj = ml_client.delete_model(model_id) assert delete_model_obj.get("result") == "deleted" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in deleting model" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when deleting model: {ex}" def test_integration_model_train_register_full_cycle(): @@ -408,7 +393,6 @@ def test_integration_model_train_register_full_cycle(): task_id = "" # Testing deploy_model = True for codecov/patch - raised = False try: ml_client.register_model( model_path=MODEL_PATH, @@ -416,11 +400,9 @@ def test_integration_model_train_register_full_cycle(): deploy_model=True, isVerbose=True, ) - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception during first model registration" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred during first model registration: {ex}" - raised = False try: model_id = ml_client.register_model( model_path=MODEL_PATH, @@ -429,12 +411,10 @@ def test_integration_model_train_register_full_cycle(): isVerbose=True, ) print("Model_id:", model_id) - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception during second model registration" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred during second model registration: {ex}" if model_id: - raised = False try: ml_load_status = ml_client.deploy_model( model_id, wait_until_deployed=False @@ -444,21 +424,17 @@ def test_integration_model_train_register_full_cycle(): ml_model_status = ml_client.get_model_info(model_id) assert ml_model_status.get("model_state") != "DEPLOY_FAILED" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in model deployment" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred during model deployment: {ex}" - raised = False try: ml_model_status = ml_client.get_model_info(model_id) assert ml_model_status.get("model_format") == "TORCH_SCRIPT" assert ml_model_status.get("algorithm") == "TEXT_EMBEDDING" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in getting model info" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when getting model info: {ex}" if task_id: - raised = False ml_task_status = None try: ml_task_status = ml_client.get_task_info( @@ -467,48 +443,40 @@ def test_integration_model_train_register_full_cycle(): assert ml_task_status.get("task_type") == "DEPLOY_MODEL" print("State:", ml_task_status.get("state")) assert ml_task_status.get("state") != "FAILED" - except: # noqa: E722 - print("Model Task Status:", ml_task_status) - raised = True - assert raised == False, "Raised Exception in pulling task info" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred in pulling task info: {ex}" # This is test is being flaky. Sometimes the test is passing and sometimes showing 500 error # due to memory circuit breaker. # Todo: We need to revisit this test. try: - raised = False sentences = ["First test sentence", "Second test sentence"] embedding_result = ml_client.generate_embedding(model_id, sentences) print(embedding_result) assert len(embedding_result.get("inference_results")) == 2 - except: # noqa: E722 - raised = True - assert ( - raised == False - ), "Raised Exception in generating sentence embedding" + except Exception as ex: # noqa: E722 + assert ( + False + ), f"Exception occurred when generating sentence embedding: {ex}" try: delete_task_obj = ml_client.delete_task(task_id) assert delete_task_obj.get("result") == "deleted" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in deleting task" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred when deleting task: {ex}" try: ml_client.undeploy_model(model_id) ml_model_status = ml_client.get_model_info(model_id) assert ml_model_status.get("model_state") != "UNDEPLOY_FAILED" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in model undeployment" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred during model undeployment : {ex}" - raised = False try: delete_model_obj = ml_client.delete_model(model_id) assert delete_model_obj.get("result") == "deleted" - except: # noqa: E722 - raised = True - assert raised == False, "Raised Exception in deleting model" + except Exception as ex: # noqa: E722 + assert False, f"Exception occurred during model deletion : {ex}" def test_search(): diff --git a/tests/ml_commons/test_model_access_control.py b/tests/ml_commons/test_model_access_control.py new file mode 100644 index 00000000..8666baac --- /dev/null +++ b/tests/ml_commons/test_model_access_control.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +import os +import time +from unittest.mock import patch + +import pytest +from opensearchpy.exceptions import NotFoundError, RequestError +from packaging.version import parse as parse_version + +from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl +from tests import OPENSEARCH_TEST_CLIENT + +OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0")) + +# MAC = Model Access Control +# Minimum opensearch version that supports Model Access Control. +MAC_MIN_VERSION = parse_version("2.8.0") + +# Minimum Opensearch version that supports Model group updates +MAC_UPDATE_MIN_VERSION = parse_version("2.11.0") + + +@pytest.fixture +def client(): + return ModelAccessControl(OPENSEARCH_TEST_CLIENT) + + +def _safe_delete_model_group(client, model_group_name): + try: + client.delete_model_group_by_name(model_group_name=model_group_name) + except NotFoundError: + pass + + +@pytest.fixture +def test_model_group(client): + model_group_name = "__test__model_group_1" + _safe_delete_model_group(client, model_group_name) + time.sleep(2) + client.register_model_group( + name=model_group_name, + description="test model group for opensearch-py-ml test cases", + ) + yield model_group_name + + _safe_delete_model_group(client, model_group_name) + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_register_model_group(client): + model_group_name1 = "__test__model_group_A" + try: + _safe_delete_model_group(client, model_group_name1) + time.sleep(2) + res = client.register_model_group(name=model_group_name1) + assert isinstance(res, dict) + assert "model_group_id" in res + assert "status" in res + assert res["status"] == "CREATED" + except Exception as ex: + assert False, f"Failed to register model group due to {ex}" + + model_group_name2 = "__test__model_group_B" + + try: + _safe_delete_model_group(client, model_group_name2) + time.sleep(2) + res = client.register_model_group( + name=model_group_name2, + description="test", + access_mode="restricted", + backend_roles=["admin"], + ) + assert "model_group_id" in res + assert "status" in res + assert res["status"] == "CREATED" + except Exception as ex: + assert False, f"Failed to register restricted model group due to {ex}" + + model_group_name3 = "__test__model_group_C" + with pytest.raises(RequestError) as exec_info: + _safe_delete_model_group(client, model_group_name3) + time.sleep(2) + res = client.register_model_group( + name=model_group_name3, + description="test", + access_mode="restricted", + add_all_backend_roles=True, + ) + assert exec_info.value.status_code == 400 + assert exec_info.match("Admin users cannot add all backend roles to a model group") + + with pytest.raises(RequestError) as exec_info: + client.register_model_group(name=model_group_name2) + assert exec_info.value.status_code == 400 + assert exec_info.match( + "The name you provided is already being used by a model group" + ) + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_get_model_group_id_by_name(client, test_model_group): + model_group_id = client.get_model_group_id_by_name(test_model_group) + assert model_group_id is not None + + model_group_id = client.get_model_group_id_by_name("test-unknown") + assert model_group_id is None + + # Mock NotFoundError as it only happens when index isn't created + with patch.object(client, "search_model_group_by_name", side_effect=NotFoundError): + model_group_id = client.get_model_group_id_by_name(test_model_group) + assert model_group_id is None + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_UPDATE_MIN_VERSION, + reason="Model groups updates are supported in OpenSearch 2.11.0 and above", +) +def test_update_model_group(client, test_model_group): + # update model group name and description + update_query = { + "description": "updated description", + } + try: + model_group_id = client.get_model_group_id_by_name(test_model_group) + if model_group_id is None: + raise Exception(f"No model group found with the name: {test_model_group}") + res = client.update_model_group(update_query, model_group_id=model_group_id) + assert isinstance(res, dict) + assert "status" in res + assert res["status"] == "Updated" + except Exception as ex: + assert False, f"Failed to search model group due to unhandled error: {ex}" + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_search_model_group(client, test_model_group): + query1 = {"query": {"match": {"name": test_model_group}}, "size": 1} + try: + res = client.search_model_group(query1) + assert isinstance(res, dict) + assert "hits" in res and "hits" in res["hits"] + assert len(res["hits"]["hits"]) == 1 + assert "_source" in res["hits"]["hits"][0] + assert "name" in res["hits"]["hits"][0]["_source"] + assert test_model_group == res["hits"]["hits"][0]["_source"]["name"] + except Exception as ex: + assert False, f"Failed to search model group due to unhandled error: {ex}" + + query2 = {"query": {"match": {"name": "test-unknown"}}, "size": 1} + try: + res = client.search_model_group(query2) + assert isinstance(res, dict) + assert "hits" in res and "hits" in res["hits"] + assert len(res["hits"]["hits"]) == 0 + except Exception as ex: + assert False, f"Failed to search model group due to unhandled error: {ex}" + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_search_model_group_by_name(client, test_model_group): + try: + res = client.search_model_group_by_name(model_group_name=test_model_group) + assert isinstance(res, dict) + assert "hits" in res and "hits" in res["hits"] + assert len(res["hits"]["hits"]) == 1 + assert "_source" in res["hits"]["hits"][0] + assert len(res["hits"]["hits"][0]["_source"]) > 1 + assert "name" in res["hits"]["hits"][0]["_source"] + assert test_model_group == res["hits"]["hits"][0]["_source"]["name"] + except Exception as ex: + assert False, f"Failed to search model group due to unhandled error: {ex}" + + try: + res = client.search_model_group_by_name( + model_group_name=test_model_group, _source="name" + ) + assert isinstance(res, dict) + assert "hits" in res and "hits" in res["hits"] + assert len(res["hits"]["hits"]) == 1 + assert "_source" in res["hits"]["hits"][0] + assert len(res["hits"]["hits"][0]["_source"]) == 1 + assert "name" in res["hits"]["hits"][0]["_source"] + except Exception as ex: + assert False, f"Failed to search model group due to unhandled error: {ex}" + + try: + res = client.search_model_group_by_name(model_group_name="test-unknown") + assert isinstance(res, dict) + assert "hits" in res and "hits" in res["hits"] + assert len(res["hits"]["hits"]) == 0 + except Exception as ex: + assert False, f"Failed to search model group due to unhandled error: {ex}" + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_delete_model_group(client, test_model_group): + # create a test model group + + for each in "AB": + model_group_name = f"__test__model_group_{each}" + model_group_id = client.get_model_group_id_by_name(model_group_name) + if model_group_id is None: + continue + res = client.delete_model_group(model_group_id=model_group_id) + assert res is None or isinstance(res, dict) + if isinstance(res, dict): + assert "result" in res + assert res["result"] in ["not_found", "deleted"] + + res = client.delete_model_group(model_group_id="test-unknown") + assert isinstance(res, dict) + assert "result" in res + assert res["result"] == "not_found" + + +@pytest.mark.skipif( + OPENSEARCH_VERSION < MAC_MIN_VERSION, + reason="Model groups are supported in OpenSearch 2.8.0 and above", +) +def test_delete_model_group_by_name(client): + with pytest.raises(NotFoundError): + client.delete_model_group_by_name(model_group_name="test-unknown") + + model_group = "__test__model_group_5" + client.register_model_group(name=model_group) + time.sleep(2) + res = client.delete_model_group_by_name(model_group_name=model_group) + assert isinstance(res, dict) + assert "result" in res + assert res["result"] == "deleted" diff --git a/tests/ml_commons/test_validators/test_model_access_control_validators.py b/tests/ml_commons/test_validators/test_model_access_control_validators.py new file mode 100644 index 00000000..d5701e70 --- /dev/null +++ b/tests/ml_commons/test_validators/test_model_access_control_validators.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# Any modifications Copyright OpenSearch Contributors. See +# GitHub history for details. + +import pytest + +from opensearch_py_ml.ml_commons.validators.model_access_control import ( + _validate_model_group_access_mode, + _validate_model_group_add_all_backend_roles, + _validate_model_group_backend_roles, + _validate_model_group_description, + _validate_model_group_name, + _validate_model_group_query, + validate_create_model_group_parameters, + validate_delete_model_group_parameters, + validate_search_model_group_parameters, + validate_update_model_group_parameters, +) + + +def test_validate_model_group_name(): + with pytest.raises(ValueError): + _validate_model_group_name(None) + + with pytest.raises(ValueError): + _validate_model_group_name("") + + with pytest.raises(ValueError): + _validate_model_group_name(123) + + res = _validate_model_group_name("ValidName") + assert res is None + + +def test_validate_model_group_description(): + with pytest.raises(ValueError): + _validate_model_group_description(123) + + res = _validate_model_group_description("") + assert res is None + + res = _validate_model_group_description(None) + assert res is None + + res = _validate_model_group_description("ValidName") + assert res is None + + +def test_validate_model_group_access_mode(): + with pytest.raises(ValueError): + _validate_model_group_access_mode(123) + + res = _validate_model_group_access_mode("private") + assert res is None + + res = _validate_model_group_access_mode("restricted") + assert res is None + + res = _validate_model_group_access_mode(None) + assert res is None + + +def test_validate_model_group_backend_roles(): + with pytest.raises(ValueError): + _validate_model_group_backend_roles(123) + + res = _validate_model_group_backend_roles(["admin"]) + assert res is None + + res = _validate_model_group_backend_roles(None) + assert res is None + + +def test_validate_model_group_add_all_backend_roles(): + with pytest.raises(ValueError): + _validate_model_group_add_all_backend_roles(123) + + res = _validate_model_group_add_all_backend_roles(False) + assert res is None + + res = _validate_model_group_add_all_backend_roles(True) + assert res is None + + res = _validate_model_group_add_all_backend_roles(None) + assert res is None + + +def test_validate_model_group_query(): + with pytest.raises(ValueError): + _validate_model_group_query(123) + + res = _validate_model_group_query({}) + assert res is None + + with pytest.raises(ValueError): + _validate_model_group_query(None) + + res = _validate_model_group_query({"query": {"match": {"name": "test"}}}) + assert res is None + + with pytest.raises(ValueError): + _validate_model_group_query({}, 123) + + +def test_validate_create_model_group_parameters(): + with pytest.raises(ValueError): + validate_create_model_group_parameters(123) + + res = validate_create_model_group_parameters("test") + assert res is None + + with pytest.raises(ValueError): + validate_create_model_group_parameters("test", access_mode="restricted") + + with pytest.raises(ValueError): + validate_create_model_group_parameters( + "test", access_mode="private", add_all_backend_roles=True + ) + + +def test_validate_update_model_group_parameters(): + with pytest.raises(ValueError): + validate_update_model_group_parameters(123, 123) + + with pytest.raises(ValueError): + validate_update_model_group_parameters(123, "id") + + res = validate_update_model_group_parameters({"query": {}}, "test") + assert res is None + + +def test_validate_delete_model_group_parameters(): + with pytest.raises(ValueError): + validate_delete_model_group_parameters(123) + + res = validate_delete_model_group_parameters("test") + assert res is None + + +def test_validate_search_model_group_parameters(): + with pytest.raises(ValueError): + validate_search_model_group_parameters(123) + + res = validate_search_model_group_parameters({"query": {}}) + assert res is None