From 3a75082ceac0a7e1fc52b648411f1cf1f1c9e996 Mon Sep 17 00:00:00 2001 From: Tasko Olevski Date: Mon, 30 Sep 2024 18:11:54 +0200 Subject: [PATCH] fix: allow session launcher parameters to be reset Allows the API to accept None as input for args, command and the session launcher resource class ID so that they can be reset to their defaults in patch endpoints. --- .../renku_data_services/base_models/core.py | 14 ++ .../renku_data_services/session/blueprints.py | 38 ++--- .../renku_data_services/session/converters.py | 83 ++++++++++ components/renku_data_services/session/db.py | 147 ++++++++---------- .../renku_data_services/session/models.py | 30 ++++ .../data_api/test_sessions.py | 18 ++- test/conftest.py | 2 + 7 files changed, 223 insertions(+), 109 deletions(-) create mode 100644 components/renku_data_services/session/converters.py diff --git a/components/renku_data_services/base_models/core.py b/components/renku_data_services/base_models/core.py index fe4f32fe7..484731e15 100644 --- a/components/renku_data_services/base_models/core.py +++ b/components/renku_data_services/base_models/core.py @@ -212,3 +212,17 @@ class Authenticator(Protocol[AnyAPIUser]): async def authenticate(self, access_token: str, request: Request) -> AnyAPIUser: """Validates the user credentials (i.e. we can say that the user is a valid Renku user).""" ... + + +@dataclass(frozen=True, eq=True, kw_only=True) +class Null: + """Parent class for distinguishing between None values.""" + + value: None = field(default=None, init=False, repr=False) + + +@dataclass(frozen=True, eq=True, kw_only=True) +class Reset(Null): + """Used to indicate a None value that has been deliberately set by the user or caller.""" + + ... diff --git a/components/renku_data_services/session/blueprints.py b/components/renku_data_services/session/blueprints.py index 75fbf25a4..b42892571 100644 --- a/components/renku_data_services/session/blueprints.py +++ b/components/renku_data_services/session/blueprints.py @@ -11,7 +11,7 @@ import renku_data_services.base_models as base_models from renku_data_services.base_api.auth import authenticate, validate_path_project_id from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint -from renku_data_services.session import apispec, models +from renku_data_services.session import apispec, converters, models from renku_data_services.session.db import SessionRepository @@ -75,9 +75,11 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True) + update = converters.environment_update_from_patch(body) environment = await self.session_repo.update_environment( - user=user, environment_id=environment_id, **body_dict + user=user, + environment_id=environment_id, + update=update, ) return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json")) @@ -172,34 +174,14 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True, mode="json") async with self.session_repo.session_maker() as session, session.begin(): current_launcher = await self.session_repo.get_launcher(user, launcher_id) - new_env: models.UnsavedEnvironment | None = None - if ( - isinstance(body.environment, apispec.EnvironmentPatchInLauncher) - and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL - and body.environment.environment_kind == apispec.EnvironmentKind.CUSTOM - ): - # This means that the global environment is being swapped for a custom one, - # so we have to create a brand new environment, but we have to validate here. - validated_env = apispec.EnvironmentPostInLauncher.model_validate(body_dict.pop("environment")) - new_env = models.UnsavedEnvironment( - name=validated_env.name, - description=validated_env.description, - container_image=validated_env.container_image, - default_url=validated_env.default_url, - port=validated_env.port, - working_directory=PurePosixPath(validated_env.working_directory), - mount_directory=PurePosixPath(validated_env.mount_directory), - uid=validated_env.uid, - gid=validated_env.gid, - environment_kind=models.EnvironmentKind(validated_env.environment_kind.value), - args=validated_env.args, - command=validated_env.command, - ) + update = converters.launcher_update_from_patch(body, current_launcher) launcher = await self.session_repo.update_launcher( - user=user, launcher_id=launcher_id, new_custom_environment=new_env, session=session, **body_dict + user=user, + launcher_id=launcher_id, + session=session, + update=update, ) return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) diff --git a/components/renku_data_services/session/converters.py b/components/renku_data_services/session/converters.py new file mode 100644 index 000000000..bbdade105 --- /dev/null +++ b/components/renku_data_services/session/converters.py @@ -0,0 +1,83 @@ +"""Code used to convert from/to apispec and models.""" + +from pathlib import PurePosixPath + +from renku_data_services.base_models.core import Reset +from renku_data_services.session import apispec, models + + +def environment_update_from_patch(data: apispec.EnvironmentPatch) -> models.EnvironmentUpdate: + """Create an update object from an apispec or any other pydantic model.""" + data_dict = data.model_dump(exclude_unset=True, mode="json") + working_directory: PurePosixPath | None = None + if data.working_directory is not None: + working_directory = PurePosixPath(data.working_directory) + mount_directory: PurePosixPath | None = None + if data.mount_directory is not None: + mount_directory = PurePosixPath(data.mount_directory) + # NOTE: If the args or command are present in the data_dict and they are None they were passed in by the user. + # The None specifically passed by the user indicates that the value should be removed from the DB. + args = Reset() if "args" in data_dict and data_dict["args"] is None else data.args + command = Reset() if "command" in data_dict and data_dict["command"] is None else data.command + return models.EnvironmentUpdate( + name=data.name, + description=data.description, + container_image=data.container_image, + default_url=data.default_url, + port=data.port, + working_directory=working_directory, + mount_directory=mount_directory, + uid=data.uid, + gid=data.gid, + args=args, + command=command, + ) + + +def launcher_update_from_patch( + data: apispec.SessionLauncherPatch, + current_launcher: models.SessionLauncher | None = None, +) -> models.SessionLauncherUpdate: + """Create an update object from an apispec or any other pydantic model.""" + data_dict = data.model_dump(exclude_unset=True, mode="json") + environment: str | models.EnvironmentUpdate | models.UnsavedEnvironment | None = None + if ( + isinstance(data.environment, apispec.EnvironmentPatchInLauncher) + and current_launcher is not None + and current_launcher.environment.environment_kind == models.EnvironmentKind.GLOBAL + and data.environment.environment_kind == apispec.EnvironmentKind.CUSTOM + ): + # This means that the global environment is being swapped for a custom one, + # so we have to create a brand new environment, but we have to validate here. + validated_env = apispec.EnvironmentPostInLauncher.model_validate(data_dict["environment"]) + environment = models.UnsavedEnvironment( + name=validated_env.name, + description=validated_env.description, + container_image=validated_env.container_image, + default_url=validated_env.default_url, + port=validated_env.port, + working_directory=PurePosixPath(validated_env.working_directory), + mount_directory=PurePosixPath(validated_env.mount_directory), + uid=validated_env.uid, + gid=validated_env.gid, + environment_kind=models.EnvironmentKind(validated_env.environment_kind.value), + args=validated_env.args, + command=validated_env.command, + ) + elif isinstance(data.environment, apispec.EnvironmentPatchInLauncher): + environment = environment_update_from_patch(data.environment) + elif isinstance(data.environment, apispec.EnvironmentIdOnlyPatch): + environment = data.environment.id + resource_class_id: int | None | Reset = None + if "resource_class_id" in data_dict and data_dict["resource_class_id"] is None: + # NOTE: This means that the resource class set in the DB should be removed so that the + # default resource class currently set in the CRC will be used. + resource_class_id = Reset() + else: + resource_class_id = data_dict.get("resource_class_id") + return models.SessionLauncherUpdate( + name=data_dict.get("name"), + description=data_dict.get("description"), + environment=environment, + resource_class_id=resource_class_id, + ) diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 417820e40..9a1215136 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -5,7 +5,6 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager, nullcontext from datetime import UTC, datetime -from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -15,6 +14,7 @@ from renku_data_services import errors from renku_data_services.authz.authz import Authz, ResourceType from renku_data_services.authz.models import Scope +from renku_data_services.base_models.core import Reset from renku_data_services.crc.db import ResourcePoolRepository from renku_data_services.session import models from renku_data_services.session import orm as schemas @@ -101,53 +101,59 @@ async def insert_environment( env = await self.__insert_environment(user, session, new_environment) return env.dump() - async def __update_environment( + def __update_environment( self, - user: base_models.APIUser, - session: AsyncSession, - environment_id: ULID, - kind: models.EnvironmentKind, - **kwargs: dict, - ) -> models.Environment: - res = await session.scalars( - select(schemas.EnvironmentORM) - .where(schemas.EnvironmentORM.id == str(environment_id)) - .where(schemas.EnvironmentORM.environment_kind == kind.value) - ) - environment = res.one_or_none() - if environment is None: - raise errors.MissingResourceError(message=f"Session environment with id '{environment_id}' does not exist.") - - for key, value in kwargs.items(): - # NOTE: Only some fields can be edited - if key in [ - "name", - "description", - "container_image", - "default_url", - "port", - "working_directory", - "mount_directory", - "uid", - "gid", - "args", - "command", - ]: - setattr(environment, key, value) - - return environment.dump() + environment: schemas.EnvironmentORM, + update: models.EnvironmentUpdate, + ) -> None: + # NOTE: this is more verbose than a loop and setattr but this way we get mypy type checks + if update.name is not None: + environment.name = update.name + if update.description is not None: + environment.description = update.description + if update.container_image is not None: + environment.container_image = update.container_image + if update.default_url is not None: + environment.default_url = update.default_url + if update.port is not None: + environment.port = update.port + if update.working_directory is not None: + environment.working_directory = update.working_directory + if update.mount_directory is not None: + environment.mount_directory = update.mount_directory + if update.uid is not None: + environment.uid = update.uid + if update.gid is not None: + environment.gid = update.gid + if isinstance(update.args, Reset): + environment.args = None + elif isinstance(update.args, list): + environment.args = update.args + if isinstance(update.command, Reset): + environment.command = None + elif isinstance(update.command, list): + environment.command = update.command async def update_environment( - self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict + self, user: base_models.APIUser, environment_id: ULID, update: models.EnvironmentUpdate ) -> models.Environment: """Update a global session environment entry.""" if not user.is_admin: raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") async with self.session_maker() as session, session.begin(): - return await self.__update_environment( - user, session, environment_id, models.EnvironmentKind.GLOBAL, **kwargs + res = await session.scalars( + select(schemas.EnvironmentORM) + .where(schemas.EnvironmentORM.id == str(environment_id)) + .where(schemas.EnvironmentORM.environment_kind == models.EnvironmentKind.GLOBAL) ) + environment = res.one_or_none() + if environment is None: + raise errors.MissingResourceError( + message=f"Session environment with id '{environment_id}' does not exist." + ) + self.__update_environment(environment, update) + return environment.dump() async def delete_environment(self, user: base_models.APIUser, environment_id: ULID) -> None: """Delete a global session environment entry.""" @@ -300,9 +306,8 @@ async def update_launcher( self, user: base_models.APIUser, launcher_id: ULID, - new_custom_environment: models.UnsavedEnvironment | None, + update: models.SessionLauncherUpdate, session: AsyncSession | None = None, - **kwargs: Any, ) -> models.SessionLauncher: """Update a session launcher entry.""" if not user.is_authenticated or user.id is None: @@ -336,8 +341,8 @@ async def update_launcher( if not authorized: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") - resource_class_id = kwargs.get("resource_class_id") - if resource_class_id is not None: + resource_class_id = update.resource_class_id + if isinstance(resource_class_id, int): res = await session.scalars( select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id) ) @@ -354,17 +359,20 @@ async def update_launcher( message=f"You do not have access to resource class with id '{resource_class_id}'." ) - for key, value in kwargs.items(): - # NOTE: Only some fields can be updated. - if key in [ - "name", - "description", - "resource_class_id", - ]: - setattr(launcher, key, value) - - env_payload = kwargs.get("environment", {}) - await self.__update_launcher_environment(user, launcher, session, new_custom_environment, **env_payload) + # NOTE: Only some fields can be updated. + if update.name is not None: + launcher.name = update.name + if update.description is not None: + launcher.description = update.description + if isinstance(update.resource_class_id, int): + launcher.resource_class_id = update.resource_class_id + elif isinstance(update.resource_class_id, Reset): + launcher.resource_class_id = None + + if update.environment is None: + return launcher.dump() + + await self.__update_launcher_environment(user, launcher, session, update.environment) return launcher.dump() async def __update_launcher_environment( @@ -372,12 +380,11 @@ async def __update_launcher_environment( user: base_models.APIUser, launcher: schemas.SessionLauncherORM, session: AsyncSession, - new_custom_environment: models.UnsavedEnvironment | None, - **kwargs: Any, + update: models.EnvironmentUpdate | models.UnsavedEnvironment | str, ) -> None: current_env_kind = launcher.environment.environment_kind - match new_custom_environment, current_env_kind, kwargs: - case None, _, {"id": env_id, **nothing_else} if len(nothing_else) == 0: + match update, current_env_kind: + case str() as env_id, _: # The environment in the launcher is set via ID, the new ID has to refer # to an environment that is GLOBAL. old_environment = launcher.environment @@ -404,29 +411,11 @@ async def __update_launcher_environment( # We remove the custom environment to avoid accumulating custom environments that are not associated # with any launchers. await session.delete(old_environment) - case None, models.EnvironmentKind.CUSTOM, {**rest} if ( - rest.get("environment_kind") is None - or rest.get("environment_kind") == models.EnvironmentKind.CUSTOM.value - ): + case models.EnvironmentUpdate(), models.EnvironmentKind.CUSTOM: # Custom environment being updated - for key, val in rest.items(): - # NOTE: Only some fields can be updated. - if key in [ - "name", - "description", - "container_image", - "default_url", - "port", - "working_directory", - "mount_directory", - "uid", - "gid", - "args", - "command", - ]: - setattr(launcher.environment, key, val) - case models.UnsavedEnvironment(), models.EnvironmentKind.GLOBAL, {**nothing_else} if ( - len(nothing_else) == 0 and new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM + self.__update_environment(launcher.environment, update) + case models.UnsavedEnvironment() as new_custom_environment, models.EnvironmentKind.GLOBAL if ( + new_custom_environment.environment_kind == models.EnvironmentKind.CUSTOM ): # Global environment replaced by a custom one new_env = await self.__insert_environment(user, session, new_custom_environment) diff --git a/components/renku_data_services/session/models.py b/components/renku_data_services/session/models.py index 6dcff46c2..9ce6f0fbe 100644 --- a/components/renku_data_services/session/models.py +++ b/components/renku_data_services/session/models.py @@ -8,6 +8,7 @@ from ulid import ULID from renku_data_services import errors +from renku_data_services.base_models.core import Reset class EnvironmentKind(StrEnum): @@ -70,6 +71,23 @@ class Environment(BaseEnvironment): created_by: str +@dataclass(kw_only=True, frozen=True, eq=True) +class EnvironmentUpdate: + """Model for the update of some or all parts of an environment.""" + + name: str | None = None + description: str | None = None + container_image: str | None = None + default_url: str | None = None + port: int | None = None + working_directory: PurePosixPath | None = None + mount_directory: PurePosixPath | None = None + uid: int | None = None + gid: int | None = None + args: list[str] | None | Reset = None + command: list[str] | None | Reset = None + + @dataclass(frozen=True, eq=True, kw_only=True) class BaseSessionLauncher: """Session launcher model.""" @@ -99,3 +117,15 @@ class SessionLauncher(BaseSessionLauncher): creation_date: datetime created_by: str environment: Environment + + +@dataclass(frozen=True, eq=True, kw_only=True) +class SessionLauncherUpdate: + """Model for the update of a session launcher.""" + + name: str | None = None + description: str | None = None + # NOTE: When unsaved environment is used it means a brand new environment should be created for the + # launcher with the update of the launcher. + environment: str | EnvironmentUpdate | UnsavedEnvironment | None = None + resource_class_id: int | None | Reset = None diff --git a/test/bases/renku_data_services/data_api/test_sessions.py b/test/bases/renku_data_services/data_api/test_sessions.py index 732e5001a..d920cba9d 100644 --- a/test/bases/renku_data_services/data_api/test_sessions.py +++ b/test/bases/renku_data_services/data_api/test_sessions.py @@ -3,6 +3,7 @@ import os import shutil from asyncio import AbstractEventLoop +from collections.abc import Iterator from typing import Any import pytest @@ -18,8 +19,8 @@ os.environ["KUBECONFIG"] = ".k3d-config.yaml" -@pytest.fixture(scope="module", autouse=True) -def cluster() -> K3DCluster: +@pytest.fixture(scope="module") +def cluster() -> Iterator[K3DCluster]: if shutil.which("k3d") is None: pytest.skip("Requires k3d for cluster creation") @@ -172,10 +173,14 @@ async def test_patch_session_environment( env = await create_session_environment("Environment 1") environment_id = env["id"] + command = ["python", "test.py"] + args = ["arg1", "arg2"] payload = { "name": "New name", "description": "New description.", "container_image": "new_image:new_tag", + "command": command, + "args": args, } _, res = await sanic_client.patch(f"/api/data/environments/{environment_id}", headers=admin_headers, json=payload) @@ -185,6 +190,14 @@ async def test_patch_session_environment( assert res.json.get("name") == "New name" assert res.json.get("description") == "New description." assert res.json.get("container_image") == "new_image:new_tag" + assert res.json.get("args") == args + assert res.json.get("command") == command + + # Test that patching with None will reset the command and args + payload = {"args": None, "command": None} + _, res = await sanic_client.patch(f"/api/data/environments/{environment_id}", headers=admin_headers, json=payload) + assert res.json.get("args") is None + assert res.json.get("command") is None @pytest.mark.asyncio @@ -540,6 +553,7 @@ async def test_starting_session_anonymous( admin_headers, launch_session, anonymous_user_headers, + cluster, ) -> None: _, res = await sanic_client.post( "/api/data/resource_pools", diff --git a/test/conftest.py b/test/conftest.py index fdbd357d3..5f5658a64 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -115,6 +115,8 @@ def secrets_key_pair(monkeypatch, tmp_path) -> None: @pytest.fixture def app_config(authz_config, db_config, monkeypatch, worker_id, secrets_key_pair) -> Generator[DataConfig, None, None]: monkeypatch.setenv("MAX_PINNED_PROJECTS", "5") + monkeypatch.setenv("NB_SERVER_OPTIONS__DEFAULTS_PATH", "server_defaults.json") + monkeypatch.setenv("NB_SERVER_OPTIONS__UI_CHOICES_PATH", "server_options.json") config = DataConfig.from_env() app_name = "app_" + str(ULID()).lower() + "_" + worker_id