-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#289 - Add register, update and delete model group functionality to s…
…upport Model Access Control (#332) * init Signed-off-by: kalyan <kalyan.ben10@live.com> * add search, delete and update Signed-off-by: kalyan <kalyan.ben10@live.com> * add tests for register model group Signed-off-by: kalyan <kalyan.ben10@live.com> * update cluster to 2.11 Signed-off-by: kalyan <kalyan.ben10@live.com> * test skipif Signed-off-by: kalyan <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * add tests Signed-off-by: kalyan <kalyan.ben10@live.com> * update matrix Signed-off-by: kalyan <kalyan.ben10@live.com> * cancel in progress Signed-off-by: kalyan <kalyan.ben10@live.com> * update concurrency Signed-off-by: kalyan <kalyan.ben10@live.com> * job level concurrency Signed-off-by: kalyan <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * fix tests Signed-off-by: kalyan <kalyan.ben10@live.com> * tests passing Signed-off-by: kalyan <kalyan.ben10@live.com> * isort fix Signed-off-by: kalyan <kalyan.ben10@live.com> * fix action Signed-off-by: kalyan <kalyan.ben10@live.com> * fix action Signed-off-by: kalyan <kalyan.ben10@live.com> * fix action Signed-off-by: kalyan <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * update changelog Signed-off-by: kalyan <kalyan.ben10@live.com> * fix os dockerfile Signed-off-by: kalyan <kalyan.ben10@live.com> * test Signed-off-by: kalyanr <kalyan.ben10@live.com> * pass opensearch version Signed-off-by: kalyanr <kalyan.ben10@live.com> * fix Signed-off-by: kalyanr <kalyan.ben10@live.com> * fix Signed-off-by: kalyanr <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * update OS dockerfile Signed-off-by: kalyan <kalyan.ben10@live.com> * fix failing tests Signed-off-by: kalyan <kalyan.ben10@live.com> * update dockerfile for 2.11.0 Signed-off-by: kalyan <kalyan.ben10@live.com> * remove disable warning Signed-off-by: kalyan <kalyan.ben10@live.com> * fix upload model Signed-off-by: kalyan <kalyan.ben10@live.com> * fix lint Signed-off-by: kalyan <kalyan.ben10@live.com> * fix lint Signed-off-by: kalyan <kalyan.ben10@live.com> * include reference Signed-off-by: kalyan <kalyan.ben10@live.com> * pr fixes Signed-off-by: kalyan <kalyan.ben10@live.com> * lint fix Signed-off-by: kalyan <kalyan.ben10@live.com> * fix lint Signed-off-by: kalyan <kalyan.ben10@live.com> * fix tests Signed-off-by: kalyan <kalyan.ben10@live.com> * skip Signed-off-by: kalyan <kalyan.ben10@live.com> * fix lint Signed-off-by: kalyan <kalyan.ben10@live.com> * fix lint and increase coverage Signed-off-by: kalyan <kalyan.ben10@live.com> * fix lint Signed-off-by: kalyan <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * feedback fixes Signed-off-by: kalyan <kalyan.ben10@live.com> * fix Signed-off-by: kalyan <kalyan.ben10@live.com> * lint fix Signed-off-by: kalyan <kalyan.ben10@live.com> * fix test cases Signed-off-by: kalyan <kalyan.ben10@live.com> * pr feedback fixes Signed-off-by: kalyanr <kalyan.ben10@live.com> * revert Signed-off-by: kalyanr <kalyan.ben10@live.com> --------- Signed-off-by: kalyan <kalyan.ben10@live.com> Signed-off-by: kalyanr <kalyan.ben10@live.com>
- Loading branch information
Showing
11 changed files
with
663 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,18 @@ | ||
ARG OPENSEARCH_VERSION | ||
ARG OPENSEARCH_VERSION=latest | ||
FROM opensearchproject/opensearch:$OPENSEARCH_VERSION | ||
|
||
# OPENSEARCH_VERSION needs to be redefined as any arg before FROM is outside build scope. | ||
# Reference: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact | ||
ARG OPENSEARCH_VERSION=latest | ||
ARG opensearch_path=/usr/share/opensearch | ||
ARG opensearch_yml=$opensearch_path/config/opensearch.yml | ||
|
||
ARG SECURE_INTEGRATION | ||
RUN echo "plugins.ml_commons.only_run_on_ml_node: false" >> $opensearch_yml; | ||
RUN echo "plugins.ml_commons.native_memory_threshold: 100" >> $opensearch_yml; | ||
RUN if [ "$OPENSEARCH_VERSION" == "2.11.0" ] ; then \ | ||
echo "plugins.ml_commons.model_access_control_enabled: true" >> $opensearch_yml; \ | ||
echo "plugins.ml_commons.allow_registering_model_via_local_file: true" >> $opensearch_yml; \ | ||
echo "plugins.ml_commons.allow_registering_model_via_url: true" >> $opensearch_yml; \ | ||
fi | ||
RUN if [ "$SECURE_INTEGRATION" != "true" ] ; then echo "plugins.security.disabled: true" >> $opensearch_yml; fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. | ||
|
||
from typing import List, Optional | ||
|
||
from opensearchpy import OpenSearch | ||
from opensearchpy.exceptions import NotFoundError | ||
|
||
from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI | ||
from opensearch_py_ml.ml_commons.validators.model_access_control import ( | ||
validate_create_model_group_parameters, | ||
validate_delete_model_group_parameters, | ||
validate_search_model_group_parameters, | ||
validate_update_model_group_parameters, | ||
) | ||
|
||
|
||
class ModelAccessControl: | ||
API_ENDPOINT = "model_groups" | ||
|
||
def __init__(self, os_client: OpenSearch): | ||
self.client = os_client | ||
|
||
def register_model_group( | ||
self, | ||
name: str, | ||
description: Optional[str] = None, | ||
access_mode: Optional[str] = "private", | ||
backend_roles: Optional[List[str]] = None, | ||
add_all_backend_roles: Optional[bool] = False, | ||
): | ||
validate_create_model_group_parameters( | ||
name, description, access_mode, backend_roles, add_all_backend_roles | ||
) | ||
|
||
body = {"name": name, "add_all_backend_roles": add_all_backend_roles} | ||
if description: | ||
body["description"] = description | ||
if access_mode: | ||
body["access_mode"] = access_mode | ||
if backend_roles: | ||
body["backend_roles"] = backend_roles | ||
|
||
return self.client.transport.perform_request( | ||
method="POST", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_register", body=body | ||
) | ||
|
||
def update_model_group( | ||
self, | ||
update_query: dict, | ||
model_group_id: Optional[str] = None, | ||
): | ||
validate_update_model_group_parameters(update_query, model_group_id) | ||
return self.client.transport.perform_request( | ||
method="PUT", | ||
url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}", | ||
body=update_query, | ||
) | ||
|
||
def search_model_group(self, query: dict): | ||
validate_search_model_group_parameters(query) | ||
return self.client.transport.perform_request( | ||
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_search", body=query | ||
) | ||
|
||
def search_model_group_by_name( | ||
self, | ||
model_group_name: str, | ||
_source: Optional[List] = None, | ||
size: Optional[int] = 1, | ||
): | ||
query = {"query": {"match": {"name": model_group_name}}, "size": size} | ||
if _source: | ||
query["_source"] = _source | ||
return self.search_model_group(query) | ||
|
||
def get_model_group_id_by_name(self, model_group_name: str): | ||
try: | ||
res = self.search_model_group_by_name(model_group_name) | ||
if res["hits"]["hits"]: | ||
return res["hits"]["hits"][0]["_id"] | ||
else: | ||
raise NotFoundError | ||
except NotFoundError: | ||
print(f"No model group found with name:{model_group_name}") | ||
return None | ||
except Exception as ex: | ||
print(f"Error in get_model_group_id_by_name: {ex}") | ||
return None | ||
|
||
def delete_model_group(self, model_group_id: str): | ||
validate_delete_model_group_parameters(model_group_id) | ||
return self.client.transport.perform_request( | ||
method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}" | ||
) | ||
|
||
def delete_model_group_by_name(self, model_group_name: str): | ||
model_group_id = self.get_model_group_id_by_name(model_group_name) | ||
if model_group_id is None: | ||
raise NotFoundError(f"Model group {model_group_name} not found") | ||
return self.delete_model_group(model_group_id=model_group_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. |
97 changes: 97 additions & 0 deletions
97
opensearch_py_ml/ml_commons/validators/model_access_control.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
# Any modifications Copyright OpenSearch Contributors. See | ||
# GitHub history for details. | ||
|
||
""" Module for validating model access control parameters """ | ||
|
||
from typing import List, Optional | ||
|
||
ACCESS_MODES = ["public", "private", "restricted"] | ||
|
||
NoneType = type(None) | ||
|
||
|
||
def _validate_model_group_name(name: str): | ||
if not name or not isinstance(name, str): | ||
raise ValueError("name is required and needs to be a string") | ||
|
||
|
||
def _validate_model_group_description(description: Optional[str]): | ||
if not isinstance(description, (NoneType, str)): | ||
raise ValueError("description needs to be a string") | ||
|
||
|
||
def _validate_model_group_access_mode(access_mode: Optional[str]): | ||
if access_mode is None: | ||
return | ||
if access_mode not in ACCESS_MODES: | ||
raise ValueError(f"access_mode can must be in {ACCESS_MODES} or None") | ||
|
||
|
||
def _validate_model_group_backend_roles(backend_roles: Optional[List[str]]): | ||
if not isinstance(backend_roles, (NoneType, list)): | ||
raise ValueError("backend_roles should either be None or a list of roles names") | ||
|
||
|
||
def _validate_model_group_add_all_backend_roles(add_all_backend_roles: Optional[bool]): | ||
if not isinstance(add_all_backend_roles, (NoneType, bool)): | ||
raise ValueError("add_all_backend_roles should be a boolean") | ||
|
||
|
||
def _validate_model_group_query(query: dict, operation: Optional[str] = None): | ||
if not isinstance(query, dict): | ||
raise ValueError("query needs to be a dictionary") | ||
|
||
if operation and not isinstance(operation, str): | ||
raise ValueError("operation needs to be a string") | ||
|
||
|
||
def validate_create_model_group_parameters( | ||
name: str, | ||
description: Optional[str] = None, | ||
access_mode: Optional[str] = "private", | ||
backend_roles: Optional[List[str]] = None, | ||
add_all_backend_roles: Optional[bool] = False, | ||
): | ||
_validate_model_group_name(name) | ||
_validate_model_group_description(description) | ||
_validate_model_group_access_mode(access_mode) | ||
_validate_model_group_backend_roles(backend_roles) | ||
_validate_model_group_add_all_backend_roles(add_all_backend_roles) | ||
|
||
if access_mode == "restricted": | ||
if not backend_roles and not add_all_backend_roles: | ||
raise ValueError( | ||
"You must specify either backend_roles or add_all_backend_roles=True for restricted access_mode" | ||
) | ||
|
||
if backend_roles and add_all_backend_roles: | ||
raise ValueError( | ||
"You cannot specify both backend_roles and add_all_backend_roles=True at the same time" | ||
) | ||
|
||
elif access_mode == "private": | ||
if backend_roles or add_all_backend_roles: | ||
raise ValueError( | ||
"You must not specify backend_roles or add_all_backend_roles=True for a private model group" | ||
) | ||
|
||
|
||
def validate_update_model_group_parameters(update_query: dict, model_group_id: str): | ||
if not isinstance(model_group_id, str): | ||
raise ValueError("Invalid model_group_id. model_group_id needs to be a string") | ||
|
||
if not isinstance(update_query, dict): | ||
raise ValueError("Invalid update_query. update_query needs to be a dictionary") | ||
|
||
|
||
def validate_delete_model_group_parameters(model_group_id: str): | ||
if not isinstance(model_group_id, str): | ||
raise ValueError("Invalid model_group_id. model_group_id needs to be a string") | ||
|
||
|
||
def validate_search_model_group_parameters(query: dict): | ||
_validate_model_group_query(query) |
Oops, something went wrong.