Skip to content

Commit

Permalink
Add WithName mixin (#218)
Browse files Browse the repository at this point in the history
* Ensure resource exists

* Add `APIMixin.get_by_field_or_raise`

* Hack: Inherit from `APIMixin[Any]`

* Remove invalid superclass

* Use `BaseModel` as generic type for `WithName`

* Fix `ensure_name_*` error messages

* Use `Any` as type argument for `WithName`

* Add `WithName.__name_field__`

* Clarify `__name_field__` docstring

* Add `get_by_name_or_raise`, refactor `get_by_name`

* Fix `WithName` method docstrings

* Remove TODO
  • Loading branch information
pederhan authored May 7, 2024
1 parent 38f7353 commit 34f0bdc
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 32 deletions.
53 changes: 53 additions & 0 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mreg_cli.api.endpoints import Endpoint
from mreg_cli.api.history import HistoryItem, HistoryResource
from mreg_cli.exceptions import CliError
from mreg_cli.log import cli_error, cli_warning
from mreg_cli.outputmanager import OutputManager
from mreg_cli.utilities.api import (
Expand Down Expand Up @@ -193,6 +194,58 @@ def get_by_field(cls, field: str, value: str) -> BMT | None:

return cast(BMT, cls(**data))

@classmethod
def get_by_field_or_raise(
cls,
field: str,
value: str,
exc_type: type[Exception] = CliError,
exc_message: str | None = None,
) -> BMT:
"""Get an object by a field and raise if not found.
Used for cases where the object must exist for the operation to continue.
:param field: The field to search by.
:param value: The value to search for.
:param exc_type: The exception type to raise.
:param exc_message: The exception message. Overrides the default message.
:returns: The object if found.
"""
obj = cls.get_by_field(field, value)
if not obj:
if not exc_message:
exc_message = f"{cls.__name__} with {field} {value!r} not found."
raise exc_type(exc_message)
return obj

@classmethod
def get_by_field_and_raise(
cls,
field: str,
value: str,
exc_type: type[Exception] = CliError,
exc_message: str | None = None,
) -> None:
"""Get an object by a field and raise if found.
Used for cases where the object must NOT exist for the operation to continue.
:param field: The field to search by.
:param value: The value to search for.
:param exc_type: The exception type to raise.
:param exc_message: The exception message. Overrides the default message.
:raises Exception: If the object is found.
"""
obj = cls.get_by_field(field, value)
if obj:
if not exc_message:
exc_message = f"{cls.__name__} with {field} {value!r} already exists."
raise exc_type(exc_message)
return None

@classmethod
def get_list_by_field(
cls, field: str, value: str | int, ordering: str | None = None
Expand Down
88 changes: 57 additions & 31 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ipaddress
import re
from datetime import date, datetime
from typing import Any, Literal, cast
from typing import Any, Literal, Self, cast

from pydantic import (
AliasChoices,
Expand Down Expand Up @@ -256,6 +256,48 @@ def valid_numeric_ttl(self, ttl: int) -> int:
return ttl


class WithName(APIMixin[Any]):
"""Mixin type for an object that has a name element."""

__name_field__: str = "name"
"""Name of the API field that holds the object's name."""

@classmethod
def ensure_name_not_exists(cls, name: str) -> None:
"""Ensure a name is not already used.
:param name: The name to check for uniqueness.
"""
cls.get_by_field_and_raise(cls.__name_field__, name)

@classmethod
def ensure_name_exists(cls, name: str) -> None:
"""Ensure a resource with the name exists.
:param name: The name to check for existence.
"""
cls.get_by_name_or_raise(name) # pyright: ignore[reportUnusedCallResult]

@classmethod
def get_by_name(cls, name: str) -> Self | None:
"""Get a resource by name.
:param name: The resource name to search for.
:returns: The resource if found.
"""
return cls.get_by_field(cls.__name_field__, name)

@classmethod
def get_by_name_or_raise(cls, name: str) -> Self:
"""Get a resource by name, raising a CliWarning if not found.
:param name: The resource name to search for.
:returns: The resource.
:raises CliWarning: If the role is not found.
"""
return cls.get_by_field_or_raise(cls.__name_field__, name)


class NameServer(FrozenModelWithTimestamps, WithTTL):
"""Model for representing a nameserver within a DNS zone."""

Expand Down Expand Up @@ -385,7 +427,7 @@ def created_at(self) -> datetime:
return self.created_at_tz_naive.replace(tzinfo=self.updated_at.tzinfo)

@classmethod
def get_by_name(cls, name: str) -> Atom | Role:
def get_role_or_atom(cls, name: str) -> Atom | Role:
"""Get an Atom or Role by name.
:param name: The name to search for.
Expand All @@ -401,7 +443,7 @@ def get_by_name(cls, name: str) -> Atom | Role:
else:
break # found a match
else:
cli_warning(f"Could not find an atom or a role with name: {name!r}")
cli_warning(f"Could not find an atom or a role with name {name}")
return role_or_atom

def output_timestamps(self, padding: int = 14) -> None:
Expand All @@ -422,7 +464,7 @@ def output(self, padding: int = 14) -> None:
output_manager.add_line(f"{'Description:':<{padding}}{self.description}")


class Role(HostPolicy, APIMixin["Role"]):
class Role(HostPolicy, WithName, APIMixin["Role"]):
"""Model for a role."""

id: int # noqa: A003
Expand All @@ -435,19 +477,6 @@ def endpoint(cls) -> Endpoint:
"""Return the endpoint for the class."""
return Endpoint.HostPolicyRoles

@classmethod
def get_by_name(cls, name: str) -> Role:
"""Get a Role by name.
:param name: The role name to search for.
:returns: The role if found.
:raises CliWarning: If the role is not found.
"""
data = get_item_by_key_value(Endpoint.HostPolicyRoles, "name", name)
if not data:
cli_warning(f"Role with name {name} not found.")
return cls(**data)

def output(self, padding: int = 14) -> None:
"""Output the role to the console.
Expand Down Expand Up @@ -484,8 +513,18 @@ def get_labels(self) -> list[Label]:
"""
return [Label.get_by_id_or_raise(id_) for id_ in self.labels]

@classmethod
def get_roles_with_atom(cls, name: str) -> list[Role]:
"""Get all roles with a specific atom.
:param atom: Name of the atom to search for.
:returns: A list of Role objects.
"""
data = get_list(cls.endpoint(), params={"atoms__name__exact": name})
return [Role(**item) for item in data]


class Atom(HostPolicy, APIMixin["Atom"]):
class Atom(HostPolicy, WithName, APIMixin["Atom"]):
"""Model for an atom."""

id: int # noqa: A003
Expand All @@ -496,19 +535,6 @@ def endpoint(cls) -> Endpoint:
"""Return the endpoint for the class."""
return Endpoint.HostPolicyAtoms

@classmethod
def get_by_name(cls, name: str) -> Atom:
"""Get an Atom by name.
:param name: The atom name to search for.
:returns: The atom if found.
:raises CliWarning: If the atom is not found.
"""
data = get_item_by_key_value(Endpoint.HostPolicyAtoms, "name", name)
if not data:
cli_warning(f"Atom with name {name} not found.")
return cls(**data)

def output(self, padding: int = 14) -> None:
"""Output the role to the console.
Expand Down
2 changes: 1 addition & 1 deletion mreg_cli/commands/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def info(args: argparse.Namespace) -> None:
"""
names: list[str] = args.name
for name in names:
role_or_atom = HostPolicy.get_by_name(name)
role_or_atom = HostPolicy.get_role_or_atom(name)
role_or_atom.output()


Expand Down

0 comments on commit 34f0bdc

Please sign in to comment.