From f0eb55ccda45b41060bcc93f382510c50fb8f2c0 Mon Sep 17 00:00:00 2001 From: Terje Kvernes Date: Tue, 23 Apr 2024 12:59:31 +0200 Subject: [PATCH] Split the models.py into smaller files. --- mreg_cli/api/abstracts.py | 285 +++++++++++++++++++++++++++++++ mreg_cli/api/fields.py | 67 ++++++++ mreg_cli/api/models.py | 343 +------------------------------------- 3 files changed, 356 insertions(+), 339 deletions(-) create mode 100644 mreg_cli/api/abstracts.py create mode 100644 mreg_cli/api/fields.py diff --git a/mreg_cli/api/abstracts.py b/mreg_cli/api/abstracts.py new file mode 100644 index 00000000..35ec12a4 --- /dev/null +++ b/mreg_cli/api/abstracts.py @@ -0,0 +1,285 @@ +"""Abstract models for the API.""" + + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast + +from pydantic import BaseModel +from pydantic.fields import AliasChoices, FieldInfo + +from mreg_cli.api.endpoints import Endpoint +from mreg_cli.log import cli_warning +from mreg_cli.outputmanager import OutputManager +from mreg_cli.utilities.api import delete, get, get_item_by_key_value, get_list, patch, post + +BMT = TypeVar("BMT", bound="BaseModel") + + +def get_field_aliases(field_info: FieldInfo) -> Set[str]: + """Get all aliases for a Pydantic field.""" + aliases: set[str] = set() + + if field_info.alias: + aliases.add(field_info.alias) + + if field_info.validation_alias: + if isinstance(field_info.validation_alias, str): + aliases.add(field_info.validation_alias) + elif isinstance(field_info.validation_alias, AliasChoices): + for choice in field_info.validation_alias.choices: + if isinstance(choice, str): + aliases.add(choice) + return aliases + + +def get_model_aliases(model: BaseModel) -> Dict[str, str]: + """Get a mapping of aliases to field names for a Pydantic model. + + Includes field names, alias, and validation alias(es). + """ + fields = {} # type: Dict[str, str] + + for field_name, field_info in model.model_fields.items(): + aliases = get_field_aliases(field_info) + if model.model_config.get("populate_by_name"): + aliases.add(field_name) + # Assign aliases to field name in mapping + for alias in aliases: + fields[alias] = field_name + + return fields + + +class FrozenModel(BaseModel): + """Model for an immutable object.""" + + def __setattr__(self, name: str, value: Any): + """Raise an exception when trying to set an attribute.""" + raise AttributeError("Cannot set attribute on a frozen object") + + def __delattr__(self, name: str): + """Raise an exception when trying to delete an attribute.""" + raise AttributeError("Cannot delete attribute on a frozen object") + + class Config: + """Pydantic configuration. + + Set the class to frozen to make it immutable and thus hashable. + """ + + frozen = True + + +class FrozenModelWithTimestamps(FrozenModel): + """Model with created_at and updated_at fields.""" + + created_at: datetime + updated_at: datetime + + def output_timestamps(self, padding: int = 14) -> None: + """Output the created and updated timestamps to the console.""" + output_manager = OutputManager() + output_manager.add_line(f"{'Created:':<{padding}}{self.created_at:%c}") + output_manager.add_line(f"{'Updated:':<{padding}}{self.updated_at:%c}") + + +class APIMixin(Generic[BMT], ABC): + """A mixin for API-related methods.""" + + id: int # noqa: A003 + + def id_for_endpoint(self) -> Union[int, str]: + """Return the appropriate id for the object for its endpoint. + + :returns: The correct identifier for the endpoint. + """ + field = self.endpoint().external_id_field() + return getattr(self, field) + + @classmethod + def field_for_endpoint(cls) -> str: + """Return the appropriate field for the object for its endpoint. + + :param field: The field to return. + :returns: The correct field for the endpoint. + """ + return cls.endpoint().external_id_field() + + @classmethod + @abstractmethod + def endpoint(cls) -> Endpoint: + """Return the endpoint for the method.""" + raise NotImplementedError("You must define an endpoint.") + + @classmethod + def get(cls, _id: int) -> Optional[BMT]: + """Get an object. + + This function is at its base a wrapper around the get_by_id function, + but it can be overridden to provide more specific functionality. + + :param _id: The ID of the object. + :returns: The object if found, None otherwise. + """ + return cls.get_by_id(_id) + + @classmethod + def get_by_id(cls, _id: int) -> Optional[BMT]: + """Get an object by its ID. + + Note that for Hosts, the ID is the name of the host. + + :param _id: The ID of the object. + :returns: The object if found, None otherwise. + """ + endpoint = cls.endpoint() + + # Some endpoints do not use the ID field as the endpoint identifier, + # and in these cases we need to search for the ID... Lovely. + if endpoint.requires_search_for_id(): + data = get_item_by_key_value(cls.endpoint(), "id", str(_id)) + else: + data = get(cls.endpoint().with_id(_id), ok404=True) + if not data: + return None + data = data.json() + + if not data: + return None + + return cast(BMT, cls(**data)) + + @classmethod + def get_by_field(cls, field: str, value: str) -> Optional[BMT]: + """Get an object by a field. + + Note that some endpoints do not use the ID field for lookups. We do some + magic mapping via endpoint introspection to perform the following mapping for + classes and their endpoint "id" fields: + + - Hosts -> name + - Networks -> network + + This implies that doing a get_by_field("name", value) on Hosts will *not* + result in a search, but a direct lookup at ../endpoint/name which is what + the mreg server expects for Hosts (and similar for Network). + + :param field: The field to search by. + :param value: The value to search for. + + :returns: The object if found, None otherwise. + """ + endpoint = cls.endpoint() + + if endpoint.requires_search_for_id() and field == endpoint.external_id_field(): + data = get(endpoint.with_id(value), ok404=True) + if not data: + return None + data = data.json() + else: + data = get_item_by_key_value(cls.endpoint(), field, value, ok404=True) + + if not data: + return None + + return cast(BMT, cls(**data)) + + @classmethod + def get_list_by_field( + cls, field: str, value: Union[str, int], ordering: Optional[str] = None + ) -> List[BMT]: + """Get a list of objects by a field. + + :param field: The field to search by. + :param value: The value to search for. + :param ordering: The ordering to use when fetching the list. + + :returns: A list of objects if found, an empty list otherwise. + """ + params = {field: value} + if ordering: + params["ordering"] = ordering + + data = get_list(cls.endpoint(), params=params) + return [cast(BMT, cls(**item)) for item in data] + + def refetch(self) -> BMT: + """Fetch an updated version of the object. + + Note that the caller (self) of this method will remain unchanged and can contain + outdated information. The returned object will be the updated version. + + :returns: The fetched object. + """ + obj = self.__class__.get_by_id(self.id) + if not obj: + cli_warning(f"Could not refresh {self.__class__.__name__} with ID {self.id}.") + + return obj + + def patch(self, fields: Dict[str, Any]) -> BMT: + """Patch the object with the given values. + + :param kwargs: The values to patch. + :returns: The object refetched from the server. + """ + patch(self.endpoint().with_id(self.id), **fields) + + new_object = self.refetch() + + aliases = get_model_aliases(new_object) + for key, value in fields.items(): + field_name = aliases.get(key) + if field_name is None: + cli_warning(f"Unknown field {key} in patch request.") + try: + nval = getattr(new_object, field_name) + except AttributeError: + cli_warning(f"Could not get value for {field_name} in patched object.") + if str(nval) != str(value): + cli_warning( + # Should this reference `field_name` instead of `key`? + f"Patch failure! Tried to set {key} to {value}, but server returned {nval}." + ) + + return new_object + + def delete(self) -> bool: + """Delete the object. + + :returns: True if the object was deleted, False otherwise. + """ + response = delete(self.endpoint().with_id(self.id_for_endpoint())) + + if response and response.ok: + return True + + return False + + @classmethod + def create(cls, kwargs: Dict[str, Union[str, None]]) -> Union[None, BMT]: + """Create the object. + + :returns: The object if created, None otherwise. + """ + response = post(cls.endpoint(), params=None, **kwargs) + + if response and response.ok: + location = response.headers.get("Location") + if location: + obj = None + if cls.endpoint() is Endpoint.Hosts: + obj = cls.get_by_field("name", location.split("/")[-1]) + else: + obj = cls.get_by_id(int(location.split("/")[-1])) + + if obj: + return obj + + cli_warning(f"Could not fetch object from location {location}.") + + else: + cli_warning("No location header in response.") + + return None diff --git a/mreg_cli/api/fields.py b/mreg_cli/api/fields.py new file mode 100644 index 00000000..187b4929 --- /dev/null +++ b/mreg_cli/api/fields.py @@ -0,0 +1,67 @@ +"""Fields for models of the API.""" + +import ipaddress +import re + +from pydantic import validator + +from mreg_cli.api.abstracts import FrozenModel +from mreg_cli.types import IP_AddressT + +_mac_regex = re.compile(r"^([0-9A-Fa-f]{2}[.:-]){5}([0-9A-Fa-f]{2})$") + + +class MACAddressField(FrozenModel): + """Represents a MAC address.""" + + address: str + + @validator("address", pre=True) + def validate_and_format_mac(cls, v: str) -> str: + """Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format. + + :param v: The input MAC address string. + :raises ValueError: If the input does not match the expected MAC address pattern. + :returns: The normalized MAC address. + """ + # Validate input format + if not _mac_regex.match(v): + raise ValueError("Invalid MAC address format") + + # Normalize MAC address + v = re.sub(r"[.:-]", "", v).lower() + return ":".join(v[i : i + 2] for i in range(0, 12, 2)) + + def __str__(self) -> str: + """Return the MAC address as a string.""" + return self.address + + +class IPAddressField(FrozenModel): + """Represents an IP address, automatically determines if it's IPv4 or IPv6.""" + + address: IP_AddressT + + @validator("address", pre=True) + def parse_ip_address(cls, value: str) -> IP_AddressT: + """Parse and validate the IP address.""" + try: + return ipaddress.ip_address(value) + except ValueError as e: + raise ValueError(f"Invalid IP address '{value}'.") from e + + def is_ipv4(self) -> bool: + """Check if the IP address is IPv4.""" + return isinstance(self.address, ipaddress.IPv4Address) + + def is_ipv6(self) -> bool: + """Check if the IP address is IPv6.""" + return isinstance(self.address, ipaddress.IPv6Address) + + def __str__(self) -> str: + """Return the IP address as a string.""" + return str(self.address) + + def __hash__(self): + """Return a hash of the IP address.""" + return hash(self.address) diff --git a/mreg_cli/api/models.py b/mreg_cli/api/models.py index 0aeb4476..db3d217c 100644 --- a/mreg_cli/api/models.py +++ b/mreg_cli/api/models.py @@ -2,67 +2,22 @@ import ipaddress import re -from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast +from typing import Any, Dict, List, Optional, Union from pydantic import AliasChoices, BaseModel, Field, root_validator, validator -from pydantic.fields import FieldInfo +from mreg_cli.api.abstracts import APIMixin, FrozenModel, FrozenModelWithTimestamps from mreg_cli.api.endpoints import Endpoint +from mreg_cli.api.fields import IPAddressField, MACAddressField from mreg_cli.config import MregCliConfig from mreg_cli.log import cli_warning from mreg_cli.outputmanager import OutputManager from mreg_cli.types import IP_AddressT -from mreg_cli.utilities.api import ( - delete, - get, - get_item_by_key_value, - get_list, - get_list_in, - patch, - post, -) +from mreg_cli.utilities.api import delete, get, get_item_by_key_value, get_list, get_list_in _mac_regex = re.compile(r"^([0-9A-Fa-f]{2}[.:-]){5}([0-9A-Fa-f]{2})$") -BMT = TypeVar("BMT", bound="BaseModel") - - -def get_field_aliases(field_info: FieldInfo) -> Set[str]: - """Get all aliases for a Pydantic field.""" - aliases: set[str] = set() - - if field_info.alias: - aliases.add(field_info.alias) - - if field_info.validation_alias: - if isinstance(field_info.validation_alias, str): - aliases.add(field_info.validation_alias) - elif isinstance(field_info.validation_alias, AliasChoices): - for choice in field_info.validation_alias.choices: - if isinstance(choice, str): - aliases.add(choice) - return aliases - - -def get_model_aliases(model: BaseModel) -> Dict[str, str]: - """Get a mapping of aliases to field names for a Pydantic model. - - Includes field names, alias, and validation alias(es). - """ - fields = {} # type: Dict[str, str] - - for field_name, field_info in model.model_fields.items(): - aliases = get_field_aliases(field_info) - if model.model_config.get("populate_by_name"): - aliases.add(field_name) - # Assign aliases to field name in mapping - for alias in aliases: - fields[alias] = field_name - - return fields - class HostT(BaseModel): """A type for a hostname.""" @@ -101,39 +56,6 @@ def __repr__(self) -> str: return self.hostname -class FrozenModel(BaseModel): - """Model for an immutable object.""" - - def __setattr__(self, name: str, value: Any): - """Raise an exception when trying to set an attribute.""" - raise AttributeError("Cannot set attribute on a frozen object") - - def __delattr__(self, name: str): - """Raise an exception when trying to delete an attribute.""" - raise AttributeError("Cannot delete attribute on a frozen object") - - class Config: - """Pydantic configuration. - - Set the class to frozen to make it immutable and thus hashable. - """ - - frozen = True - - -class FrozenModelWithTimestamps(FrozenModel): - """Model with created_at and updated_at fields.""" - - created_at: datetime - updated_at: datetime - - def output_timestamps(self, padding: int = 14) -> None: - """Output the created and updated timestamps to the console.""" - output_manager = OutputManager() - output_manager.add_line(f"{'Created:':<{padding}}{self.created_at:%c}") - output_manager.add_line(f"{'Updated:':<{padding}}{self.updated_at:%c}") - - class WithHost(BaseModel): """Model for an object that has a host element.""" @@ -178,207 +100,6 @@ def resolve_zone(self) -> Union["Zone", None]: return Zone(**data) -class APIMixin(Generic[BMT], ABC): - """A mixin for API-related methods.""" - - id: int # noqa: A003 - - def id_for_endpoint(self) -> Union[int, str]: - """Return the appropriate id for the object for its endpoint. - - :returns: The correct identifier for the endpoint. - """ - field = self.endpoint().external_id_field() - return getattr(self, field) - - @classmethod - def field_for_endpoint(cls) -> str: - """Return the appropriate field for the object for its endpoint. - - :param field: The field to return. - :returns: The correct field for the endpoint. - """ - return cls.endpoint().external_id_field() - - @classmethod - @abstractmethod - def endpoint(cls) -> Endpoint: - """Return the endpoint for the method.""" - raise NotImplementedError("You must define an endpoint.") - - @classmethod - def get(cls, _id: int) -> Optional[BMT]: - """Get an object. - - This function is at its base a wrapper around the get_by_id function, - but it can be overridden to provide more specific functionality. - - :param _id: The ID of the object. - :returns: The object if found, None otherwise. - """ - return cls.get_by_id(_id) - - @classmethod - def get_by_id(cls, _id: int) -> Optional[BMT]: - """Get an object by its ID. - - Note that for Hosts, the ID is the name of the host. - - :param _id: The ID of the object. - :returns: The object if found, None otherwise. - """ - endpoint = cls.endpoint() - - # Some endpoints do not use the ID field as the endpoint identifier, - # and in these cases we need to search for the ID... Lovely. - if endpoint.requires_search_for_id(): - data = get_item_by_key_value(cls.endpoint(), "id", str(_id)) - else: - data = get(cls.endpoint().with_id(_id), ok404=True) - if not data: - return None - data = data.json() - - if not data: - return None - - return cast(BMT, cls(**data)) - - @classmethod - def get_by_field(cls, field: str, value: str) -> Optional[BMT]: - """Get an object by a field. - - Note that some endpoints do not use the ID field for lookups. We do some - magic mapping via endpoint introspection to perform the following mapping for - classes and their endpoint "id" fields: - - - Hosts -> name - - Networks -> network - - This implies that doing a get_by_field("name", value) on Hosts will *not* - result in a search, but a direct lookup at ../endpoint/name which is what - the mreg server expects for Hosts (and similar for Network). - - :param field: The field to search by. - :param value: The value to search for. - - :returns: The object if found, None otherwise. - """ - endpoint = cls.endpoint() - - if endpoint.requires_search_for_id() and field == endpoint.external_id_field(): - data = get(endpoint.with_id(value), ok404=True) - if not data: - return None - data = data.json() - else: - data = get_item_by_key_value(cls.endpoint(), field, value, ok404=True) - - if not data: - return None - - return cast(BMT, cls(**data)) - - @classmethod - def get_list_by_field( - cls, field: str, value: Union[str, int], ordering: Optional[str] = None - ) -> List[BMT]: - """Get a list of objects by a field. - - :param field: The field to search by. - :param value: The value to search for. - :param ordering: The ordering to use when fetching the list. - - :returns: A list of objects if found, an empty list otherwise. - """ - params = {field: value} - if ordering: - params["ordering"] = ordering - - data = get_list(cls.endpoint(), params=params) - return [cast(BMT, cls(**item)) for item in data] - - def refetch(self) -> BMT: - """Fetch an updated version of the object. - - Note that the caller (self) of this method will remain unchanged and can contain - outdated information. The returned object will be the updated version. - - :returns: The fetched object. - """ - obj = self.__class__.get_by_id(self.id) - if not obj: - cli_warning(f"Could not refresh {self.__class__.__name__} with ID {self.id}.") - - return obj - - def patch(self, fields: Dict[str, Any]) -> BMT: - """Patch the object with the given values. - - :param kwargs: The values to patch. - :returns: The object refetched from the server. - """ - patch(self.endpoint().with_id(self.id), **fields) - - new_object = self.refetch() - - aliases = get_model_aliases(new_object) - for key, value in fields.items(): - field_name = aliases.get(key) - if field_name is None: - cli_warning(f"Unknown field {key} in patch request.") - try: - nval = getattr(new_object, field_name) - except AttributeError: - cli_warning(f"Could not get value for {field_name} in patched object.") - if str(nval) != str(value): - cli_warning( - # Should this reference `field_name` instead of `key`? - f"Patch failure! Tried to set {key} to {value}, but server returned {nval}." - ) - - return new_object - - def delete(self) -> bool: - """Delete the object. - - :returns: True if the object was deleted, False otherwise. - """ - response = delete(self.endpoint().with_id(self.id_for_endpoint())) - - if response and response.ok: - return True - - return False - - @classmethod - def create(cls, kwargs: Dict[str, Union[str, None]]) -> Union[None, BMT]: - """Create the object. - - :returns: The object if created, None otherwise. - """ - response = post(cls.endpoint(), params=None, **kwargs) - - if response and response.ok: - location = response.headers.get("Location") - if location: - obj = None - if cls.endpoint() is Endpoint.Hosts: - obj = cls.get_by_field("name", location.split("/")[-1]) - else: - obj = cls.get_by_id(int(location.split("/")[-1])) - - if obj: - return obj - - cli_warning(f"Could not fetch object from location {location}.") - - else: - cli_warning("No location header in response.") - - return None - - class NameServer(FrozenModelWithTimestamps): """Model for representing a nameserver within a DNS zone.""" @@ -554,62 +275,6 @@ def __hash__(self): return hash((self.id, self.network)) -class MACAddressField(FrozenModel): - """Represents a MAC address.""" - - address: str - - @validator("address", pre=True) - def validate_and_format_mac(cls, v: str) -> str: - """Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format. - - :param v: The input MAC address string. - :raises ValueError: If the input does not match the expected MAC address pattern. - :returns: The normalized MAC address. - """ - # Validate input format - if not _mac_regex.match(v): - raise ValueError("Invalid MAC address format") - - # Normalize MAC address - v = re.sub(r"[.:-]", "", v).lower() - return ":".join(v[i : i + 2] for i in range(0, 12, 2)) - - def __str__(self) -> str: - """Return the MAC address as a string.""" - return self.address - - -class IPAddressField(FrozenModel): - """Represents an IP address, automatically determines if it's IPv4 or IPv6.""" - - address: IP_AddressT - - @validator("address", pre=True) - def parse_ip_address(cls, value: str) -> IP_AddressT: - """Parse and validate the IP address.""" - try: - return ipaddress.ip_address(value) - except ValueError as e: - raise ValueError(f"Invalid IP address '{value}'.") from e - - def is_ipv4(self) -> bool: - """Check if the IP address is IPv4.""" - return isinstance(self.address, ipaddress.IPv4Address) - - def is_ipv6(self) -> bool: - """Check if the IP address is IPv6.""" - return isinstance(self.address, ipaddress.IPv6Address) - - def __str__(self) -> str: - """Return the IP address as a string.""" - return str(self.address) - - def __hash__(self): - """Return a hash of the IP address.""" - return hash(self.address) - - class IPAddress(FrozenModelWithTimestamps, WithHost, APIMixin["IPAddress"]): """Represents an IP address with associated details."""