Skip to content

Commit

Permalink
Improve type safety of validators (#286)
Browse files Browse the repository at this point in the history
* Improve type safety of validators

* Remove unused import
  • Loading branch information
pederhan authored Jul 26, 2024
1 parent 2ac32b2 commit 815b969
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 35 deletions.
15 changes: 11 additions & 4 deletions mreg_cli/api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MACAddressField(FrozenModel):

address: str

@field_validator("address", mode="before")
@field_validator("address", mode="after")
@classmethod
def validate_and_format_mac(cls, v: str) -> str:
"""Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format.
Expand Down Expand Up @@ -49,8 +49,13 @@ class IPAddressField(FrozenModel):

@classmethod
def from_string(cls, address: str) -> IPAddressField:
"""Create an IPAddressField from a string."""
return cls(address=address) # type: ignore # validator handles this
"""Create an IPAddressField from a string.
Shortcut for creating an IPAddressField from a string,
without having to convince the type checker that we can
pass in a string to the address field each time.
"""
return cls(address=address) # pyright: ignore[reportArgumentType] # validator handles this

@field_validator("address", mode="before")
@classmethod
Expand All @@ -69,11 +74,13 @@ def is_ipv6(self) -> bool:
"""Check if the IP address is IPv6."""
return isinstance(self.address, ipaddress.IPv6Address)

@staticmethod
def is_valid(value: str) -> bool:
"""Check if the value is a valid IP address."""
try:
ipaddress.ip_address(value)
return True
except:
except ValueError:
return False

def __str__(self) -> str:
Expand Down
14 changes: 8 additions & 6 deletions mreg_cli/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ class HistoryItem(BaseModel):
data: dict[str, Any]

@field_validator("data", mode="before")
def parse_json_data(cls, v: Any) -> dict[str, Any]:
"""Ensure that data is always treated as a dictionary."""
def parse_json_data(cls, v: Any) -> Any:
"""Ensure that non-dict values are treated as JSON."""
if isinstance(v, dict):
return v # type: ignore
else:
return v # pyright: ignore[reportUnknownVariableType]
try:
return json.loads(v)
except json.JSONDecodeError as e:
raise ValueError("Failed to parse history data as JSON") from e

def clean_timestamp(self) -> str:
"""Clean up the timestamp for output."""
Expand All @@ -89,8 +91,8 @@ def msg(self, basename: str) -> str:
rel = self.data["relation"][:-1]
cls = str(self.resource)
if "." in cls:
cls = cls[cls.rindex(".")+1:]
cls = cls.replace("HostPolicy_","")
cls = cls[cls.rindex(".") + 1 :]
cls = cls.replace("HostPolicy_", "")
cls = cls.lower()
msg = f"{rel} {self.data['name']} {direction} {cls} {self.name}"
elif action == "create":
Expand Down
38 changes: 14 additions & 24 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Field,
computed_field,
field_validator,
model_validator,
)
from typing_extensions import Unpack

Expand Down Expand Up @@ -397,7 +396,7 @@ class Permission(FrozenModelWithTimestamps, APIMixin):

@field_validator("range", mode="before")
@classmethod
def validate_ip_or_network(cls, value: str) -> IP_NetworkT:
def validate_ip_or_network(cls, value: Any) -> IP_NetworkT:
"""Validate and convert the input to a network."""
try:
return ipaddress.ip_network(value)
Expand Down Expand Up @@ -1862,21 +1861,19 @@ class IPAddress(FrozenModelWithTimestamps, WithHost, APIMixin):

@field_validator("macaddress", mode="before")
@classmethod
def create_valid_macadress_or_none(cls, v: str) -> MACAddressField | None:
def create_valid_macadress_or_none(cls, v: Any) -> MACAddressField | None:
"""Create macaddress or convert empty strings to None."""
if v:
return MACAddressField(address=v)

return None

@model_validator(mode="before")
@field_validator("ipaddress", mode="before")
@classmethod
def convert_ip_address(cls, values: Any):
"""Convert ipaddress string to IPAddressField if necessary."""
ip_address = values.get("ipaddress")
if isinstance(ip_address, str):
values["ipaddress"] = {"address": ip_address}
return values
def create_valid_ipaddress(cls, v: Any) -> IPAddressField:
"""Create macaddress or convert empty strings to None."""
if isinstance(v, str):
return IPAddressField.from_string(v)
return v # let Pydantic handle it

@classmethod
def get_by_ip(cls, ip: IP_AddressT) -> list[Self]:
Expand Down Expand Up @@ -2039,7 +2036,7 @@ class CNAME(FrozenModelWithTimestamps, WithHost, WithZone, WithTTL, APIMixin):

@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
def validate_name(cls, value: Any) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)

Expand Down Expand Up @@ -2502,17 +2499,16 @@ class Host(FrozenModelWithTimestamps, WithTTL, WithHistory, APIMixin):

@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
def validate_name(cls, value: Any) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)

@field_validator("bacnetid", mode="before")
@classmethod
def convert_bacnetid(cls, v: dict[str, int] | None) -> int | None:
"""Convert json id field to int or None."""
if v and "id" in v:
return v["id"]

def convert_bacnetid(cls, v: Any) -> Any:
"""Use nested ID value in bacnetid value."""
if isinstance(v, dict):
return v.get("id") # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
return None

@classmethod
Expand Down Expand Up @@ -3133,12 +3129,6 @@ def get_by_ip(cls, ip: IP_AddressT) -> HostList:
"""
return cls.get(params={"ipaddresses__ipaddress": str(ip), "ordering": "name"})

@field_validator("results", mode="before")
@classmethod
def check_results(cls, v: list[dict[str, str]]) -> list[dict[str, str]]:
"""Check that the results are valid."""
return v

def __len__(self):
"""Return the number of results."""
return len(self.results)
Expand Down
3 changes: 2 additions & 1 deletion mreg_cli/utilities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,9 @@ class PaginatedResponse(BaseModel):
def _none_count_is_0(cls, v: Any) -> Any:
"""Ensure `count` is never `None`."""
# Django count doesn't seem to be guaranteed to be an integer.
# Ensures here that None is treated as 0.
# https://github.com/django/django/blob/bcbc4b9b8a4a47c8e045b060a9860a5c038192de/django/core/paginator.py#L105-L111
# Theoretically any callable can be passed to the "count" attribute of the paginator.
# Ensures here that None (and any falsey value) is treated as 0.
return v or 0

@classmethod
Expand Down

0 comments on commit 815b969

Please sign in to comment.