Skip to content

Commit

Permalink
Split the models.py into smaller files.
Browse files Browse the repository at this point in the history
  • Loading branch information
terjekv committed Apr 23, 2024
1 parent eb37282 commit f0eb55c
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 339 deletions.
285 changes: 285 additions & 0 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
"""Abstract models for the API."""


from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast

from pydantic import BaseModel
from pydantic.fields import AliasChoices, FieldInfo

from mreg_cli.api.endpoints import Endpoint
from mreg_cli.log import cli_warning
from mreg_cli.outputmanager import OutputManager
from mreg_cli.utilities.api import delete, get, get_item_by_key_value, get_list, patch, post

BMT = TypeVar("BMT", bound="BaseModel")


def get_field_aliases(field_info: FieldInfo) -> Set[str]:
"""Get all aliases for a Pydantic field."""
aliases: set[str] = set()

if field_info.alias:
aliases.add(field_info.alias)

if field_info.validation_alias:
if isinstance(field_info.validation_alias, str):
aliases.add(field_info.validation_alias)
elif isinstance(field_info.validation_alias, AliasChoices):
for choice in field_info.validation_alias.choices:
if isinstance(choice, str):
aliases.add(choice)
return aliases


def get_model_aliases(model: BaseModel) -> Dict[str, str]:
"""Get a mapping of aliases to field names for a Pydantic model.
Includes field names, alias, and validation alias(es).
"""
fields = {} # type: Dict[str, str]

for field_name, field_info in model.model_fields.items():
aliases = get_field_aliases(field_info)
if model.model_config.get("populate_by_name"):
aliases.add(field_name)
# Assign aliases to field name in mapping
for alias in aliases:
fields[alias] = field_name

return fields


class FrozenModel(BaseModel):
"""Model for an immutable object."""

def __setattr__(self, name: str, value: Any):
"""Raise an exception when trying to set an attribute."""
raise AttributeError("Cannot set attribute on a frozen object")

def __delattr__(self, name: str):
"""Raise an exception when trying to delete an attribute."""
raise AttributeError("Cannot delete attribute on a frozen object")

class Config:
"""Pydantic configuration.
Set the class to frozen to make it immutable and thus hashable.
"""

frozen = True


class FrozenModelWithTimestamps(FrozenModel):
"""Model with created_at and updated_at fields."""

created_at: datetime
updated_at: datetime

def output_timestamps(self, padding: int = 14) -> None:
"""Output the created and updated timestamps to the console."""
output_manager = OutputManager()
output_manager.add_line(f"{'Created:':<{padding}}{self.created_at:%c}")
output_manager.add_line(f"{'Updated:':<{padding}}{self.updated_at:%c}")


class APIMixin(Generic[BMT], ABC):
"""A mixin for API-related methods."""

id: int # noqa: A003

def id_for_endpoint(self) -> Union[int, str]:
"""Return the appropriate id for the object for its endpoint.
:returns: The correct identifier for the endpoint.
"""
field = self.endpoint().external_id_field()
return getattr(self, field)

@classmethod
def field_for_endpoint(cls) -> str:
"""Return the appropriate field for the object for its endpoint.
:param field: The field to return.
:returns: The correct field for the endpoint.
"""
return cls.endpoint().external_id_field()

@classmethod
@abstractmethod
def endpoint(cls) -> Endpoint:
"""Return the endpoint for the method."""
raise NotImplementedError("You must define an endpoint.")

@classmethod
def get(cls, _id: int) -> Optional[BMT]:
"""Get an object.
This function is at its base a wrapper around the get_by_id function,
but it can be overridden to provide more specific functionality.
:param _id: The ID of the object.
:returns: The object if found, None otherwise.
"""
return cls.get_by_id(_id)

@classmethod
def get_by_id(cls, _id: int) -> Optional[BMT]:
"""Get an object by its ID.
Note that for Hosts, the ID is the name of the host.
:param _id: The ID of the object.
:returns: The object if found, None otherwise.
"""
endpoint = cls.endpoint()

# Some endpoints do not use the ID field as the endpoint identifier,
# and in these cases we need to search for the ID... Lovely.
if endpoint.requires_search_for_id():
data = get_item_by_key_value(cls.endpoint(), "id", str(_id))
else:
data = get(cls.endpoint().with_id(_id), ok404=True)
if not data:
return None
data = data.json()

if not data:
return None

return cast(BMT, cls(**data))

@classmethod
def get_by_field(cls, field: str, value: str) -> Optional[BMT]:
"""Get an object by a field.
Note that some endpoints do not use the ID field for lookups. We do some
magic mapping via endpoint introspection to perform the following mapping for
classes and their endpoint "id" fields:
- Hosts -> name
- Networks -> network
This implies that doing a get_by_field("name", value) on Hosts will *not*
result in a search, but a direct lookup at ../endpoint/name which is what
the mreg server expects for Hosts (and similar for Network).
:param field: The field to search by.
:param value: The value to search for.
:returns: The object if found, None otherwise.
"""
endpoint = cls.endpoint()

