Skip to content

Commit

Permalink
#289 - Add register, update and delete model group functionality to s…
Browse files Browse the repository at this point in the history
…upport Model Access Control (#332)

* init

Signed-off-by: kalyan <kalyan.ben10@live.com>

* add search, delete and update

Signed-off-by: kalyan <kalyan.ben10@live.com>

* add tests for register model group

Signed-off-by: kalyan <kalyan.ben10@live.com>

* update cluster to 2.11

Signed-off-by: kalyan <kalyan.ben10@live.com>

* test skipif

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* add tests

Signed-off-by: kalyan <kalyan.ben10@live.com>

* update matrix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* cancel in progress

Signed-off-by: kalyan <kalyan.ben10@live.com>

* update concurrency

Signed-off-by: kalyan <kalyan.ben10@live.com>

* job level concurrency

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix tests

Signed-off-by: kalyan <kalyan.ben10@live.com>

* tests passing

Signed-off-by: kalyan <kalyan.ben10@live.com>

* isort fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix action

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix action

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix action

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* update changelog

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix os dockerfile

Signed-off-by: kalyan <kalyan.ben10@live.com>

* test

Signed-off-by: kalyanr <kalyan.ben10@live.com>

* pass opensearch version

Signed-off-by: kalyanr <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyanr <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyanr <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* update OS dockerfile

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix failing tests

Signed-off-by: kalyan <kalyan.ben10@live.com>

* update dockerfile for 2.11.0

Signed-off-by: kalyan <kalyan.ben10@live.com>

* remove disable warning

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix upload model

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix lint

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix lint

Signed-off-by: kalyan <kalyan.ben10@live.com>

* include reference

Signed-off-by: kalyan <kalyan.ben10@live.com>

* pr fixes

Signed-off-by: kalyan <kalyan.ben10@live.com>

* lint fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix lint

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix tests

Signed-off-by: kalyan <kalyan.ben10@live.com>

* skip

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix lint

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix lint and increase coverage

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix lint

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* feedback fixes

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* lint fix

Signed-off-by: kalyan <kalyan.ben10@live.com>

* fix test cases

Signed-off-by: kalyan <kalyan.ben10@live.com>

* pr feedback fixes

Signed-off-by: kalyanr <kalyan.ben10@live.com>

* revert

Signed-off-by: kalyanr <kalyan.ben10@live.com>

---------

Signed-off-by: kalyan <kalyan.ben10@live.com>
Signed-off-by: kalyanr <kalyan.ben10@live.com>
  • Loading branch information
rawwar committed Nov 13, 2023
1 parent fc698a7 commit 2c3b744
Show file tree
Hide file tree
Showing 11 changed files with 663 additions and 72 deletions.
10 changes: 9 additions & 1 deletion .ci/opensearch/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
ARG OPENSEARCH_VERSION
ARG OPENSEARCH_VERSION=latest
FROM opensearchproject/opensearch:$OPENSEARCH_VERSION

# OPENSEARCH_VERSION needs to be redefined as any arg before FROM is outside build scope.
# Reference: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
ARG OPENSEARCH_VERSION=latest
ARG opensearch_path=/usr/share/opensearch
ARG opensearch_yml=$opensearch_path/config/opensearch.yml

ARG SECURE_INTEGRATION
RUN echo "plugins.ml_commons.only_run_on_ml_node: false" >> $opensearch_yml;
RUN echo "plugins.ml_commons.native_memory_threshold: 100" >> $opensearch_yml;
RUN if [ "$OPENSEARCH_VERSION" == "2.11.0" ] ; then \
echo "plugins.ml_commons.model_access_control_enabled: true" >> $opensearch_yml; \
echo "plugins.ml_commons.allow_registering_model_via_local_file: true" >> $opensearch_yml; \
echo "plugins.ml_commons.allow_registering_model_via_url: true" >> $opensearch_yml; \
fi
RUN if [ "$SECURE_INTEGRATION" != "true" ] ; then echo "plugins.security.disabled: true" >> $opensearch_yml; fi
1 change: 0 additions & 1 deletion .ci/run-opensearch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# to form a cluster suitable for running the REST API tests.
#
# Export the NUMBER_OF_NODES variable to start more than 1 node

script_path=$(dirname $(realpath -s $0))
source $script_path/imports.sh
set -euo pipefail
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ name: Integration tests

on: [push, pull_request]

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
integration:
name: Integ
Expand All @@ -13,6 +17,7 @@ jobs:
secured: ["true"]
entry:
- { opensearch_version: 2.7.0 }
- { opensearch_version: 2.11.0 }

steps:
- name: Checkout
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add script to trigger ml-models-release jenkins workflow with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
- Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283))
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))

### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
Expand Down
2 changes: 2 additions & 0 deletions opensearch_py_ml/ml_commons/ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
MODEL_VERSION_FIELD,
TIMEOUT,
)
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader

Expand All @@ -35,6 +36,7 @@ def __init__(self, os_client: OpenSearch):
self._client = os_client
self._model_uploader = ModelUploader(os_client)
self._model_execute = ModelExecute(os_client)
self.model_access_control = ModelAccessControl(os_client)

