Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee committed Oct 3, 2024
1 parent a028799 commit eedd5b5
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 98 deletions.
31 changes: 17 additions & 14 deletions components/renku_data_services/authz/authz.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Projects authorization adapter."""

import asyncio
from collections.abc import AsyncIterable, Awaitable, Callable
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable
from dataclasses import dataclass, field
from enum import StrEnum
from functools import wraps
Expand Down Expand Up @@ -385,10 +385,14 @@ async def users_with_permission(
ids.append(response.subject.subject_object_id)
return ids

async def get_all_members(self, resource_type: ResourceType, *, zed_token: ZedToken | None = None) -> list[Member]:
async def get_all_members(
self, resource_type: ResourceType, *, zed_token: ZedToken | None = None
) -> AsyncGenerator[Member, None]:
"""Get all users that are members of a specific resource."""
members = await self._get_members_helper(resource_type, resource_id=None, zed_token=zed_token)
return [m for m in members if m.user_id and m.user_id != "*"]
members = self._get_members_helper(resource_type, resource_id=None, zed_token=zed_token)
async for member in members:
if member.user_id and member.user_id != "*":
yield member

@_is_allowed(Scope.READ)
async def members(
Expand All @@ -401,7 +405,7 @@ async def members(
zed_token: ZedToken | None = None,
) -> list[Member]:
"""Get all users that are members of a specific resource type, if role is None then all roles are retrieved."""
return await self._get_members_helper(resource_type, str(resource_id), role, zed_token=zed_token)
return [m async for m in self._get_members_helper(resource_type, str(resource_id), role, zed_token=zed_token)]

async def _get_members_helper(
self,
Expand All @@ -410,7 +414,7 @@ async def _get_members_helper(
role: Role | None = None,
*,
zed_token: ZedToken | None = None,
) -> list[Member]:
) -> AsyncGenerator[Member, None]:
"""Get all users that are members of a resource, if role is None then all roles are retrieved."""
consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True)
sub_filter = SubjectFilter(subject_type=ResourceType.user.value)
Expand Down Expand Up @@ -443,20 +447,19 @@ async def _get_members_helper(
relationship_filter=rel_filter,
)
)
members: list[Member] = []

async for response in responses:
# Skip "public_viewer" relationships
if response.relationship.relation == _Relation.public_viewer.value:
continue
member_role = _Relation(response.relationship.relation).to_role()
members.append(
Member(
user_id=response.relationship.subject.object.object_id,
role=member_role,
resource_id=response.relationship.resource.object_id,
)
member = Member(
user_id=response.relationship.subject.object.object_id,
role=member_role,
resource_id=response.relationship.resource.object_id,
)
return members

yield member

@staticmethod
def authz_change(
Expand Down
41 changes: 28 additions & 13 deletions components/renku_data_services/message_queue/api.spec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@ servers:
- url: /api/data
- url: /ui-server/api/data
paths:
/search/reprovision:
/message_queue/reprovision:
post:
summary: Start a new reprovisioning
description: Only a single reprovisioning is active at any time
responses:
"201":
description: The reprovisioning is/will be started
content:
application/json:
schema:
$ref: "#/components/schemas/Reprovisioning"
"409":
description: A reprovisioning is already started
default:
$ref: "#/components/responses/Error"
tags:
- search
- message_queue
get:
summary: Return status of reprovisioning
responses:
Expand All @@ -38,7 +44,7 @@ paths:
default:
$ref: "#/components/responses/Error"
tags:
- search
- message_queue
delete:
summary: Stop an active reprovisioning
responses:
Expand All @@ -47,28 +53,37 @@ paths:
default:
$ref: "#/components/responses/Error"
tags:
- search
- message_queue

components:
schemas:
ReprovisioningStatus:
description: Status of a reprovisioning
Reprovisioning:
description: A reprovisioning
type: object
properties:
active:
type: boolean
description: Whether a reprovisioning is in progress or not
id:
$ref: "#/components/schemas/Ulid"
start_date:
description: The date and time the resource was created (in UTC and ISO-8601 format)
description: The date and time the reprovisioning was started (in UTC and ISO-8601 format)
type: string
format: date-time
example: "2023-11-01T17:32:28Z"
required:
- active
- id
- start_date
example:
- active: true
- id: 01BX5ZZK2KAC4AV9WEV3EMM0S0
start_date: "2023-11-01T17:32:28Z"
- active: false
ReprovisioningStatus:
description: Status of a reprovisioning
allOf:
- $ref: "#/components/schemas/Reprovisioning"
Ulid:
description: ULID identifier
type: string
minLength: 26
maxLength: 26
pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive
ErrorResponse:
type: object
properties:
Expand Down
32 changes: 20 additions & 12 deletions components/renku_data_services/message_queue/apispec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: api.spec.yaml
# timestamp: 2024-09-10T00:17:57+00:00
# timestamp: 2024-10-03T07:51:44+00:00

from __future__ import annotations

Expand All @@ -11,17 +11,6 @@
from renku_data_services.message_queue.apispec_base import BaseAPISpec


class ReprovisioningStatus(BaseAPISpec):
active: bool = Field(
..., description="Whether a reprovisioning is in progress or not"
)
start_date: Optional[datetime] = Field(
None,
description="The date and time the resource was created (in UTC and ISO-8601 format)",
example="2023-11-01T17:32:28Z",
)


class Error(BaseAPISpec):
code: int = Field(..., example=1404, gt=0)
detail: Optional[str] = Field(
Expand All @@ -32,3 +21,22 @@ class Error(BaseAPISpec):

class ErrorResponse(BaseAPISpec):
error: Error


class Reprovisioning(BaseAPISpec):
id: str = Field(
...,
description="ULID identifier",
max_length=26,
min_length=26,
pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$",
)
start_date: datetime = Field(
...,
description="The date and time the reprovisioning was started (in UTC and ISO-8601 format)",
example="2023-11-01T17:32:28Z",
)


class ReprovisioningStatus(Reprovisioning):
pass
10 changes: 6 additions & 4 deletions components/renku_data_services/message_queue/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def post(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_admins
async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse:
async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse | JSONResponse:
reprovisioning = await self.reprovisioning_repo.start()

request.app.add_task(
Expand All @@ -54,7 +54,7 @@ async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse:
name=f"reprovisioning-{reprovisioning.id}",
)

return HTTPResponse(status=201)
return json({"id": reprovisioning.id, "start_date": reprovisioning.start_date.isoformat()}, 201)

return "/search/reprovision", ["POST"], _post

Expand All @@ -63,8 +63,10 @@ def get_status(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
async def _get_status(_: Request, __: base_models.APIUser) -> JSONResponse | HTTPResponse:
active_reprovisioning = await self.reprovisioning_repo.get_active_reprovisioning()
return json({"active": bool(active_reprovisioning)}, 200)
reprovisioning = await self.reprovisioning_repo.get_active_reprovisioning()
if not reprovisioning:
return HTTPResponse(status=404)
return json({"id": str(reprovisioning.id), "start_date": reprovisioning.start_date.isoformat()})

return "/search/reprovision", ["GET"], _get_status

Expand Down
75 changes: 46 additions & 29 deletions components/renku_data_services/message_queue/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Callable

from sanic.log import logger
from sqlalchemy.ext.asyncio import AsyncSession

from renku_data_services.authz.authz import Authz, ResourceType
Expand Down Expand Up @@ -31,39 +32,55 @@ async def reprovision(
project_repo: ProjectRepository,
authz: Authz,
) -> None:
"""Create and send various data service events required for reprovisioning the search index."""
async with session_maker() as session, session.begin():
start_event = make_event(
message_type="reprovisioning.started", payload=v2.ReprovisioningStarted(id=str(reprovisioning.id))
)
await event_repo.store_event(session, start_event)
"""Create and send various data service events required for reprovisioning the message queue."""
logger.info(f"Starting reprovisioning with ID {reprovisioning.id}")

all_users = await user_repo.get_users(requested_by=requested_by)
for user in all_users:
user_event = EventConverter.to_events(user, event_type=v2.UserAdded)
await event_repo.store_event(session, user_event[0])
try:
async with session_maker() as session, session.begin():
start_event = make_event(
message_type="reprovisioning.started", payload=v2.ReprovisioningStarted(id=str(reprovisioning.id))
)
await event_repo.store_event(session, start_event)

all_groups = await group_repo.get_all_groups(requested_by=requested_by)
for group in all_groups:
group_event = EventConverter.to_events(group, event_type=v2.GroupAdded)
await event_repo.store_event(session, group_event[0])
logger.info("Reprovisioning users")
all_users = await user_repo.get_users(requested_by=requested_by)
for user in all_users:
user_event = EventConverter.to_events(user, event_type=v2.UserAdded)
await event_repo.store_event(session, user_event[0])

all_groups_members = await authz.get_all_members(ResourceType.group)
for group_member in all_groups_members:
group_member_event = make_group_member_added_event(member=group_member)
await event_repo.store_event(session, group_member_event)
logger.info("Reprovisioning groups")
all_groups = group_repo.get_all_groups(requested_by=requested_by)
async for group in all_groups:
group_event = EventConverter.to_events(group, event_type=v2.GroupAdded)
await event_repo.store_event(session, group_event[0])

all_projects = await project_repo.get_all_projects(requested_by=requested_by)
for project in all_projects:
project_event = EventConverter.to_events(project, event_type=v2.ProjectCreated)
await event_repo.store_event(session, project_event[0])
logger.info("Reprovisioning group members")
all_groups_members = authz.get_all_members(ResourceType.group)
async for group_member in all_groups_members:
group_member_event = make_group_member_added_event(member=group_member)
await event_repo.store_event(session, group_member_event)

all_projects_members = await authz.get_all_members(ResourceType.project)
for project_member in all_projects_members:
project_member_event = make_project_member_added_event(member=project_member)
await event_repo.store_event(session, project_member_event)
logger.info("Reprovisioning projects")
all_projects = project_repo.get_all_projects(requested_by=requested_by)
async for project in all_projects:
project_event = EventConverter.to_events(project, event_type=v2.ProjectCreated)
await event_repo.store_event(session, project_event[0])

start_event = make_event(message_type="reprovisioning.finished", payload=v2.ReprovisioningFinished(id="42"))
await event_repo.store_event(session, start_event)
logger.info("Reprovisioning project members")
all_projects_members = authz.get_all_members(ResourceType.project)
async for project_member in all_projects_members:
project_member_event = make_project_member_added_event(member=project_member)
await event_repo.store_event(session, project_member_event)

await reprovisioning_repo.stop()
finish_event = make_event(
message_type="reprovisioning.finished", payload=v2.ReprovisioningFinished(id=str(reprovisioning.id))
)
await event_repo.store_event(session, finish_event)

logger.info(f"Trying to commit reprovisioning with ID {reprovisioning.id}")
except Exception as e:
logger.exception(f"An error occurred during reprovisioning with ID {reprovisioning.id}: {e}")
else:
logger.info(f"Reprovisioning with ID {reprovisioning.id} is successfully finished")
finally:
await reprovisioning_repo.stop()
2 changes: 1 addition & 1 deletion components/renku_data_services/message_queue/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def delete_event(self, id: int) -> None:
stmt = delete(schemas.EventORM).where(schemas.EventORM.id == id)
await session.execute(stmt)

async def clear(self) -> None:
async def delete_all_events(self) -> None:
"""Delete all events. This is only used when testing reprovisioning."""
async with self.session_maker() as session, session.begin():
await session.execute(delete(schemas.EventORM))
Expand Down
9 changes: 4 additions & 5 deletions components/renku_data_services/namespace/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,15 @@ async def get_groups(
n_total_elements = result.scalar() or 0
return [g.dump() for g in groups_orm], n_total_elements

async def get_all_groups(self, requested_by: base_models.APIUser) -> list[models.Group]:
async def get_all_groups(self, requested_by: base_models.APIUser) -> AsyncGenerator[models.Group, None]:
"""Get all groups when reprovisioning."""
if not requested_by.is_admin:
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")

async with self.session_maker() as session, session.begin():
stmt = select(schemas.GroupORM)
results = await session.execute(stmt)
groups_orms = results.scalars().all()
return [g.dump() for g in groups_orms]
groups = await session.stream_scalars(select(schemas.GroupORM))
async for group in groups:
yield group.dump()

async def _get_group(
self, session: AsyncSession, user: base_models.APIUser, slug: str, load_members: bool = False
Expand Down
11 changes: 5 additions & 6 deletions components/renku_data_services/project/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import functools
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from datetime import UTC, datetime
from typing import Any, Concatenate, ParamSpec, TypeVar

Expand Down Expand Up @@ -70,16 +70,15 @@ async def get_projects(
total_elements = results[1] or 0
return [p.dump() for p in projects_orm], total_elements

async def get_all_projects(self, requested_by: base_models.APIUser) -> list[models.Project]:
async def get_all_projects(self, requested_by: base_models.APIUser) -> AsyncGenerator[models.Project, None]:
"""Get all projects from the database when reprovisioning."""
if not requested_by.is_admin:
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")

async with self.session_maker() as session:
stmt = select(schemas.ProjectORM)
results = await session.execute(stmt)
projects_orms = results.scalars().all()
return [p.dump() for p in projects_orms]
projects = await session.stream_scalars(select(schemas.ProjectORM))
async for project in projects:
yield project.dump()

async def get_project(self, user: base_models.APIUser, project_id: ULID) -> models.Project:
"""Get one project from the database."""
Expand Down
Loading

0 comments on commit eedd5b5

Please sign in to comment.