if endpoint.requires_search_for_id() and field == endpoint.external_id_field():
data = get(endpoint.with_id(value), ok404=True)
if not data:
return None
data = data.json()
else:
data = get_item_by_key_value(cls.endpoint(), field, value, ok404=True)

if not data:
return None

return cast(BMT, cls(**data))

@classmethod
def get_list_by_field(
cls, field: str, value: Union[str, int], ordering: Optional[str] = None
) -> List[BMT]:
"""Get a list of objects by a field.
:param field: The field to search by.
:param value: The value to search for.
:param ordering: The ordering to use when fetching the list.
:returns: A list of objects if found, an empty list otherwise.
"""
params = {field: value}
if ordering:
params["ordering"] = ordering

data = get_list(cls.endpoint(), params=params)
return [cast(BMT, cls(**item)) for item in data]

def refetch(self) -> BMT:
"""Fetch an updated version of the object.
Note that the caller (self) of this method will remain unchanged and can contain
outdated information. The returned object will be the updated version.
:returns: The fetched object.
"""
obj = self.__class__.get_by_id(self.id)
if not obj:
cli_warning(f"Could not refresh {self.__class__.__name__} with ID {self.id}.")

return obj

def patch(self, fields: Dict[str, Any]) -> BMT:
"""Patch the object with the given values.
:param kwargs: The values to patch.
:returns: The object refetched from the server.
"""
patch(self.endpoint().with_id(self.id), **fields)

new_object = self.refetch()

aliases = get_model_aliases(new_object)
for key, value in fields.items():
field_name = aliases.get(key)
if field_name is None:
cli_warning(f"Unknown field {key} in patch request.")
try:
nval = getattr(new_object, field_name)
except AttributeError:
cli_warning(f"Could not get value for {field_name} in patched object.")
if str(nval) != str(value):
cli_warning(
# Should this reference `field_name` instead of `key`?
f"Patch failure! Tried to set {key} to {value}, but server returned {nval}."
)

return new_object

def delete(self) -> bool:
"""Delete the object.
:returns: True if the object was deleted, False otherwise.
"""
response = delete(self.endpoint().with_id(self.id_for_endpoint()))

if response and response.ok:
return True

return False

@classmethod
def create(cls, kwargs: Dict[str, Union[str, None]]) -> Union[None, BMT]:
"""Create the object.
:returns: The object if created, None otherwise.
"""
response = post(cls.endpoint(), params=None, **kwargs)

if response and response.ok:
location = response.headers.get("Location")
if location:
obj = None
if cls.endpoint() is Endpoint.Hosts:
obj = cls.get_by_field("name", location.split("/")[-1])
else:
obj = cls.get_by_id(int(location.split("/")[-1]))

if obj:
return obj

cli_warning(f"Could not fetch object from location {location}.")

else:
cli_warning("No location header in response.")

return None
67 changes: 67 additions & 0 deletions mreg_cli/api/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Fields for models of the API."""

import ipaddress
import re

from pydantic import validator

from mreg_cli.api.abstracts import FrozenModel
from mreg_cli.types import IP_AddressT

_mac_regex = re.compile(r"^([0-9A-Fa-f]{2}[.:-]){5}([0-9A-Fa-f]{2})$")


class MACAddressField(FrozenModel):
"""Represents a MAC address."""

address: str

@validator("address", pre=True)
def validate_and_format_mac(cls, v: str) -> str:
"""Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format.
:param v: The input MAC address string.
:raises ValueError: If the input does not match the expected MAC address pattern.
:returns: The normalized MAC address.
"""
# Validate input format
if not _mac_regex.match(v):
raise ValueError("Invalid MAC address format")

# Normalize MAC address
v = re.sub(r"[.:-]", "", v).lower()
return ":".join(v[i : i + 2] for i in range(0, 12, 2))

def __str__(self) -> str:
"""Return the MAC address as a string."""
return self.address


class IPAddressField(FrozenModel):
"""Represents an IP address, automatically determines if it's IPv4 or IPv6."""

address: IP_AddressT

@validator("address", pre=True)
def parse_ip_address(cls, value: str) -> IP_AddressT:
"""Parse and validate the IP address."""
try:
return ipaddress.ip_address(value)
except ValueError as e:
raise ValueError(f"Invalid IP address '{value}'.") from e

def is_ipv4(self) -> bool:
"""Check if the IP address is IPv4."""
return isinstance(self.address, ipaddress.IPv4Address)

def is_ipv6(self) -> bool:
"""Check if the IP address is IPv6."""
return isinstance(self.address, ipaddress.IPv6Address)

def __str__(self) -> str:
"""Return the IP address as a string."""
return str(self.address)

def __hash__(self):
"""Return a hash of the IP address."""
return hash(self.address)
Loading

0 comments on commit f0eb55c

Please sign in to comment.