Skip to content

Commit

Permalink
Permission commands migrated. (#228)
Browse files Browse the repository at this point in the history
* Permission commands migrated.

---------

Co-authored-by: pederhan <pederhan@usit.uio.no>
  • Loading branch information
terjekv and pederhan authored May 21, 2024
1 parent b7c7370 commit b79a1b0
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 112 deletions.
65 changes: 59 additions & 6 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,47 +303,100 @@ def get_by_field_and_raise(

@classmethod
def get_list_by_field(
cls, field: str, value: str | int, ordering: str | None = None
cls, field: str, value: str | int, ordering: str | None = None, limit: int = 500
) -> list[Self]:
"""Get a list of objects by a field.
:param field: The field to search by.
:param value: The value to search for.
:param ordering: The ordering to use when fetching the list.
:param limit: The maximum number of hits to allow (default 500)
:returns: A list of objects if found, an empty list otherwise.
"""
params = {field: value}
if ordering:
params["ordering"] = ordering

data = get_list(cls.endpoint(), params=params)
data = get_list(cls.endpoint(), params=params, max_hits_to_allow=limit)
return [cls(**item) for item in data]

@classmethod
def get_by_query(cls, query: dict[str, str], ordering: str | None = None) -> list[Self]:
def get_by_query(
cls, query: dict[str, str], ordering: str | None = None, limit: int = 500
) -> list[Self]:
"""Get a list of objects by a query.
:param query: The query to search by.
:param ordering: The ordering to use when fetching the list.
:param limit: The maximum number of hits to allow (default 500)
:returns: A list of objects if found, an empty list otherwise.
"""
if ordering:
query["ordering"] = ordering

data = get_list(cls.endpoint().with_query(query))
data = get_list(cls.endpoint().with_query(query), max_hits_to_allow=limit)
return [cls(**item) for item in data]

@classmethod
def get_by_query_unique(cls, data: dict[str, str]) -> Self:
def get_by_query_unique_or_raise(
cls,
query: dict[str, str],
exc_type: type[Exception] = EntityNotFound,
exc_message: str | None = None,
) -> Self:
"""Get an object by a query and raise if not found.
Used for cases where the object must exist for the operation to continue.
:param query: The query to search by.
:param exc_type: The exception type to raise.
:param exc_message: The exception message. Overrides the default message.
:returns: The object if found.
"""
obj = cls.get_by_query_unique(query)
if not obj:
if not exc_message:
exc_message = f"{cls.__name__} with query {query} not found."
raise exc_type(exc_message)
return obj

@classmethod
def get_by_query_unique_and_raise(
cls,
query: dict[str, str],
exc_type: type[Exception] = EntityAlreadyExists,
exc_message: str | None = None,
) -> None:
"""Get an object by a query and raise if found.
Used for cases where the object must NOT exist for the operation to continue.
:param query: The query to search by.
:param exc_type: The exception type to raise.
:param exc_message: The exception message. Overrides the default message.
:raises Exception: If the object is found.
"""
obj = cls.get_by_query_unique(query)
if obj:
if not exc_message:
exc_message = f"{cls.__name__} with query {query} already exists."
raise exc_type(exc_message)
return None

@classmethod
def get_by_query_unique(cls, data: dict[str, str]) -> Self | None:
"""Get an object with the given data.
:param data: The data to search for.
:returns: The object if found, None otherwise.
"""
obj_dict = get_list_unique(cls.endpoint(), params=data)
if not obj_dict:
raise EntityNotFound(f"{cls.__name__} record for {data} not found.")
return None
return cls(**obj_dict)

def refetch(self) -> Self:
Expand Down
39 changes: 38 additions & 1 deletion mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,19 @@ class Permission(FrozenModelWithTimestamps, APIMixin):

id: int # noqa: A003
group: str
range: str # noqa: A003
range: IP_NetworkT # noqa: A003
regex: str
labels: list[int]

@field_validator("range", mode="before")
@classmethod
def validate_ip_or_network(cls, value: str) -> IP_NetworkT:
"""Validate and convert the input to a network."""
try:
return ipaddress.ip_network(value)
except ValueError as e:
raise InputFailure(f"Invalid input for network: {value}") from e

@classmethod
def endpoint(cls) -> Endpoint:
"""Return the endpoint for the class."""
Expand All @@ -364,6 +373,34 @@ def output_multiple(cls, permissions: list[Permission], indent: int = 4) -> None
indent=indent,
)

def add_label(self, label_name: str) -> Self:
"""Add a label to the permission.
:param label_name: The name of the label to add.
:returns: The updated Permission object.
"""
label = Label.get_by_name_or_raise(label_name)
if label.id in self.labels:
raise EntityAlreadyExists(f"The permission already has the label {label_name!r}")

label_ids = self.labels.copy()
label_ids.append(label.id)
return self.patch({"labels": label_ids})

def remove_label(self, label_name: str) -> Self:
"""Remove a label from the permission.
:param label_name: The name of the label to remove.
:returns: The updated Permission object.
"""
label = Label.get_by_name_or_raise(label_name)
if label.id not in self.labels:
raise EntityNotFound(f"The permission does not have the label {label_name!r}")

label_ids = self.labels.copy()
label_ids.remove(label.id)
return self.patch({"labels": label_ids}, use_json=True)


class Zone(FrozenModelWithTimestamps, WithTTL):
"""Model representing a DNS zone with various attributes and related nameservers."""
Expand Down
Loading

0 comments on commit b79a1b0

Please sign in to comment.