Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee committed Sep 25, 2024
1 parent ccc1870 commit a028799
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 23 deletions.
1 change: 1 addition & 0 deletions bases/renku_data_services/data_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def register_all_handlers(app: Sanic, config: Config) -> Sanic:
url_prefix=url_prefix,
authenticator=config.authenticator,
session_maker=config.db.async_session_maker,
reprovisioning_repo=config.reprovisioning_repo,
event_repo=config.event_repo,
user_repo=config.kc_user_repo,
group_repo=config.group_repo,
Expand Down
11 changes: 10 additions & 1 deletion components/renku_data_services/app_config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from renku_data_services.k8s.clients import DummyCoreClient, DummySchedulingClient, K8sCoreClient, K8sSchedulingClient
from renku_data_services.k8s.quota import QuotaRepository
from renku_data_services.message_queue.config import RedisConfig
from renku_data_services.message_queue.db import EventRepository
from renku_data_services.message_queue.db import EventRepository, ReprovisioningRepository
from renku_data_services.message_queue.interface import IMessageQueue
from renku_data_services.message_queue.redis_queue import RedisQueue
from renku_data_services.namespace.db import GroupRepository
Expand Down Expand Up @@ -166,6 +166,7 @@ class Config:
_project_repo: ProjectRepository | None = field(default=None, repr=False, init=False)
_group_repo: GroupRepository | None = field(default=None, repr=False, init=False)
_event_repo: EventRepository | None = field(default=None, repr=False, init=False)
_reprovisioning_repo: ReprovisioningRepository | None = field(default=None, repr=False, init=False)
_session_repo: SessionRepository | None = field(default=None, repr=False, init=False)
_user_preferences_repo: UserPreferencesRepository | None = field(default=None, repr=False, init=False)
_kc_user_repo: KcUserRepo | None = field(default=None, repr=False, init=False)
Expand All @@ -176,6 +177,7 @@ class Config:
_platform_repo: PlatformRepository | None = field(default=None, repr=False, init=False)

def __post_init__(self) -> None:
# NOTE: Read spec files required for Swagger
spec_file = Path(renku_data_services.crc.__file__).resolve().parent / "api.spec.yaml"
with open(spec_file) as f:
crc_spec = safe_load(f)
Expand Down Expand Up @@ -292,6 +294,13 @@ def event_repo(self) -> EventRepository:
)
return self._event_repo

@property
def reprovisioning_repo(self) -> ReprovisioningRepository:
"""The DB adapter for reprovisioning."""
if not self._reprovisioning_repo:
self._reprovisioning_repo = ReprovisioningRepository(session_maker=self.db.async_session_maker)
return self._reprovisioning_repo

@property
def project_repo(self) -> ProjectRepository:
"""The DB adapter for Renku native projects."""
Expand Down
34 changes: 22 additions & 12 deletions components/renku_data_services/message_queue/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from renku_data_services.base_api.auth import authenticate, only_admins
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.message_queue.core import reprovision
from renku_data_services.message_queue.db import EventRepository
from renku_data_services.message_queue.db import EventRepository, ReprovisioningRepository
from renku_data_services.namespace.db import GroupRepository
from renku_data_services.project.db import ProjectRepository
from renku_data_services.users.db import UserRepo
Expand All @@ -24,6 +24,7 @@ class SearchBP(CustomBlueprint):