def execute(self, algorithm_name: str, input_json: dict) -> dict:
"""
Expand Down
105 changes: 105 additions & 0 deletions opensearch_py_ml/ml_commons/model_access_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

from typing import List, Optional

from opensearchpy import OpenSearch
from opensearchpy.exceptions import NotFoundError

from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI
from opensearch_py_ml.ml_commons.validators.model_access_control import (
validate_create_model_group_parameters,
validate_delete_model_group_parameters,
validate_search_model_group_parameters,
validate_update_model_group_parameters,
)


class ModelAccessControl:
API_ENDPOINT = "model_groups"

def __init__(self, os_client: OpenSearch):
self.client = os_client

def register_model_group(
self,
name: str,
description: Optional[str] = None,
access_mode: Optional[str] = "private",
backend_roles: Optional[List[str]] = None,
add_all_backend_roles: Optional[bool] = False,
):
validate_create_model_group_parameters(
name, description, access_mode, backend_roles, add_all_backend_roles
)

body = {"name": name, "add_all_backend_roles": add_all_backend_roles}
if description:
body["description"] = description
if access_mode:
body["access_mode"] = access_mode
if backend_roles:
body["backend_roles"] = backend_roles

return self.client.transport.perform_request(
method="POST", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_register", body=body
)

def update_model_group(
self,
update_query: dict,
model_group_id: Optional[str] = None,
):
validate_update_model_group_parameters(update_query, model_group_id)
return self.client.transport.perform_request(
method="PUT",
url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}",
body=update_query,
)

def search_model_group(self, query: dict):
validate_search_model_group_parameters(query)
return self.client.transport.perform_request(
method="GET", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/_search", body=query
)

def search_model_group_by_name(
self,
model_group_name: str,
_source: Optional[List] = None,
size: Optional[int] = 1,
):
query = {"query": {"match": {"name": model_group_name}}, "size": size}
if _source:
query["_source"] = _source
return self.search_model_group(query)

def get_model_group_id_by_name(self, model_group_name: str):
try:
res = self.search_model_group_by_name(model_group_name)
if res["hits"]["hits"]:
return res["hits"]["hits"][0]["_id"]
else:
raise NotFoundError
except NotFoundError:
print(f"No model group found with name:{model_group_name}")
return None
except Exception as ex:
print(f"Error in get_model_group_id_by_name: {ex}")
return None

def delete_model_group(self, model_group_id: str):
validate_delete_model_group_parameters(model_group_id)
return self.client.transport.perform_request(
method="DELETE", url=f"{ML_BASE_URI}/{self.API_ENDPOINT}/{model_group_id}"
)

def delete_model_group_by_name(self, model_group_name: str):
model_group_id = self.get_model_group_id_by_name(model_group_name)
if model_group_id is None:
raise NotFoundError(f"Model group {model_group_name} not found")
return self.delete_model_group(model_group_id=model_group_id)
6 changes: 6 additions & 0 deletions opensearch_py_ml/ml_commons/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
97 changes: 97 additions & 0 deletions opensearch_py_ml/ml_commons/validators/model_access_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

""" Module for validating model access control parameters """

from typing import List, Optional

ACCESS_MODES = ["public", "private", "restricted"]

NoneType = type(None)


def _validate_model_group_name(name: str):
if not name or not isinstance(name, str):
raise ValueError("name is required and needs to be a string")


def _validate_model_group_description(description: Optional[str]):
if not isinstance(description, (NoneType, str)):
raise ValueError("description needs to be a string")


def _validate_model_group_access_mode(access_mode: Optional[str]):
if access_mode is None:
return
if access_mode not in ACCESS_MODES:
raise ValueError(f"access_mode can must be in {ACCESS_MODES} or None")


def _validate_model_group_backend_roles(backend_roles: Optional[List[str]]):
if not isinstance(backend_roles, (NoneType, list)):
raise ValueError("backend_roles should either be None or a list of roles names")


def _validate_model_group_add_all_backend_roles(add_all_backend_roles: Optional[bool]):
if not isinstance(add_all_backend_roles, (NoneType, bool)):
raise ValueError("add_all_backend_roles should be a boolean")


def _validate_model_group_query(query: dict, operation: Optional[str] = None):
if not isinstance(query, dict):
raise ValueError("query needs to be a dictionary")

if operation and not isinstance(operation, str):
raise ValueError("operation needs to be a string")


def validate_create_model_group_parameters(
name: str,
description: Optional[str] = None,
access_mode: Optional[str] = "private",
backend_roles: Optional[List[str]] = None,
add_all_backend_roles: Optional[bool] = False,
):
_validate_model_group_name(name)
_validate_model_group_description(description)
_validate_model_group_access_mode(access_mode)
_validate_model_group_backend_roles(backend_roles)
_validate_model_group_add_all_backend_roles(add_all_backend_roles)

if access_mode == "restricted":
if not backend_roles and not add_all_backend_roles:
raise ValueError(
"You must specify either backend_roles or add_all_backend_roles=True for restricted access_mode"
)

if backend_roles and add_all_backend_roles:
raise ValueError(
"You cannot specify both backend_roles and add_all_backend_roles=True at the same time"
)

elif access_mode == "private":
if backend_roles or add_all_backend_roles:
raise ValueError(
"You must not specify backend_roles or add_all_backend_roles=True for a private model group"
)


def validate_update_model_group_parameters(update_query: dict, model_group_id: str):
if not isinstance(model_group_id, str):
raise ValueError("Invalid model_group_id. model_group_id needs to be a string")

if not isinstance(update_query, dict):
raise ValueError("Invalid update_query. update_query needs to be a dictionary")


def validate_delete_model_group_parameters(model_group_id: str):
if not isinstance(model_group_id, str):
raise ValueError("Invalid model_group_id. model_group_id needs to be a string")


def validate_search_model_group_parameters(query: dict):
_validate_model_group_query(query)
Loading

0 comments on commit 2c3b744

Please sign in to comment.