Skip to content

Commit

Permalink
Add policy commands (#220)
Browse files Browse the repository at this point in the history
* Add roles and atoms endpoints to name field check

* Use `Mapping` annotation in `APIMixin.patch`

`dict` invariant and will not accept an argument of `dict[str, str]` when the type annotation is `dict[str, str | None]`. Yeah...

* Add type annotations for `Namespace` attributes
  • Loading branch information
pederhan authored May 13, 2024
1 parent 7123311 commit da3335a
Show file tree
Hide file tree
Showing 5 changed files with 475 additions and 332 deletions.
79 changes: 57 additions & 22 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Self, cast
from typing import Any, Mapping, Self, cast

from pydantic import AliasChoices, BaseModel
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -54,18 +54,61 @@ def get_model_aliases(model: BaseModel) -> dict[str, str]:
Includes field names, alias, and validation alias(es).
"""
fields: dict[str, str] = {}

for field_name, field_info in model.model_fields.items():
aliases = get_field_aliases(field_info)
if model.model_config.get("populate_by_name"):
aliases.add(field_name)
# Assign aliases to field name in mapping
for alias in aliases:
fields[alias] = field_name

return fields


def validate_patched_model(model: BaseModel, fields: dict[str, Any]) -> None:
"""Validate that model fields were patched correctly."""
aliases = get_model_aliases(model)

validators = {
list: _validate_lists,
dict: _validate_dicts,
}
for key, value in fields.items():
field_name = key
if key in aliases:
field_name = aliases[key]

try:
nval = getattr(model, field_name)
except AttributeError as e:
raise PatchError(f"Could not get value for {field_name} in patched object.") from e

# Ensure patched value is the one we tried to set
validator = validators.get(type(nval), _validate_default)
if not validator(nval, value):
raise PatchError(
f"Patch failure! Tried to set {key} to {value}, but server returned {nval}."
)


def _validate_lists(new: list[Any], old: list[Any]) -> bool:
"""Validate that two lists are equal."""
if len(new) != len(old):
return False
return all(x in old for x in new)


def _validate_dicts(new: dict[str, Any], old: dict[str, Any]) -> bool:
"""Validate that two dictionaries are equal."""
if len(new) != len(old):
return False
return all(old.get(k) == v for k, v in new.items())


def _validate_default(new: Any, old: Any) -> bool:
"""Validate that two values are equal."""
return str(new) == str(old)


class FrozenModel(BaseModel):
"""Model for an immutable object."""

Expand Down Expand Up @@ -336,7 +379,7 @@ def refetch(self) -> Self:

return obj

def patch(self, fields: dict[str, Any]) -> Self:
def patch(self, fields: dict[str, Any], use_json: bool = False, validate: bool = True) -> Self:
"""Patch the object with the given values.
Notes
Expand All @@ -346,28 +389,20 @@ def patch(self, fields: dict[str, Any]) -> Self:
are). Odds are you want to pass an empty string instead.
:param fields: The values to patch.
:param validate: Whether to validate the patched object.
:returns: The object refetched from the server.
"""
patch(self.endpoint().with_id(self.id_for_endpoint()), **fields)
if use_json:
patch(self.endpoint().with_id(self.id_for_endpoint()), fields, use_json=True)
else:
patch(self.endpoint().with_id(self.id_for_endpoint()), use_json=False, **fields)

new_object = self.refetch()

# __init_subclass__ guarantees we inherit from BaseModel
# but we can't signal this to the type checker, so we cast here.
aliases = get_model_aliases(cast(BaseModel, new_object))
for key, value in fields.items():
field_name = key
if key in aliases:
field_name = aliases[key]
try:
nval = getattr(new_object, field_name)
except AttributeError as e:
raise PatchError(f"Could not get value for {field_name} in patched object.") from e
if value and str(nval) != str(value):
raise PatchError(
# Should this reference `field_name` instead of `key`?
f"Patch failure! Tried to set {key} to {value}, but server returned {nval}."
)
if validate:
# __init_subclass__ guarantees we inherit from BaseModel
# but we can't signal this to the type checker, so we cast here.
validate_patched_model(cast(BaseModel, new_object), fields)

return new_object

Expand All @@ -384,7 +419,7 @@ def delete(self) -> bool:
return False

@classmethod
def create(cls, params: dict[str, str | None]) -> Self | None:
def create(cls, params: Mapping[str, str | None]) -> Self | None:
"""Create the object.
Note that several endpoints do not support location headers for created objects,
Expand Down
6 changes: 6 additions & 0 deletions mreg_cli/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class Endpoint(str, Enum):
NetworksReservedList = "/api/v1/networks/{}/reserved_list"

HostPolicyRoles = "/api/v1/hostpolicy/roles/"
HostPolicyRolesAddAtom = "/api/v1/hostpolicy/roles/{}/atoms/"
HostPolicyRolesRemoveAtom = "/api/v1/hostpolicy/roles/{}/atoms/{}"
HostPolicyRolesAddHost = "/api/v1/hostpolicy/roles/{}/hosts/"
HostPolicyRolesRemoveHost = "/api/v1/hostpolicy/roles/{}/hosts/{}"
HostPolicyAtoms = "/api/v1/hostpolicy/atoms/"

ForwardZones = f"{Zones}forward/"
Expand Down Expand Up @@ -96,6 +100,8 @@ def external_id_field(self) -> str:
Endpoint.Cnames,
Endpoint.ForwardZones,
Endpoint.ReverseZones,
Endpoint.HostPolicyRoles,
Endpoint.HostPolicyAtoms,
):
return "name"
if self == Endpoint.Networks:
Expand Down
Loading

0 comments on commit da3335a

Please sign in to comment.