authenticator: base_models.Authenticator
session_maker: Callable[..., AsyncSession]
reprovisioning_repo: ReprovisioningRepository
event_repo: EventRepository
user_repo: UserRepo
group_repo: GroupRepository
Expand All @@ -35,16 +36,24 @@ def post(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_admins
async def _post(_: Request, user: base_models.APIUser) -> HTTPResponse:
await reprovision(
session_maker=self.session_maker,
requested_by=user,
event_repo=self.event_repo,
user_repo=self.user_repo,
group_repo=self.group_repo,
project_repo=self.project_repo,
authz=self.authz,
async def _post(request: Request, user: base_models.APIUser) -> HTTPResponse:
reprovisioning = await self.reprovisioning_repo.start()

request.app.add_task(
reprovision(
session_maker=self.session_maker,
requested_by=user,
reprovisioning=reprovisioning,
reprovisioning_repo=self.reprovisioning_repo,
event_repo=self.event_repo,
user_repo=self.user_repo,
group_repo=self.group_repo,
project_repo=self.project_repo,
authz=self.authz,
),
name=f"reprovisioning-{reprovisioning.id}",
)

return HTTPResponse(status=201)

return "/search/reprovision", ["POST"], _post
Expand All @@ -54,7 +63,8 @@ def get_status(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
async def _get_status(_: Request, __: base_models.APIUser) -> JSONResponse | HTTPResponse:
return json({"active": True}, 200)
active_reprovisioning = await self.reprovisioning_repo.get_active_reprovisioning()
return json({"active": bool(active_reprovisioning)}, 200)

return "/search/reprovision", ["GET"], _get_status

Expand All @@ -64,7 +74,7 @@ def delete(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@only_admins
async def _delete(_: Request, __: base_models.APIUser) -> HTTPResponse:
# await self.project_repo.delete_project(user=user, project_id=ULID.from_str(project_id))
await self.reprovisioning_repo.stop()
return HTTPResponse(status=204)

return "/search/reprovision", ["DELETE"], _delete
11 changes: 9 additions & 2 deletions components/renku_data_services/message_queue/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
make_group_member_added_event,
make_project_member_added_event,
)
from renku_data_services.message_queue.db import EventRepository
from renku_data_services.message_queue.db import EventRepository, ReprovisioningRepository
from renku_data_services.message_queue.models import Reprovisioning
from renku_data_services.namespace.db import GroupRepository
from renku_data_services.project.db import ProjectRepository
from renku_data_services.users.db import UserRepo
Expand All @@ -22,6 +23,8 @@
async def reprovision(
session_maker: Callable[..., AsyncSession],
requested_by: APIUser,
reprovisioning: Reprovisioning,
reprovisioning_repo: ReprovisioningRepository,
event_repo: EventRepository,
user_repo: UserRepo,
group_repo: GroupRepository,
Expand All @@ -30,7 +33,9 @@ async def reprovision(
) -> None:
"""Create and send various data service events required for reprovisioning the search index."""
async with session_maker() as session, session.begin():
start_event = make_event(message_type="reprovisioning.started", payload=v2.ReprovisioningStarted(id="42"))
start_event = make_event(
message_type="reprovisioning.started", payload=v2.ReprovisioningStarted(id=str(reprovisioning.id))
)
await event_repo.store_event(session, start_event)

all_users = await user_repo.get_users(requested_by=requested_by)
Expand Down Expand Up @@ -60,3 +65,5 @@ async def reprovision(

start_event = make_event(message_type="reprovisioning.finished", payload=v2.ReprovisioningFinished(id="42"))
await event_repo.store_event(session, start_event)

await reprovisioning_repo.stop()
34 changes: 33 additions & 1 deletion components/renku_data_services/message_queue/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from __future__ import annotations

from collections.abc import Callable
from datetime import UTC, datetime

from sanic.log import logger
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from renku_data_services import errors
from renku_data_services.message_queue import orm as schemas
from renku_data_services.message_queue.interface import IMessageQueue
from renku_data_services.message_queue.models import Event
from renku_data_services.message_queue.models import Event, Reprovisioning


class EventRepository:
Expand Down Expand Up @@ -80,3 +82,33 @@ async def clear(self) -> None:
"""Delete all events. This is only used when testing reprovisioning."""
async with self.session_maker() as session, session.begin():
await session.execute(delete(schemas.EventORM))


class ReprovisioningRepository:
"""Repository for Reprovisioning."""

def __init__(self, session_maker: Callable[..., AsyncSession]) -> None:
self.session_maker = session_maker

async def start(self) -> Reprovisioning:
"""Create a new reprovisioning."""
async with self.session_maker() as session, session.begin():
active_reprovisioning = await session.scalar(select(schemas.ReprovisioningORM))
if active_reprovisioning:
raise errors.ConflictError(message="A reprovisioning is already in progress")

reprovisioning_orm = schemas.ReprovisioningORM(start_date=datetime.now(UTC).replace(microsecond=0))
session.add(reprovisioning_orm)

return reprovisioning_orm.dump()

async def get_active_reprovisioning(self) -> Reprovisioning | None:
"""Get current reprovisioning."""
async with self.session_maker() as session:
active_reprovisioning = await session.scalar(select(schemas.ReprovisioningORM))
return active_reprovisioning.dump() if active_reprovisioning else None

async def stop(self) -> None:
"""Stop current reprovisioning."""
async with self.session_maker() as session, session.begin():
await session.execute(delete(schemas.ReprovisioningORM))
10 changes: 9 additions & 1 deletion components/renku_data_services/message_queue/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import glob
import json
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import UTC, datetime
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -88,3 +88,11 @@ def create(cls, queue: str, message_type: str, payload: AvroModel) -> Self:
"payload": _serialize_binary(payload),
}
return cls(queue, message)


@dataclass
class Reprovisioning:
"""A reprovisioning."""

id: ULID
start_date: datetime = field(default_factory=lambda: datetime.now(UTC).replace(microsecond=0))
21 changes: 20 additions & 1 deletion components/renku_data_services/message_queue/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from sqlalchemy import JSON, DateTime, MetaData, String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column
from ulid import ULID

from renku_data_services.message_queue.models import Event
from renku_data_services.message_queue.models import Event, Reprovisioning
from renku_data_services.utils.sqlalchemy import ULIDType

JSONVariant = JSON().with_variant(JSONB(), "postgresql")

Expand Down Expand Up @@ -71,3 +73,20 @@ def get_message_type(self) -> Optional[str]:
return None
else:
return message_type


class ReprovisioningORM(BaseORM):
"""Reprovisioning table.
This table is used to make sure that only one instance of reprovisioning is run at any given time.
It gets updated with the reprovisioning progress.
"""

__tablename__ = "reprovisioning"

id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
start_date: Mapped[datetime] = mapped_column("start_date", DateTime(timezone=True), nullable=False)

def dump(self) -> Reprovisioning:
"""Create a Reprovisioning from the ORM object."""
return Reprovisioning(id=self.id, start_date=self.start_date)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""create search revision table
Revision ID: 726d5d0e1f28
Revises: 9058bf0a1a12
Create Date: 2024-09-18 11:45:38.734042
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "726d5d0e1f28"
down_revision = "9058bf0a1a12"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"reprovisioning",
sa.Column("id", sa.String(26), nullable=False),
sa.Column("start_date", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
schema="events",
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("reprovisioning", schema="events")
# ### end Alembic commands ###
3 changes: 0 additions & 3 deletions components/renku_data_services/project/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ async def get_all_projects(self, requested_by: base_models.APIUser) -> list[mode
raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")

async with self.session_maker() as session:
# NOTE: without awaiting the connection below, there are failures about how a connection has not
# been established in the DB but the query is getting executed.
_ = await session.connection()
stmt = select(schemas.ProjectORM)
results = await session.execute(stmt)
projects_orms = results.scalars().all()
Expand Down
71 changes: 69 additions & 2 deletions test/bases/renku_data_services/data_api/test_message_queue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio

import pytest

import renku_data_services.message_queue.blueprints
from test.bases.renku_data_services.data_api.utils import dataclass_to_str, deserialize_event


@pytest.mark.asyncio
async def test_reprovisioning(
sanic_client, app_config, create_project, create_group, admin_headers, user_headers, project_members
async def test_search_reprovisioning(
sanic_client, app_config, create_project, create_group, admin_headers, project_members
) -> None:
await create_project("Project 1")
await create_project("Project 2", visibility="public")
Expand All @@ -25,9 +28,73 @@ async def test_reprovisioning(

assert response.status_code == 201, response.text

# NOTE: Wait for server to finish the reprovisioning task
await asyncio.sleep(2)

reprovisioning_events = await app_config.event_repo._get_pending_events()

events_before = {dataclass_to_str(deserialize_event(e)) for e in events}
events_after = {dataclass_to_str(deserialize_event(e)) for e in reprovisioning_events[1:-1]}

assert events_after == events_before


@pytest.mark.asyncio
async def test_search_only_admins_can_start_reprovisioning(sanic_client, user_headers) -> None:
_, response = await sanic_client.post("/api/data/search/reprovision", headers=user_headers)

assert response.status_code == 403, response.text
assert "You do not have the required permissions for this operation." in response.json["error"]["message"]


async def long_reprovisioning_mock(*_, **__):
# NOTE: we do not delete the reprovision instance at the end to simulate a long reprovisioning
print("Running")


@pytest.mark.asyncio
async def test_search_multiple_reprovisioning_not_allowed(sanic_client, admin_headers, monkeypatch) -> None:
monkeypatch.setattr(renku_data_services.message_queue.blueprints, "reprovision", long_reprovisioning_mock)

_, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers)
assert response.status_code == 201, response.text

_, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers)

assert response.status_code == 409, response.text
assert "A reprovisioning is already in progress" in response.json["error"]["message"]


@pytest.mark.asyncio
async def test_search_get_reprovisioning_status(sanic_client, admin_headers, user_headers, monkeypatch):
monkeypatch.setattr(renku_data_services.message_queue.blueprints, "reprovision", long_reprovisioning_mock)

_, response = await sanic_client.get("/api/data/search/reprovision", headers=user_headers)

assert response.status_code == 200, response.text
assert response.json["active"] is False

# NOTE: Start a reprovisioning
_, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers)
assert response.status_code == 201, response.text

_, response = await sanic_client.get("/api/data/search/reprovision", headers=user_headers)

assert response.status_code == 200, response.text
assert response.json["active"] is True


@pytest.mark.asyncio
async def test_search_can_stop_reprovisioning(sanic_client, admin_headers, monkeypatch) -> None:
monkeypatch.setattr(renku_data_services.message_queue.blueprints, "reprovision", long_reprovisioning_mock)

_, response = await sanic_client.post("/api/data/search/reprovision", headers=admin_headers)
assert response.status_code == 201, response.text

_, response = await sanic_client.delete("/api/data/search/reprovision", headers=admin_headers)
assert response.status_code == 204, response.text

_, response = await sanic_client.get("/api/data/search/reprovision", headers=admin_headers)

assert response.status_code == 200, response.text
assert response.json["active"] is False
Empty file.

0 comments on commit a028799

Please sign in to comment.