Skip to content

Commit

Permalink
Upgrade models.py to Pydantic V2 semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
pederhan committed Apr 23, 2024
1 parent f0eb55c commit dfd0193
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 31 deletions.
8 changes: 5 additions & 3 deletions mreg_cli/api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ipaddress
import re

from pydantic import validator
from pydantic import field_validator

from mreg_cli.api.abstracts import FrozenModel
from mreg_cli.types import IP_AddressT
Expand All @@ -16,7 +16,8 @@ class MACAddressField(FrozenModel):

address: str

@validator("address", pre=True)
@field_validator("address", mode="before")
@classmethod
def validate_and_format_mac(cls, v: str) -> str:
"""Validate and normalize MAC address to 'aa:bb:cc:dd:ee:ff' format.
Expand All @@ -42,7 +43,8 @@ class IPAddressField(FrozenModel):

address: IP_AddressT

@validator("address", pre=True)
@field_validator("address", mode="before")
@classmethod
def parse_ip_address(cls, value: str) -> IP_AddressT:
"""Parse and validate the IP address."""
try:
Expand Down
77 changes: 49 additions & 28 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

from pydantic import AliasChoices, BaseModel, Field, root_validator, validator
from pydantic import (
AliasChoices,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
field_validator,
model_validator,
)
from pydantic.fields import FieldInfo
from typing_extensions import Annotated

from mreg_cli.api.abstracts import APIMixin, FrozenModel, FrozenModelWithTimestamps
from mreg_cli.api.endpoints import Endpoint
Expand All @@ -24,7 +34,8 @@ class HostT(BaseModel):

hostname: str

@validator("hostname")
@field_validator("hostname")
@classmethod
def validate_hostname(cls, value: str) -> str:
"""Validate the hostname."""
value = value.lower()
Expand Down Expand Up @@ -166,6 +177,18 @@ def is_delegated(self) -> bool:
return True


def _extract_name(value: Dict[str, Any]) -> str:
"""Extract the name from the dictionary.
:param v: Dictionary containing the name.
:returns: Extracted name as a string.
"""
return value["name"]


NameList = List[Annotated[str, BeforeValidator(_extract_name)]]


class Role(FrozenModelWithTimestamps, APIMixin["Role"]):
"""Model for a role.
Expand All @@ -175,13 +198,14 @@ class Role(FrozenModelWithTimestamps, APIMixin["Role"]):

id: int # noqa: A003
created_at: datetime = Field(..., validation_alias=AliasChoices("create_date", "created_at"))
hosts: List[str]
atoms: List[str]
hosts: NameList
atoms: NameList
description: str
name: str
labels: List[int]

@validator("created_at", pre=True)
@field_validator("created_at", mode="before")
@classmethod
def validate_created_at(cls, value: str) -> datetime:
"""Validate and convert the created_at field to datetime.
Expand All @@ -190,15 +214,6 @@ def validate_created_at(cls, value: str) -> datetime:
"""
return datetime.fromisoformat(value)

@validator("hosts", "atoms", pre=True, each_item=True)
def extract_name(cls, v: Dict[str, str]) -> str:
"""Extract the name from the dictionary.
:param v: Dictionary containing the name.
:returns: Extracted name as a string.
"""
return v["name"]

@classmethod
def endpoint(cls) -> Endpoint:
"""Return the endpoint for the class."""
Expand Down Expand Up @@ -233,7 +248,7 @@ class Network(FrozenModelWithTimestamps, APIMixin["Network"]):
excluded_ranges: List[str]
network: str # for now
description: str
vlan: Optional[int]
vlan: Optional[int] = None
dns_delegated: bool
category: str
location: str
Expand Down Expand Up @@ -282,15 +297,17 @@ class IPAddress(FrozenModelWithTimestamps, WithHost, APIMixin["IPAddress"]):
macaddress: Optional[MACAddressField] = None
ipaddress: IPAddressField

@validator("macaddress", pre=True, allow_reuse=True)
def create_valid_macadress_or_none(cls, v: str):
@field_validator("macaddress", mode="before")
@classmethod
def create_valid_macadress_or_none(cls, v: str) -> MACAddressField | None:
"""Create macaddress or convert empty strings to None."""
if v:
return MACAddressField(address=v)

return None

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def convert_ip_address(cls, values: Any):
"""Convert ipaddress string to IPAddressField if necessary."""
ip_address = values.get("ipaddress")
Expand Down Expand Up @@ -418,7 +435,8 @@ class CNAME(FrozenModelWithTimestamps, WithHost, WithZone, APIMixin["CNAME"]):
name: HostT
ttl: Optional[int] = None

@validator("name", pre=True)
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)
Expand Down Expand Up @@ -511,9 +529,9 @@ class NAPTR(FrozenModelWithTimestamps, WithHost, APIMixin["NAPTR"]):
id: int # noqa: A003
preference: int
order: int
flag: Optional[str]
service: Optional[str]
regex: Optional[str]
flag: Optional[str] = None
service: Optional[str] = None
regex: Optional[str] = None
replacement: str

def output(self, padding: int = 14) -> None:
Expand Down Expand Up @@ -572,7 +590,7 @@ class Srv(FrozenModelWithTimestamps, WithHost, WithZone, APIMixin["Srv"]):
priority: int
weight: int
port: int
ttl: Optional[int]
ttl: Optional[int] = None

@classmethod
def endpoint(cls) -> Endpoint:
Expand Down Expand Up @@ -737,13 +755,15 @@ class Host(FrozenModelWithTimestamps, APIMixin["Host"]):
# Note, we do not use WithZone here as this is optional and we resolve it differently.
zone: Optional[int] = None

@validator("name", pre=True)
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: str) -> HostT:
"""Validate the hostname."""
return HostT(hostname=value)

@validator("comment", pre=True, allow_reuse=True)
def empty_string_to_none(cls, v: str):
@field_validator("comment", mode="before")
@classmethod
def empty_string_to_none(cls, v: str) -> str | None:
"""Convert empty strings to None."""
return v or None

Expand Down Expand Up @@ -1107,8 +1127,9 @@ def get(cls, params: Optional[Dict[str, Any]] = None) -> "HostList":
data = get_list(cls.endpoint(), params=params)
return cls(results=[Host(**host) for host in data])

@validator("results", pre=True)
def check_results(cls, v: List[Dict[str, str]]):
@field_validator("results", mode="before")
@classmethod
def check_results(cls, v: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Check that the results are valid."""
return v

Expand Down

0 comments on commit dfd0193

Please sign in to comment.