From eedd5b500ed6bffb8dbc64dd6bbb5e08b4f348eb Mon Sep 17 00:00:00 2001 From: Mohammad Alisafaee Date: Thu, 3 Oct 2024 01:27:42 +0200 Subject: [PATCH] address review comments --- components/renku_data_services/authz/authz.py | 31 ++++---- .../message_queue/api.spec.yaml | 41 ++++++---- .../message_queue/apispec.py | 32 +++++--- .../message_queue/blueprints.py | 10 ++- .../renku_data_services/message_queue/core.py | 75 ++++++++++++------- .../renku_data_services/message_queue/db.py | 2 +- .../renku_data_services/namespace/db.py | 9 +-- components/renku_data_services/project/db.py | 11 ++- .../data_api/test_message_queue.py | 50 +++++++++---- 9 files changed, 163 insertions(+), 98 deletions(-) diff --git a/components/renku_data_services/authz/authz.py b/components/renku_data_services/authz/authz.py index 401aa16cb..182c5b34b 100644 --- a/components/renku_data_services/authz/authz.py +++ b/components/renku_data_services/authz/authz.py @@ -1,7 +1,7 @@ """Projects authorization adapter.""" import asyncio -from collections.abc import AsyncIterable, Awaitable, Callable +from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable from dataclasses import dataclass, field from enum import StrEnum from functools import wraps @@ -385,10 +385,14 @@ async def users_with_permission( ids.append(response.subject.subject_object_id) return ids - async def get_all_members(self, resource_type: ResourceType, *, zed_token: ZedToken | None = None) -> list[Member]: + async def get_all_members( + self, resource_type: ResourceType, *, zed_token: ZedToken | None = None + ) -> AsyncGenerator[Member, None]: """Get all users that are members of a specific resource.""" - members = await self._get_members_helper(resource_type, resource_id=None, zed_token=zed_token) - return [m for m in members if m.user_id and m.user_id != "*"] + members = self._get_members_helper(resource_type, resource_id=None, zed_token=zed_token) + async for member in members: + if member.user_id and member.user_id != "*": + yield member @_is_allowed(Scope.READ) async def members( @@ -401,7 +405,7 @@ async def members( zed_token: ZedToken | None = None, ) -> list[Member]: """Get all users that are members of a specific resource type, if role is None then all roles are retrieved.""" - return await self._get_members_helper(resource_type, str(resource_id), role, zed_token=zed_token) + return [m async for m in self._get_members_helper(resource_type, str(resource_id), role, zed_token=zed_token)] async def _get_members_helper( self, @@ -410,7 +414,7 @@ async def _get_members_helper( role: Role | None = None, *, zed_token: ZedToken | None = None, - ) -> list[Member]: + ) -> AsyncGenerator[Member, None]: """Get all users that are members of a resource, if role is None then all roles are retrieved.""" consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) sub_filter = SubjectFilter(subject_type=ResourceType.user.value) @@ -443,20 +447,19 @@ async def _get_members_helper( relationship_filter=rel_filter, ) ) - members: list[Member] = [] + async for response in responses: # Skip "public_viewer" relationships if response.relationship.relation == _Relation.public_viewer.value: continue member_role = _Relation(response.relationship.relation).to_role() - members.append( - Member( - user_id=response.relationship.subject.object.object_id, - role=member_role, - resource_id=response.relationship.resource.object_id, - ) + member = Member( + user_id=response.relationship.subject.object.object_id, + role=member_role, + resource_id=response.relationship.resource.object_id, ) - return members + + yield member @staticmethod def authz_change( diff --git a/components/renku_data_services/message_queue/api.spec.yaml b/components/renku_data_services/message_queue/api.spec.yaml index f406f620f..999c78fa7 100644 --- a/components/renku_data_services/message_queue/api.spec.yaml +++ b/components/renku_data_services/message_queue/api.spec.yaml @@ -9,17 +9,23 @@ servers: - url: /api/data - url: /ui-server/api/data paths: - /search/reprovision: + /message_queue/reprovision: post: summary: Start a new reprovisioning description: Only a single reprovisioning is active at any time responses: "201": description: The reprovisioning is/will be started + content: + application/json: + schema: + $ref: "#/components/schemas/Reprovisioning" + "409": + description: A reprovisioning is already started default: $ref: "#/components/responses/Error" tags: - - search + - message_queue get: summary: Return status of reprovisioning responses: @@ -38,7 +44,7 @@ paths: default: $ref: "#/components/responses/Error" tags: - - search + - message_queue delete: summary: Stop an active reprovisioning responses: @@ -47,28 +53,37 @@ paths: default: $ref: "#/components/responses/Error" tags: - - search + - message_queue components: schemas: - ReprovisioningStatus: - description: Status of a reprovisioning + Reprovisioning: + description: A reprovisioning type: object properties: - active: - type: boolean - description: Whether a reprovisioning is in progress or not + id: + $ref: "#/components/schemas/Ulid" start_date: - description: The date and time the resource was created (in UTC and ISO-8601 format) + description: The date and time the reprovisioning was started (in UTC and ISO-8601 format) type: string format: date-time example: "2023-11-01T17:32:28Z" required: - - active + - id + - start_date example: - - active: true + - id: 01BX5ZZK2KAC4AV9WEV3EMM0S0 start_date: "2023-11-01T17:32:28Z" - - active: false + ReprovisioningStatus: + description: Status of a reprovisioning + allOf: + - $ref: "#/components/schemas/Reprovisioning" + Ulid: + description: ULID identifier + type: string + minLength: 26 + maxLength: 26 + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive ErrorResponse: type: object properties: diff --git a/components/renku_data_services/message_queue/apispec.py b/components/renku_data_services/message_queue/apispec.py index cb2d6193f..18cf76abb 100644 --- a/components/renku_data_services/message_queue/apispec.py +++ b/components/renku_data_services/message_queue/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-09-10T00:17:57+00:00 +# timestamp: 2024-10-03T07:51:44+00:00 from __future__ import annotations @@ -11,17 +11,6 @@ from renku_data_services.message_queue.apispec_base import BaseAPISpec -class ReprovisioningStatus(BaseAPISpec): - active: bool = Field( - ..., description="Whether a reprovisioning is in progress or not" - ) - start_date: Optional[datetime] = Field( - None, - description="The date and time the resource was created (in UTC and ISO-8601 format)", - example="2023-11-01T17:32:28Z", - ) - - class Error(BaseAPISpec): code: int = Field(..., example=1404, gt=0) detail: Optional[str] = Field( @@ -32,3 +21,22 @@ class Error(BaseAPISpec): class ErrorResponse(BaseAPISpec): error: Error + + +class Reprovisioning(BaseAPISpec): + id: str = Field( + ..., + description="ULID identifier", + max_length=26, + min_length=26, + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", + ) + start_date: datetime = Field( + ..., + description="The date and time the reprovisioning was started (in UTC and ISO-8601 format)", + example="2023-11-01T17:32:28Z", + ) + + +class ReprovisioningStatus(Reprovisioning): + pass diff --git a/components/renku_data_services/message_queue/blueprints.py b/components/renku_data_services/message_queue/blueprints.py index b58688423..f1408b912 100644 --- a/components/renku_data_services/message_queue/blueprints.py +++ b/components/renku_data_services/message_queue/blueprints.py @@ -36,7 +36,7 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @only_admins - async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse: + async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse | JSONResponse: reprovisioning = await self.reprovisioning_repo.start() request.app.add_task( @@ -54,7 +54,7 @@ async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse: name=f"reprovisioning-{reprovisioning.id}", ) - return HTTPResponse(status=201) + return json({"id": reprovisioning.id, "start_date": reprovisioning.start_date.isoformat()}, 201) return "/search/reprovision", ["POST"], _post @@ -63,8 +63,10 @@ def get_status(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) async def _get_status(_: Request, __: base_models.APIUser) -> JSONResponse | HTTPResponse: - active_reprovisioning = await self.reprovisioning_repo.get_active_reprovisioning() - return json({"active": bool(active_reprovisioning)}, 200) + reprovisioning = await self.reprovisioning_repo.get_active_reprovisioning() + if not reprovisioning: + return HTTPResponse(status=404) + return json({"id": str(reprovisioning.id), "start_date": reprovisioning.start_date.isoformat()}) return "/search/reprovision", ["GET"], _get_status diff --git a/components/renku_data_services/message_queue/core.py b/components/renku_data_services/message_queue/core.py index f4dcee639..8ccc9cf1e 100644 --- a/components/renku_data_services/message_queue/core.py +++ b/components/renku_data_services/message_queue/core.py @@ -2,6 +2,7 @@ from collections.abc import Callable +from sanic.log import logger from sqlalchemy.ext.asyncio import AsyncSession from renku_data_services.authz.authz import Authz, ResourceType @@ -31,39 +32,55 @@ async def reprovision( project_repo: ProjectRepository, authz: Authz, ) -> None: - """Create and send various data service events required for reprovisioning the search index.""" - async with session_maker() as session, session.begin(): - start_event = make_event( - message_type="reprovisioning.started", payload=v2.ReprovisioningStarted(id=str(reprovisioning.id)) - ) - await event_repo.store_event(session, start_event) + """Create and send various data service events required for reprovisioning the message queue.""" + logger.info(f"Starting reprovisioning with ID {reprovisioning.id}") - all_users = await user_repo.get_users(requested_by=requested_by) - for user in all_users: - user_event = EventConverter.to_events(user, event_type=v2.UserAdded) - await event_repo.store_event(session, user_event[0]) + try: + async with session_maker() as session, session.begin(): + start_event = make_event( + message_type="reprovisioning.started", payload=v2.ReprovisioningStarted(id=str(reprovisioning.id)) + ) + await event_repo.store_event(session, start_event) - all_groups = await group_repo.get_all_groups(requested_by=requested_by) - for group in all_groups: - group_event = EventConverter.to_events(group, event_type=v2.GroupAdded) - await event_repo.store_event(session, group_event[0]) + logger.info("Reprovisioning users") + all_users = await user_repo.get_users(requested_by=requested_by) + for user in all_users: + user_event = EventConverter.to_events(user, event_type=v2.UserAdded) + await event_repo.store_event(session, user_event[0]) - all_groups_members = await authz.get_all_members(ResourceType.group) - for group_member in all_groups_members: - group_member_event = make_group_member_added_event(member=group_member) - await event_repo.store_event(session, group_member_event) + logger.info("Reprovisioning groups") + all_groups = group_repo.get_all_groups(requested_by=requested_by) + async for group in all_groups: + group_event = EventConverter.to_events(group, event_type=v2.GroupAdded) + await event_repo.store_event(session, group_event[0]) - all_projects = await project_repo.get_all_projects(requested_by=requested_by) - for project in all_projects: - project_event = EventConverter.to_events(project, event_type=v2.ProjectCreated) - await event_repo.store_event(session, project_event[0]) + logger.info("Reprovisioning group members") + all_groups_members = authz.get_all_members(ResourceType.group) + async for group_member in all_groups_members: + group_member_event = make_group_member_added_event(member=group_member) + await event_repo.store_event(session, group_member_event) - all_projects_members = await authz.get_all_members(ResourceType.project) - for project_member in all_projects_members: - project_member_event = make_project_member_added_event(member=project_member) - await event_repo.store_event(session, project_member_event) + logger.info("Reprovisioning projects") + all_projects = project_repo.get_all_projects(requested_by=requested_by) + async for project in all_projects: + project_event = EventConverter.to_events(project, event_type=v2.ProjectCreated) + await event_repo.store_event(session, project_event[0]) - start_event = make_event(message_type="reprovisioning.finished", payload=v2.ReprovisioningFinished(id="42")) - await event_repo.store_event(session, start_event) + logger.info("Reprovisioning project members") + all_projects_members = authz.get_all_members(ResourceType.project) + async for project_member in all_projects_members: + project_member_event = make_project_member_added_event(member=project_member) + await event_repo.store_event(session, project_member_event) - await reprovisioning_repo.stop() + finish_event = make_event( + message_type="reprovisioning.finished", payload=v2.ReprovisioningFinished(id=str(reprovisioning.id)) + ) + await event_repo.store_event(session, finish_event) + + logger.info(f"Trying to commit reprovisioning with ID {reprovisioning.id}") + except Exception as e: + logger.exception(f"An error occurred during reprovisioning with ID {reprovisioning.id}: {e}") + else: + logger.info(f"Reprovisioning with ID {reprovisioning.id} is successfully finished") + finally: + await reprovisioning_repo.stop() diff --git a/components/renku_data_services/message_queue/db.py b/components/renku_data_services/message_queue/db.py index 366bd1037..6c22abdc6 100644 --- a/components/renku_data_services/message_queue/db.py +++ b/components/renku_data_services/message_queue/db.py @@ -78,7 +78,7 @@ async def delete_event(self, id: int) -> None: stmt = delete(schemas.EventORM).where(schemas.EventORM.id == id) await session.execute(stmt) - async def clear(self) -> None: + async def delete_all_events(self) -> None: """Delete all events. This is only used when testing reprovisioning.""" async with self.session_maker() as session, session.begin(): await session.execute(delete(schemas.EventORM)) diff --git a/components/renku_data_services/namespace/db.py b/components/renku_data_services/namespace/db.py index 9d99d8064..babb3de50 100644 --- a/components/renku_data_services/namespace/db.py +++ b/components/renku_data_services/namespace/db.py @@ -90,16 +90,15 @@ async def get_groups( n_total_elements = result.scalar() or 0 return [g.dump() for g in groups_orm], n_total_elements - async def get_all_groups(self, requested_by: base_models.APIUser) -> list[models.Group]: + async def get_all_groups(self, requested_by: base_models.APIUser) -> AsyncGenerator[models.Group, None]: """Get all groups when reprovisioning.""" if not requested_by.is_admin: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") async with self.session_maker() as session, session.begin(): - stmt = select(schemas.GroupORM) - results = await session.execute(stmt) - groups_orms = results.scalars().all() - return [g.dump() for g in groups_orms] + groups = await session.stream_scalars(select(schemas.GroupORM)) + async for group in groups: + yield group.dump() async def _get_group( self, session: AsyncSession, user: base_models.APIUser, slug: str, load_members: bool = False diff --git a/components/renku_data_services/project/db.py b/components/renku_data_services/project/db.py index c231a9833..3ea19555f 100644 --- a/components/renku_data_services/project/db.py +++ b/components/renku_data_services/project/db.py @@ -3,7 +3,7 @@ from __future__ import annotations import functools -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from datetime import UTC, datetime from typing import Any, Concatenate, ParamSpec, TypeVar @@ -70,16 +70,15 @@ async def get_projects( total_elements = results[1] or 0 return [p.dump() for p in projects_orm], total_elements - async def get_all_projects(self, requested_by: base_models.APIUser) -> list[models.Project]: + async def get_all_projects(self, requested_by: base_models.APIUser) -> AsyncGenerator[models.Project, None]: """Get all projects from the database when reprovisioning.""" if not requested_by.is_admin: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") async with self.session_maker() as session: - stmt = select(schemas.ProjectORM) - results = await session.execute(stmt) - projects_orms = results.scalars().all() - return [p.dump() for p in projects_orms] + projects = await session.stream_scalars(select(schemas.ProjectORM)) + async for project in projects: + yield project.dump() async def get_project(self, user: base_models.APIUser, project_id: ULID) -> models.Project: """Get one project from the database.""" diff --git a/test/bases/renku_data_services/data_api/test_message_queue.py b/test/bases/renku_data_services/data_api/test_message_queue.py index e80171f27..90db083f6 100644 --- a/test/bases/renku_data_services/data_api/test_message_queue.py +++ b/test/bases/renku_data_services/data_api/test_message_queue.py @@ -6,9 +6,29 @@ from test.bases.renku_data_services.data_api.utils import dataclass_to_str, deserialize_event +@pytest.fixture +def _reprovisioning(sanic_client, user_headers): + """Wait for the data service to finish the reprovisioning task.""" + + async def wait_helper(): + total_wait_time = 0 + while True: + await asyncio.sleep(0.1) + total_wait_time += 0.1 + + _, response = await sanic_client.get("/api/data/search/reprovision", headers=user_headers) + + if response.status_code == 404: + break + elif total_wait_time > 30: + assert False, "Reprovisioning was not finished after 30 seconds" + + return wait_helper + + @pytest.mark.asyncio -async def test_search_reprovisioning( - sanic_client, app_config, create_project, create_group, admin_headers, project_members +async def test_message_queue_reprovisioning( + sanic_client, app_config, create_project, create_group, admin_headers, project_members, _reprovisioning ) -> None: await create_project("Project 1") await create_project("Project 2", visibility="public") @@ -22,14 +42,15 @@ async def test_search_reprovisioning( events = await app_config.event_repo._get_pending_events() # NOTE: Clear all events before reprovisioning - await app_config.event_repo.clear() + await app_config.event_repo.delete_all_events() _, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers) assert response.status_code == 201, response.text + assert response.json["id"] is not None + assert response.json["start_date"] is not None - # NOTE: Wait for server to finish the reprovisioning task - await asyncio.sleep(2) + await _reprovisioning() reprovisioning_events = await app_config.event_repo._get_pending_events() @@ -40,7 +61,7 @@ async def test_search_reprovisioning( @pytest.mark.asyncio -async def test_search_only_admins_can_start_reprovisioning(sanic_client, user_headers) -> None: +async def test_message_queue_only_admins_can_start_reprovisioning(sanic_client, user_headers) -> None: _, response = await sanic_client.post("/api/data/search/reprovision", headers=user_headers) assert response.status_code == 403, response.text @@ -53,7 +74,7 @@ async def long_reprovisioning_mock(*_, **__): @pytest.mark.asyncio -async def test_search_multiple_reprovisioning_not_allowed(sanic_client, admin_headers, monkeypatch) -> None: +async def test_message_queue_multiple_reprovisioning_not_allowed(sanic_client, admin_headers, monkeypatch) -> None: monkeypatch.setattr(renku_data_services.message_queue.blueprints, "reprovision", long_reprovisioning_mock) _, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers) @@ -66,13 +87,12 @@ async def test_search_multiple_reprovisioning_not_allowed(sanic_client, admin_he @pytest.mark.asyncio -async def test_search_get_reprovisioning_status(sanic_client, admin_headers, user_headers, monkeypatch): +async def test_message_queue_get_reprovisioning_status(sanic_client, admin_headers, user_headers, monkeypatch): monkeypatch.setattr(renku_data_services.message_queue.blueprints, "reprovision", long_reprovisioning_mock) _, response = await sanic_client.get("/api/data/search/reprovision", headers=user_headers) - assert response.status_code == 200, response.text - assert response.json["active"] is False + assert response.status_code == 404, response.text # NOTE: Start a reprovisioning _, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers) @@ -81,20 +101,22 @@ async def test_search_get_reprovisioning_status(sanic_client, admin_headers, use _, response = await sanic_client.get("/api/data/search/reprovision", headers=user_headers) assert response.status_code == 200, response.text - assert response.json["active"] is True + assert response.json["id"] is not None + assert response.json["start_date"] is not None @pytest.mark.asyncio -async def test_search_can_stop_reprovisioning(sanic_client, admin_headers, monkeypatch) -> None: +async def test_message_queue_can_stop_reprovisioning(sanic_client, admin_headers, monkeypatch) -> None: monkeypatch.setattr(renku_data_services.message_queue.blueprints, "reprovision", long_reprovisioning_mock) _, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers) assert response.status_code == 201, response.text + _, response = await sanic_client.get("/api/data/search/reprovision", headers=admin_headers) + assert response.status_code == 200, response.text _, response = await sanic_client.delete("/api/data/search/reprovision", headers=admin_headers) assert response.status_code == 204, response.text _, response = await sanic_client.get("/api/data/search/reprovision", headers=admin_headers) - assert response.status_code == 200, response.text - assert response.json["active"] is False + assert response.status_code == 404, response.text