Skip to content

Commit

Permalink
cleanup changes
Browse files Browse the repository at this point in the history
  • Loading branch information
leafty committed Oct 2, 2024
1 parent 9072df2 commit dde4e44
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 55 deletions.
13 changes: 12 additions & 1 deletion components/renku_data_services/authz/authz.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ async def _get_authz_change(
result, DataConnectorToProjectLink
):
user = _extract_user_from_args(*func_args, **func_kwargs)
authz_change = await db_repo.authz._add_data_connector_to_project_link(user, result)
authz_change = await db_repo.authz._remove_data_connector_to_project_link(user, result)
case _:
resource_id: str | ULID | None = "unknown"
if isinstance(result, (Project, Namespace, Group, DataConnector)):
Expand Down Expand Up @@ -681,6 +681,17 @@ async def _remove_project(
ReadRelationshipsRequest(consistency=consistency, relationship_filter=rel_filter)
)
rels: list[Relationship] = []
async for response in responses:
rels.append(response.relationship)
# Project is also a subject for "linked_to" relations
rel_filter = RelationshipFilter(
optional_subject_filter=SubjectFilter(
subject_type=ResourceType.project.value, optional_subject_id=str(project.id)
)
)
responses: AsyncIterable[ReadRelationshipsResponse] = self.client.ReadRelationships(
ReadRelationshipsRequest(consistency=consistency, relationship_filter=rel_filter)
)
async for response in responses:
rels.append(response.relationship)
apply = WriteRelationshipsRequest(
Expand Down
1 change: 1 addition & 0 deletions components/renku_data_services/base_api/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def validate_body_root_model(
"""Decorator for sanic json payload validation when the model is derived from RootModel.
Should be removed once sanic fixes this error in their validation code.
Issue link: https://github.com/sanic-org/sanic-ext/issues/198
"""

def decorator(
Expand Down
30 changes: 17 additions & 13 deletions components/renku_data_services/crc/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from renku_data_services import errors
from renku_data_services.base_api.auth import authenticate, only_admins
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_db_ids, validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_db_ids, validate_query
from renku_data_services.base_models.validation import validated_json
from renku_data_services.crc import apispec, models
from renku_data_services.crc.db import ResourcePoolRepository, UserRepository
Expand Down Expand Up @@ -161,9 +161,11 @@ def post(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@only_admins
@validate_db_ids
async def _post(request: Request, user: base_models.APIUser, resource_pool_id: int) -> HTTPResponse:
users = apispec.PoolUsersWithId.model_validate(request.json) # validation
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=users, post=True)
@validate_body_root_model(json=apispec.PoolUsersWithId)
async def _post(
_: Request, user: base_models.APIUser, resource_pool_id: int, body: apispec.PoolUsersWithId
) -> HTTPResponse:
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=body, post=True)

return "/resource_pools/<resource_pool_id>/users", ["POST"], _post

Expand All @@ -173,9 +175,11 @@ def put(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@only_admins
@validate_db_ids
async def _put(request: Request, user: base_models.APIUser, resource_pool_id: int) -> HTTPResponse:
users = apispec.PoolUsersWithId.model_validate(request.json) # validation
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=users, post=False)
@validate_body_root_model(json=apispec.PoolUsersWithId)
async def _put(
_: Request, user: base_models.APIUser, resource_pool_id: int, body: apispec.PoolUsersWithId
) -> HTTPResponse:
return await self._put_post(api_user=user, resource_pool_id=resource_pool_id, body=body, post=False)

return "/resource_pools/<resource_pool_id>/users", ["PUT"], _put

Expand Down Expand Up @@ -528,9 +532,9 @@ def post(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_admins
async def _post(request: Request, user: base_models.APIUser, user_id: str) -> HTTPResponse:
ids = apispec.IntegerIds.model_validate(request.json) # validation
return await self._post_put(user_id=user_id, post=True, resource_pool_ids=ids, api_user=user)
@validate_body_root_model(json=apispec.IntegerIds)
async def _post(_: Request, user: base_models.APIUser, user_id: str, body: apispec.IntegerIds) -> HTTPResponse:
return await self._post_put(user_id=user_id, post=True, resource_pool_ids=body, api_user=user)

return "/users/<user_id>/resource_pools", ["POST"], _post

Expand All @@ -539,9 +543,9 @@ def put(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_admins
async def _put(request: Request, user: base_models.APIUser, user_id: str) -> HTTPResponse:
ids = apispec.IntegerIds.model_validate(request.json) # validation
return await self._post_put(user_id=user_id, post=False, resource_pool_ids=ids, api_user=user)
@validate_body_root_model(json=apispec.IntegerIds)
async def _put(_: Request, user: base_models.APIUser, user_id: str, body: apispec.IntegerIds) -> HTTPResponse:
return await self._post_put(user_id=user_id, post=False, resource_pool_ids=body, api_user=user)

return "/users/<user_id>/resource_pools", ["PUT"], _put

Expand Down
9 changes: 4 additions & 5 deletions components/renku_data_services/crc/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ async def filter_resource_pools(
max_storage: int = 0,
gpu: int = 0,
) -> list[models.ResourcePool]:
"""Get resource pools from database with indication of which resource class matches the specified crtieria."""
"""Get resource pools from database with indication of which resource class matches the specified criteria."""
async with self.session_maker() as session:
criteria = models.ResourceClass(
name="criteria",
Expand All @@ -205,18 +205,17 @@ async def filter_resource_pools(
)
stmt = (
select(schemas.ResourcePoolORM)
.join(schemas.ResourcePoolORM.classes)
.distinct()
.options(selectinload(schemas.ResourcePoolORM.classes))
.order_by(
schemas.ResourcePoolORM.id,
schemas.ResourcePoolORM.name,
schemas.ResourceClassORM.id,
schemas.ResourceClassORM.name,
)
)
# NOTE: The line below ensures that the right users can access the right resources, do not remove.
stmt = _resource_pool_access_control(api_user, stmt)
res = await session.execute(stmt)
return [i.dump(self.quotas_repo.get_quota(i.quota), criteria) for i in res.unique().scalars().all()]
return [i.dump(self.quotas_repo.get_quota(i.quota), criteria) for i in res.scalars().all()]

@_only_admins
async def insert_resource_pool(
Expand Down
21 changes: 17 additions & 4 deletions components/renku_data_services/crc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,23 @@ def __post_init__(self) -> None:
if len(default_classes) != 1:
raise ValidationError(message="One default class is required in each resource pool.")

# We need to sort classes to make '__eq__' reliable
object.__setattr__(
self, "classes", sorted(self.classes, key=lambda x: (x.default, x.cpu, x.memory, x.default_storage, x.name))
)
def __eq__(self, other: Any) -> bool:
"""Check two resource pools for equality."""
if not isinstance(other, ResourcePool):
return False

if self.id != other.id:
return False
if self.name != other.name:
return False
if self.default != other.default or self.public != other.public:
return False
if self.idle_threshold != other.idle_threshold or self.hibernation_threshold != other.hibernation_threshold:
return False

this_classes = sorted(self.classes, key=lambda x: (x.default, x.cpu, x.memory, x.default_storage, x.name))
other_classes = sorted(other.classes, key=lambda x: (x.default, x.cpu, x.memory, x.default_storage, x.name))
return this_classes == other_classes

def set_quota(self, val: Quota) -> "ResourcePool":
"""Set the quota for a resource pool."""
Expand Down
4 changes: 4 additions & 0 deletions components/renku_data_services/crc/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ class ResourcePoolORM(BaseORM):
default_factory=list,
cascade="save-update, merge, delete",
lazy="selectin",
order_by=(
"[ResourceClassORM.gpu,ResourceClassORM.cpu,ResourceClassORM.memory,ResourceClassORM.max_storage,"
"ResourceClassORM.name,ResourceClassORM.id]"
),
)
idle_threshold: Mapped[Optional[int]] = mapped_column(default=None)
hibernation_threshold: Mapped[Optional[int]] = mapped_column(default=None)
Expand Down
12 changes: 6 additions & 6 deletions components/renku_data_services/namespace/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from renku_data_services.authz.models import Role, UnsavedMember
from renku_data_services.base_api.auth import authenticate, only_authenticated
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_query
from renku_data_services.base_api.pagination import PaginationRequest, paginate
from renku_data_services.base_models.validation import validate_and_dump, validated_json
from renku_data_services.errors import errors
Expand Down Expand Up @@ -118,11 +118,11 @@ def update_members(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@only_authenticated
async def _update_members(request: Request, user: base_models.APIUser, slug: str) -> JSONResponse:
# TODO: sanic validation does not support validating top-level json lists, switch this to @validate
# once sanic-org/sanic-ext/issues/198 is fixed
body_validated = apispec.GroupMemberPatchRequestList.model_validate(request.json)
members = [UnsavedMember(Role.from_group_role(member.role), member.id) for member in body_validated.root]
@validate_body_root_model(json=apispec.GroupMemberPatchRequestList)
async def _update_members(
_: Request, user: base_models.APIUser, slug: str, body: apispec.GroupMemberPatchRequestList
) -> JSONResponse:
members = [UnsavedMember(Role.from_group_role(member.role), member.id) for member in body.root]
res = await self.group_repo.update_group_members(
user=user,
slug=slug,
Expand Down
10 changes: 6 additions & 4 deletions components/renku_data_services/project/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_api.etag import if_match_required
from renku_data_services.base_api.misc import validate_query
from renku_data_services.base_api.misc import validate_body_root_model, validate_query
from renku_data_services.base_api.pagination import PaginationRequest, paginate
from renku_data_services.errors import errors
from renku_data_services.project import apispec
Expand Down Expand Up @@ -259,9 +259,11 @@ def update_members(self) -> BlueprintFactoryResponse:

@authenticate(self.authenticator)
@validate_path_project_id
async def _update_members(request: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse:
body_dump = apispec.ProjectMemberListPatchRequest.model_validate(request.json)
members = [Member(Role(i.role.value), i.id, project_id) for i in body_dump.root]
@validate_body_root_model(json=apispec.ProjectMemberListPatchRequest)
async def _update_members(
_: Request, user: base_models.APIUser, project_id: str, body: apispec.ProjectMemberListPatchRequest
) -> HTTPResponse:
members = [Member(Role(i.role.value), i.id, project_id) for i in body.root]
await self.project_member_repo.update_members(user, ULID.from_str(project_id), members)
return HTTPResponse(status=200)

Expand Down
Loading

0 comments on commit dde4e44

Please sign in to comment.