diff --git a/order/__init__.py b/order/__init__.py index a63325a..2ae98b0 100644 --- a/order/__init__.py +++ b/order/__init__.py @@ -10,7 +10,7 @@ "Settings", "Lazy", "Model", "AdapterModel", "Adapter", "Materialized", "DataProvider", - "UniqueObject", "UniqueObjectIndex", + "UniqueObject", "LazyUniqueObject", "UniqueObjectIndex", "DuplicateObjectException", "DuplicateNameException", "DuplicateIdException", ] @@ -23,14 +23,18 @@ # provisioning imports from order.settings import Settings -from order.models.base import Lazy, Model +from order.types import Lazy +from order.models.base import Model from order.models.unique import ( - UniqueObject, UniqueObjectIndex, DuplicateObjectException, DuplicateNameException, - DuplicateIdException, + UniqueObject, LazyUniqueObject, UniqueObjectIndex, DuplicateObjectException, + DuplicateNameException, DuplicateIdException, ) +from order.models.campaign import Campaign +from order.models.dataset import Dataset from order.adapters.base import AdapterModel, Adapter, Materialized, DataProvider # import adapters to trigger their registration import order.adapters.order -import order.adapters.dbs -import order.adapters.xsdb +import order.adapters.das +# import order.adapters.dbs +# import order.adapters.xsdb diff --git a/order/adapters/base.py b/order/adapters/base.py index 780e2d0..e52cc2e 100644 --- a/order/adapters/base.py +++ b/order/adapters/base.py @@ -16,20 +16,19 @@ import shutil from contextlib import contextmanager from abc import ABCMeta, abstractmethod, abstractproperty -from typing import Any, Sequence, Dict -from types import GeneratorType from pydantic import BaseModel +from order.types import Any, Sequence, Dict, GeneratorType, NonEmptyStrictStr, StrictStr, Field from order.settings import Settings from order.util import create_hash class AdapterModel(BaseModel): - adapter: str - arguments: Dict[str, Any] - key: str + adapter: NonEmptyStrictStr + key: StrictStr + arguments: Dict[NonEmptyStrictStr, Any] = Field(default_factory=lambda: {}) @property def name(self) -> str: @@ -194,17 +193,24 @@ def __init__( shutil.rmtree(self.cache_directory) @contextmanager - def materialize(self, adapter_model: AdapterModel | dict[str, Any]) -> GeneratorType: + def materialize( + self, + adapter_model: AdapterModel | dict[str, Any], + adapter_kwargs: dict[str, Any] | None = None, + ) -> GeneratorType: if not isinstance(adapter_model, AdapterModel): adapter_model = AdapterModel(**adapter_model) # get the adapter class and instantiate it adapter = AdapterMeta.get_cls(adapter_model.name)() + # merge kwargs + adapter_kwargs = {**adapter_model.arguments, **(adapter_kwargs or {})} + # determine the basename of the cache file (if existing) h = ( os.path.realpath(self.data_location), - adapter.get_cache_key(**adapter_model.arguments), + adapter.get_cache_key(**adapter_kwargs), ) cache_name = f"{create_hash(h)}.json" @@ -220,7 +226,7 @@ def materialize(self, adapter_model: AdapterModel | dict[str, Any]) -> Generator # invoke the adapter args = (self.data_location,) if adapter.needs_data_location else () - materialized = adapter.retrieve_data(*args, **adapter_model.arguments) + materialized = adapter.retrieve_data(*args, **adapter_kwargs) # complain when the return value is not a materialized container if not isinstance(materialized, Materialized): @@ -232,10 +238,8 @@ def materialize(self, adapter_model: AdapterModel | dict[str, Any]) -> Generator # yield the materialized data and cache it if the receiving context did not raise try: yield materialized - except Exception as e: - if isinstance(e, self.SkipCaching): - return - raise e + except self.SkipCaching: + return # cache it if writable_path: diff --git a/order/adapters/das.py b/order/adapters/das.py new file mode 100644 index 0000000..8e2d046 --- /dev/null +++ b/order/adapters/das.py @@ -0,0 +1,29 @@ +# coding: utf-8 + +from __future__ import annotations + + +__all__ = ["DASDatasetAdapter"] + + +from order.adapters.base import Adapter, Materialized + + +class DASDatasetAdapter(Adapter): + + name = "das_dataset" + + def retrieve_data(self, *, keys: list[str]) -> Materialized: + if keys[0].startswith("/SCALE"): + return Materialized(n_events=1, n_files=1) + return Materialized(n_events=5_000_000, n_files=12) + + +class DASLFNsAdapter(Adapter): + + name = "das_lfns" + + def retrieve_data(self, *, keys: list[str]) -> Materialized: + if keys[0].startswith("/SCALE"): + return Materialized(lfns=["/SCALE/b/NANOAODSIM"]) + return Materialized(lfns=["/a/b/NANOAODSIM"]) diff --git a/order/adapters/order.py b/order/adapters/order.py index 74f3fe0..c14dc4a 100644 --- a/order/adapters/order.py +++ b/order/adapters/order.py @@ -42,13 +42,73 @@ def retrieve_data( dataset_dir = os.path.join(self.remove_scheme(data_location), "datasets", campaign_name) # read yaml files in the datasets directory - datasets = {} + datasets = [] for path in glob.glob(os.path.join(dataset_dir, "*.yaml")): with open(path, "r") as f: # allow multiple documents per file - for data in yaml.load_all(f, Loader=yaml.SafeLoader): - if "name" not in data: - raise KeyError(f"no field 'name' defined in dataset yaml file {path}") - datasets[data["name"]] = data + stream = yaml.load_all(f, Loader=yaml.SafeLoader) + for i, entry in enumerate(stream): + if "name" not in entry: + raise AttributeError( + f"no field 'name' defined in enty {i} of dataset yaml file {path}", + ) + if "id" not in entry: + raise AttributeError( + f"no field 'id' defined in enty {i} of dataset yaml file {path}", + ) + datasets.append( + self.create_lazy_dataset_dict(campaign_name, entry["name"], entry["id"]), + ) return Materialized(datasets=datasets) + + @classmethod + def create_lazy_dataset_dict(cls, campaign_name: str, name: str, id: int) -> dict: + return { + "name": name, + "id": id, + "class_name": "Dataset", + "adapter": { + "adapter": "order_dataset", + "arguments": { + "campaign_name": campaign_name, + "dataset_name": name, + }, + "key": "dataset", + }, + } + + +class DatasetAdapter(OrderAdapter): + + name = "order_dataset" + + def retrieve_data( + self, + data_location: str, + *, + campaign_name: str, + dataset_name: str, + ) -> Materialized: + # only supporting local evaluation for now + if not self.location_is_local(data_location): + raise NotImplementedError(f"non-local location {data_location} not handled by {self}") + + # build the yaml file path + path = os.path.join( + self.remove_scheme(data_location), + "datasets", + campaign_name, + f"{dataset_name}.yaml", + ) + if not os.path.exists(path): + raise Exception(f"dataset file {path} not existing") + + # open the file and look for the dataset + with open(path, "r") as f: + stream = yaml.load_all(f, Loader=yaml.SafeLoader) + for entry in stream: + if entry.get("name") == dataset_name: + return Materialized(dataset=entry) + + raise Exception(f"no dataset entry with name '{dataset_name}' found in {path}") diff --git a/order/models/base.py b/order/models/base.py index 23591f4..e66fb44 100644 --- a/order/models/base.py +++ b/order/models/base.py @@ -7,76 +7,39 @@ from __future__ import annotations -__all__ = ["Lazy", "Model"] +__all__ = ["Model"] -import re -from typing import Union, Any -from types import GeneratorType - -from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType -from pydantic import BaseModel, Field, Strict, ConfigDict -from pydantic.fields import FieldInfo +from pydantic import BaseModel, ConfigDict +from order.types import Any, GeneratorType, Field, FieldInfo, Lazy from order.adapters.base import AdapterModel, DataProvider from order.util import no_value -class Lazy(object): - - @classmethod - def __class_getitem__(cls, types): - if not isinstance(types, tuple): - types = (types,) - return Union[tuple(map(cls.make_strict, types)) + (AdapterModel,)] - - @classmethod - def make_strict(cls, type_: type) -> AnnotatedType: - # some types cannot be strict - if not cls.can_make_strict(type_): - return type_ - - # when not decorated with strict meta data, just create a new strict tyoe - if ( - not isinstance(type_, AnnotatedType) or - not any(isinstance(m, Strict) for m in getattr(type_, "__metadata__", [])) - ): - return Annotated[type_, Strict()] - - # when already strict, return as is - metadata = type_.__metadata__ - if all(m.strict for m in metadata if isinstance(m, Strict)): - return type_ - - # at this point, strict metadata exists but it is actually disabled, - # so replace it in metadata and return a new annotated type - metadata = [ - (Strict() if isinstance(m, Strict) else m) - for m in metadata - ] - return Annotated[(*type_.__args__, *metadata)] - - @classmethod - def can_make_strict(cls, type_: type) -> bool: - if type_.__dict__.get("_name") in ("Dict", "List"): - return False - - return True - - class ModelMeta(type(BaseModel)): def __new__(meta_cls, class_name: str, bases: tuple, class_dict: dict[str, Any]) -> "ModelMeta": # convert "Lazy" annotations to proper fields and add access properties lazy_attrs = [] for attr, type_str in list(class_dict.get("__annotations__", {}).items()): - type_names = meta_cls.parse_lazy_annotation(type_str) + type_names = Lazy.parse_annotation(type_str) if type_names: - meta_cls.register_lazy_attr(attr, type_names, class_name, class_dict) + meta_cls.register_lazy_attr(attr, type_names, class_name, bases, class_dict) lazy_attrs.append(attr) - # store names of lazy attributes - class_dict["_lazy_attrs"] = [(attr, meta_cls.get_lazy_attr(attr)) for attr in lazy_attrs] + # store names of lazy attributes, considering also bases + lazy_attrs_dict = {} + for base in reversed(bases): + if getattr(base, "_lazy_attrs", None) is None: + continue + lazy_attrs_dict.update({ + attr: lazy_attr + for attr, lazy_attr in base._lazy_attrs.default.items() + if lazy_attr in base.__fields__ + }) + lazy_attrs_dict.update({attr: meta_cls.get_lazy_attr(attr) for attr in lazy_attrs}) + class_dict["_lazy_attrs"] = lazy_attrs_dict # check the model_config class_dict["model_config"] = model_config = class_dict.get("model_config") or ConfigDict() @@ -95,12 +58,12 @@ def __new__(meta_cls, class_name: str, bases: tuple, class_dict: dict[str, Any]) # create the class cls = super().__new__(meta_cls, class_name, bases, class_dict) - return cls + # remove non-existing lazy attributes from above added dict after class was created + for attr, lazy_attr in list(cls._lazy_attrs.default.items()): + if lazy_attr not in cls.__fields__: + del cls._lazy_attrs.default[attr] - @classmethod - def parse_lazy_annotation(meta_cls, type_str: str) -> list[str] | None: - m = re.match(r"^Lazy\[(.+)\]$", type_str) - return m and [s.strip() for s in m.group(1).split(",")] + return cls @classmethod def get_lazy_attr(meta_cls, attr: str) -> str: @@ -112,6 +75,7 @@ def register_lazy_attr( attr: str, type_names: list[str], class_name: str, + bases: tuple, class_dict: dict[str, Any], ) -> None: # if a field already exist, get it @@ -122,12 +86,15 @@ def register_lazy_attr( ) class_dict.pop(attr, None) + # store existing fields + class_dict.setdefault("__orig_fields__", {})[attr] = field + # exchange the annotation with the lazy one lazy_attr = meta_cls.get_lazy_attr(attr) class_dict["__annotations__"][lazy_attr] = class_dict["__annotations__"].pop(attr) - # add a field for the lazy attribute with aliases - _field = Field(alias=attr, serialization_alias=attr, repr=False) + # make sure the field has an alias set and is skipped in repr + _field = Field(alias=attr, repr=False) field = FieldInfo.merge_field_infos(field, _field) if field else _field class_dict[lazy_attr] = field @@ -146,7 +113,7 @@ def fget(self): with DataProvider.instance().materialize(adapter_model) as materialized: # loop through known lazy attributes and check which of them is assigned a # materialized value - for attr_, lazy_attr_ in self._lazy_attrs: + for attr_, lazy_attr_ in self._lazy_attrs.items(): # the adapter model must be compatible that the called one adapter_model_ = getattr(self, lazy_attr_) if not adapter_model.compare_signature(adapter_model_): @@ -199,6 +166,11 @@ def __repr_args__(self) -> GeneratorType: """ yield from super().__repr_args__() - for attr, lazy_attr in self._lazy_attrs: + for attr, lazy_attr in self._lazy_attrs.items(): + # skip when field was originally skipped + orig_field = self.__orig_fields__.get(attr) + if orig_field and not orig_field.repr: + continue + value = getattr(self, lazy_attr) yield attr, f"lazy({value.name})" if isinstance(value, AdapterModel) else value diff --git a/order/models/campaign.py b/order/models/campaign.py index de688ad..1dabafb 100644 --- a/order/models/campaign.py +++ b/order/models/campaign.py @@ -10,10 +10,9 @@ __all__ = ["GT", "Campaign"] -from typing import Dict - -from order.models.base import Model, Lazy -from order.models.dataset import Dataset +from order.types import Lazy, Field +from order.models.base import Model +from order.models.dataset import DatasetIndex class GT(Model): @@ -28,4 +27,4 @@ class Campaign(Model): tier: Lazy[str] ecm: Lazy[float] recommended_gt: GT - datasets: Lazy[Dict[str, Dataset]] + datasets: DatasetIndex = Field(default_factory=DatasetIndex) diff --git a/order/models/dataset.py b/order/models/dataset.py index 789ac1e..08fabfe 100644 --- a/order/models/dataset.py +++ b/order/models/dataset.py @@ -6,24 +6,97 @@ __all__ = ["Dataset"] -from typing import List +import enum -from order.models.base import Model, Lazy +from pydantic import field_validator +from order.types import Union, List, Dict, Field, NonEmptyStrictStr, PositiveStrictInt, Lazy +from order.models.base import Model +from order.models.unique import UniqueObject, LazyUniqueObject, UniqueObjectIndex -class File(Model): - logical_file_name: str - block_name: str - check_sum: int - last_modification_date: int - file_type: str +# class File(Model): +# logical_file_name: str +# block_name: str +# check_sum: int +# last_modification_date: int +# file_type: str -class Dataset(Model): - id: int - name: str - das_name: str - nevents: int - # pnf: Lazy[List of string] - files: Lazy[List[File]] +class GenOrder(enum.Enum): + + unknown: str = "unknown" + lo: str = "lo" + nlo: str = "nlo" + nnlo: str = "nnlo" + n3lo: str = "n3lo" + + def __str__(self) -> str: + return self.value + + +class DatasetInfo(Model): + + keys: List[NonEmptyStrictStr] + gen_order: NonEmptyStrictStr = Field(default=str(GenOrder.unknown)) + n_files: Lazy[PositiveStrictInt] + n_events: Lazy[PositiveStrictInt] + lfns: Lazy[List[NonEmptyStrictStr]] + + @field_validator("gen_order", mode="after") + @classmethod + def validate_gen_order(cls, gen_order: str) -> str: + try: + return str(GenOrder[gen_order]) + except KeyError: + raise ValueError(f"unknown gen_order '{gen_order}'") + + +class Dataset(UniqueObject): + + info: Dict[str, DatasetInfo] + + def __getitem__(self, name: str) -> DatasetInfo: + return self.get_info(name) + + def get_info(self, name: str) -> DatasetInfo: + return self.info[name] + + def set_info(self, name: str, info: DatasetInfo) -> None: + if not isinstance(info, DatasetInfo): + raise TypeError(f"expected info to be DatasetInfo object, but got '{info}'") + self.info[name] = info + + @property + def keys(self) -> list[NonEmptyStrictStr]: + return self.info["nominal"].keys + + @property + def gen_order(self) -> GenOrder: + return self.info["nominal"].gen_order + + @property + def n_files(self) -> int: + return self.info["nominal"].n_files + + @property + def n_events(self) -> int: + return self.info["nominal"].n_events + + @property + def lfns(self) -> list[NonEmptyStrictStr]: + return self.info["nominal"].lfns + + +class LazyDataset(LazyUniqueObject): + + class_name: NonEmptyStrictStr = Field(default=Dataset) + + +class DatasetIndex(UniqueObjectIndex): + + class_name: NonEmptyStrictStr = Field(default=Dataset) + objects: Lazy[List[Union[Dataset, LazyDataset]]] = Field( + default_factory=lambda: [], + repr=False, + ) diff --git a/order/models/unique.py b/order/models/unique.py index 20baf2d..345468a 100644 --- a/order/models/unique.py +++ b/order/models/unique.py @@ -8,19 +8,21 @@ __all__ = [ - "UniqueObject", "UniqueObjectIndex", + "UniqueObject", "LazyUniqueObject", "UniqueObjectIndex", "DuplicateObjectException", "DuplicateNameException", "DuplicateIdException", ] -from typing import ClassVar, Any, List, Union -from types import GeneratorType +from contextlib import contextmanager -from pydantic import StrictInt, StrictStr, Field, field_validator -from typing_extensions import Annotated -from annotated_types import Ge, Len +from pydantic import field_validator +from order.types import ( + ClassVar, Any, T, List, Union, GeneratorType, Field, PositiveStrictInt, NonEmptyStrictStr, + KeysView, Lazy, +) from order.models.base import Model +from order.adapters.base import AdapterModel, DataProvider from order.util import no_value, DotAccessProxy @@ -41,7 +43,7 @@ def __new__( ) -> "UniqueObjectMeta": # define a separate integer to remember the maximum id class_dict.setdefault("_max_id", 0) - class_dict["__annotations__"]["_max_id"] = "ClassVar[int]" + class_dict.setdefault("__annotations__", {})["_max_id"] = "ClassVar[int]" # create the class cls = super().__new__(meta_cls, class_name, bases, class_dict) @@ -66,20 +68,10 @@ def get_unique_cls(meta_cls, name: str) -> "UniqueObjectMeta": return meta_cls.__unique_classes[name] -class UniqueObject(Model, metaclass=UniqueObjectMeta): +class UniqueObjectBase(Model): - id: Annotated[StrictInt, Ge(0)] - name: Annotated[StrictStr, Len(min_length=1)] - - AUTO_ID: ClassVar[str] = "+" - - @field_validator("id", mode="before") - @classmethod - def evaluate_auto_id(cls, id: str | int) -> int: - if id == cls.AUTO_ID: - cls._max_id += 1 - id = cls._max_id - return id + id: PositiveStrictInt + name: NonEmptyStrictStr def __hash__(self) -> int: """ @@ -103,6 +95,13 @@ def __eq__(self, other: Any) -> bool: if isinstance(other, self.__class__): return self.name == other.name and self.id == other.id + # TODO: not particularly clean to use a subclass of _this_ class, solve by inheritance + if ( + (isinstance(other, LazyUniqueObject) and other.cls == self.__class__) or + (isinstance(self, LazyUniqueObject) and self.cls == other.__class__) + ): + return self.name == other.name and self.id == other.id + return False def __ne__(self, other: Any) -> bool: @@ -119,7 +118,7 @@ def __lt__(self, other: Any) -> bool: if isinstance(other, int): return self.id < other - if isinstance(other, self.__class__): + if isinstance(other, UniqueObjectBase): return self.id < other.id return False @@ -132,7 +131,7 @@ def __le__(self, other: Any) -> bool: if isinstance(other, int): return self.id <= other - if isinstance(other, self.__class__): + if isinstance(other, UniqueObjectBase): return self.id <= other.id return False @@ -145,7 +144,7 @@ def __gt__(self, other: Any) -> bool: if isinstance(other, int): return self.id > other - if isinstance(other, self.__class__): + if isinstance(other, UniqueObjectBase): return self.id > other.id return False @@ -158,13 +157,83 @@ def __ge__(self, other: Any) -> bool: if isinstance(other, int): return self.id >= other - if isinstance(other, self.__class__): + if isinstance(other, UniqueObjectBase): return self.id >= other.id return False -class UniqueObjectIndex(Model): +class WrapsUniqueClcass(Model): + + class_name: NonEmptyStrictStr + + @field_validator("class_name", mode="before") + @classmethod + def convert_class_to_name(cls, class_name: str | UniqueObjectMeta) -> str: + if isinstance(class_name, UniqueObjectMeta): + class_name = class_name.__name__ + return class_name + + @field_validator("class_name", mode="after") + @classmethod + def validate_class_name(cls, class_name: str) -> str: + # check that the model class is existing + if not UniqueObjectMeta.has_unique_cls(class_name): + raise ValueError(f"class '{class_name}' is not a subclass of UniqueObject") + return class_name + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # store a reference to the wrapped class + self._cls = UniqueObjectMeta.get_unique_cls(self.class_name) + + @property + def cls(self) -> UniqueObjectMeta: + return self._cls + + +class UniqueObject(UniqueObjectBase, metaclass=UniqueObjectMeta): + + AUTO_ID: ClassVar[str] = "+" + + @field_validator("id", mode="before") + @classmethod + def evaluate_auto_id(cls, id: str | int) -> int: + if id == cls.AUTO_ID: + cls._max_id += 1 + id = cls._max_id + return id + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # adjust max id class attribute + if self.id > self.__class__._max_id: + self.__class__._max_id = self.id + + +class LazyUniqueObject(UniqueObjectBase, WrapsUniqueClcass): + + adapter: AdapterModel + + @contextmanager + def materialize(self, index: "UniqueObjectIndex") -> GeneratorType: + with DataProvider.instance().materialize(self.adapter) as materialized: + # complain when the adapter did not provide a value for this attribute + if self.adapter.key not in materialized: + raise KeyError( + f"adapter '{self.adapter.name}' did not provide field " + f"'{self.adapter.key}' required to materialize '{self}'", + ) + + # create the materialized instance + inst = self.cls(**materialized[self.adapter.key]) + + yield inst + + +class UniqueObjectIndex(WrapsUniqueClcass): """ Index of :py:class:`UniqueObject` instances which are - as the name suggests - unique within this index, enabling fast lookups by either name or id. @@ -210,43 +279,36 @@ class UniqueObjectIndex(Model): An object that provides simple attribute access to contained objects via name. """ - class_name: Annotated[StrictStr, Len(min_length=1)] = Field(default=UniqueObject) - objects: List[UniqueObject] = Field(default_factory=lambda: [], repr=False) + objects: Lazy[List[Union[UniqueObject, LazyUniqueObject]]] = Field( + default_factory=lambda: [], + repr=False, + ) - @field_validator("class_name", mode="before") + @field_validator("lazy_objects", mode="after") @classmethod - def convert_class_to_name(cls, class_name: str | UniqueObjectMeta) -> str: - if isinstance(class_name, UniqueObjectMeta): - class_name = class_name.__name__ - return class_name - - @field_validator("class_name", mode="after") - @classmethod - def validate_class_name(cls, class_name: str) -> str: - # check that the model class is existing - if not UniqueObjectMeta.has_unique_cls(class_name): - raise ValueError(f"class '{class_name}' is not a subclass of UniqueObject") - return class_name - - @field_validator("objects", mode="after") - @classmethod - def detect_duplicate_objects(cls, objects: List[UniqueObject]) -> List[UniqueObject]: - seen_ids, seen_names = set(), set() + def detect_duplicate_objects( + cls, + objects: Lazy[list[UniqueObject | LazyUniqueObject]], + ) -> Lazy[list[UniqueObject | LazyUniqueObject]]: + # skip adapters + if isinstance(objects, AdapterModel): + return objects + + # detect duplicate ids and names + seen_names, seen_ids = set(), set() for obj in objects: - if obj.id in seen_ids: - raise DuplicateIdException(type(obj), obj.id, cls) if obj.name in seen_names: raise DuplicateNameException(type(obj), obj.name, cls) + if obj.id in seen_ids: + raise DuplicateIdException(type(obj), obj.id, cls) seen_ids.add(obj.id) seen_names.add(obj.name) + return objects def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - # store a reference to the class - self._cls = UniqueObjectMeta.get_unique_cls(self.class_name) - # name-based DotAccessProxy self._n = DotAccessProxy(self.get) @@ -274,7 +336,7 @@ def __iter__(self) -> GeneratorType: Iterates through the index and yields the contained objects (i.e. the *values*). """ for obj in self.objects: - yield obj + yield self.get(obj.name) def __nonzero__(self): """ @@ -282,16 +344,18 @@ def __nonzero__(self): """ return bool(self.objects) + def __getitem__(self, obj: Any) -> UniqueObject: + """ + Shorthand for :py:func:`get` without a default value. + """ + return self.get(obj) + def __repr_args__(self) -> GeneratorType: """ Yields all key-values pairs to be injected into the representation. """ yield from super().__repr_args__() - yield "objects", len(self) - - @property - def cls(self) -> UniqueObjectMeta: - return self._cls + yield "len", len(self) @property def n(self) -> DotAccessProxy: @@ -315,48 +379,52 @@ def _sync_indices(self, force: bool = False) -> None: self._name_index[obj.name] = obj self._id_index[obj.id] = obj - def names(self) -> List[str]: + def names(self) -> KeysView: """ Returns the names of the contained objects in the index. """ self._sync_indices() - return list(self._name_index.keys()) + return self._name_index.keys() - def ids(self) -> List[int]: + def ids(self) -> KeysView: """ Returns the ids of the contained objects in the index. """ self._sync_indices() - return list(self._id_index.keys()) + return self._id_index.keys() - def keys(self): + def keys(self) -> GeneratorType: """ Returns the (name, id) pairs of all objects contained in the index. """ self._sync_indices() - return list(zip(self._name_index.keys(), self._id_index.keys())) + return (tpl for tpl in zip(self.names(), self.ids())) - def values(self): + def values(self) -> GeneratorType: """ Returns all objects contained in the index. """ self._sync_indices() - return list(self.objects) + return (obj for obj in self) - def items(self): + def items(self) -> GeneratorType: """ Returns (name, id, object) 3-tuples of all objects contained in the index """ - return list(zip(self.keys(), self.objects)) + return ( + ((obj.name, obj.id), obj) + for obj in self + ) def has(self, obj: Any) -> bool: """ Returns whether an object *obj* is contained in the index. *obj* can be an :py:attr:`id`, a - :py:attr:`name` or a :py:class:`UniqueObject` of type :py:attr:`cls` + :py:attr:`name`, a :py:class:`UniqueObject` of type or a :py:class:`LazyUniqueObject` that + wraps a type :py:attr:`cls`. """ self._sync_indices() - if isinstance(obj, self.cls): + if isinstance(obj, self.cls) or (isinstance(obj, LazyUniqueObject) and obj.cls == self.cls): obj = obj.name if isinstance(obj, str): @@ -367,33 +435,56 @@ def has(self, obj: Any) -> bool: return False - def get(self, obj: Any, default: Any = no_value) -> Union[UniqueObject, Any]: + def get(self, obj: Any, default: T = no_value) -> UniqueObject | T: """ Returns an object *obj* contained in this index. *obj* can be an :py:attr:`id`, a - :py:attr:`name` or a :py:class:`UniqueObject` of type :py:attr:`cls`. If no object could be - found, *default* is returned if set. An exception is raised otherwise. + :py:attr:`name`, a :py:class:`UniqueObject` of type or a :py:class:`LazyUniqueObject` that + wraps a type :py:attr:`cls`. If no object could be found, *default* is returned if set. An + exception is raised otherwise. """ self._sync_indices() - obj_orig = obj - if isinstance(obj, self.cls): - obj = obj.name - - if isinstance(obj, str): - if obj in self._name_index: - return self._name_index[obj] - if default != no_value: - return default - - elif isinstance(obj, int): - if obj in self._id_index: - return self._id_index[obj] - if default != no_value: - return default + name_or_id = obj + inst_passed = False + if isinstance(obj, self.cls) or isinstance(obj, LazyUniqueObject) and obj.cls == self.cls: + name_or_id = obj.name + inst_passed = True + + _obj = None + if isinstance(name_or_id, str): + if name_or_id in self._name_index: + _obj = self._name_index[name_or_id] + + elif isinstance(name_or_id, int): + if name_or_id in self._id_index: + _obj = self._id_index[name_or_id] + + # when an obj was an instance, but the found one is not equal to it, reset the found one + if _obj is not None and inst_passed and _obj != obj: + _obj = None + + # prepare and return the found object + if _obj is not None: + # materialize when the found object is lazy + if isinstance(_obj, LazyUniqueObject): + # remember the position of the object + idx = self.objects.index(_obj) + + # materializee + with _obj.materialize(self) as _obj: + # add back the materialized object + self.objects[idx] = _obj + self._name_index[_obj.name] = _obj + self._id_index[_obj.id] = _obj + + return _obj + + if default != no_value: + return default - raise ValueError(f"object '{obj_orig}' not known to index '{self}'") + raise ValueError(f"object '{obj}' not known to index '{self}'") - def get_first(self, default: Any = no_value) -> Union[UniqueObject, Any]: + def get_first(self, default: T = no_value) -> UniqueObject | T: """ Returns the first object of this index. If no object could be found, *default* is returned if set. An exception is raised otherwise. @@ -401,9 +492,9 @@ def get_first(self, default: Any = no_value) -> Union[UniqueObject, Any]: if not self.objects and default != no_value: return default - return self.objects[0] + return self.get(self.objects[0].name) - def get_last(self, default: Any = no_value) -> Union[UniqueObject, Any]: + def get_last(self, default: T = no_value) -> UniqueObject | T: """ Returns the last object of this index. If no object could be found, *default* is returned if set. An exception is raised otherwise. @@ -411,15 +502,28 @@ def get_last(self, default: Any = no_value) -> Union[UniqueObject, Any]: if not self.objects and default != no_value: return default - return self.objects[-1] + return self.get(self.objects[-1].name) - def add(self, obj: UniqueObject, overwrite: bool = False) -> UniqueObject: - """ - Adds a new object *obj* with type :py:attr:`cls` to the index. When an object with the same - :py:attr:`name` or :py:attr:`id` already exists and *overwrite* is *False*, an exception is - raised. Otherwise, the object is overwritten. The added object is returned. - """ - if not isinstance(obj, self.cls): + def add( + self, + obj: UniqueObject | LazyUniqueObject, + overwrite: bool = False, + ) -> UniqueObject | LazyUniqueObject: + """ + Adds *obj*, a :py:class:`UniqueObject` of type or a :py:class:`LazyUniqueObject` that + wraps a type :py:attr:`cls`, to the index. When an object with the same :py:attr:`name` or + :py:attr:`id` already exists and *overwrite* is *False*, an exception is raised. Otherwise, + the object is overwritten. The added object is returned. + """ + if isinstance(obj, LazyUniqueObject): + # unique object type of the lazy object and this index must match + if self.cls != obj.cls: + raise TypeError( + f"LazyUniqueObject '{obj}' must materialize into '{self.cls}' instead of " + f"'{obj.cls}'", + ) + elif not isinstance(obj, self.cls): + # type of the object must match that of the index raise TypeError(f"object '{obj}' to add must be of type '{self.cls}'") self._sync_indices() @@ -428,11 +532,11 @@ def add(self, obj: UniqueObject, overwrite: bool = False) -> UniqueObject: if obj.name in self._name_index: if not overwrite: raise DuplicateNameException(self.cls, obj.name, self) - self.remove(obj) + self.remove(obj.name) if obj.id in self._id_index: if not overwrite: raise DuplicateIdException(self.cls, obj.id, self) - self.remove(obj) + self.remove(obj.id) # add to objects and indices self.objects.append(obj) @@ -443,39 +547,62 @@ def add(self, obj: UniqueObject, overwrite: bool = False) -> UniqueObject: def extend( self, - objects: Union["UniqueObjectIndex", List[UniqueObject]], + objects: "UniqueObjectIndex" | list[UniqueObject | LazyUniqueObject], overwrite: bool = False, ) -> None: """ - Adds multiple new *objects* of type :py:attr:`cls` to this index. + Adds multiple new *objects* of type :py:attr:`cls` to this index. See :py:meth:`add` for + more info. """ - for obj in objects: + # when objects is an index, do not materialize its objects via the normal iterator + gen = objects.objects if isinstance(objects, UniqueObjectIndex) else objects + for obj in gen: self.add(obj, overwrite=overwrite) def index(self, obj: Any) -> int: """ Returns the position of an object *obj* in this index. *obj* can be an :py:attr:`id`, a - :py:attr:`name` or a :py:class:`UniqueObject` of type :py:attr:`cls`. + :py:attr:`name`, a :py:class:`UniqueObject` of type or a :py:class:`LazyUniqueObject` that + wraps a type :py:attr:`cls`. """ return self.objects.index(self.get(obj)) def remove(self, obj: Any) -> bool: """ - Remove an object *obj* from the index. *obj* can be an :py:attr:`id`, a :py:attr:`name` or a - :py:class:`UniqueObject` of type :py:attr:`cls`. *True* is returned in case an object could - be removed, and *False* otherwise. + Remove an object *obj* from the index. *obj* can be an :py:attr:`id`, a :py:attr:`name`, a + :py:class:`UniqueObject` of type or a :py:class:`LazyUniqueObject` that wraps a type + :py:attr:`cls`. *True* is returned in case an object could be removed, and *False* + otherwise. """ - # first, get the object - obj = self.get(obj, default=None) + self._sync_indices() + + name_or_id = obj + inst_passed = False + if isinstance(obj, self.cls) or isinstance(obj, LazyUniqueObject) and obj.cls == self.cls: + name_or_id = obj.name + inst_passed = True + + _obj = None + if isinstance(name_or_id, str): + if name_or_id in self._name_index: + _obj = self._name_index[name_or_id] + + elif isinstance(name_or_id, int): + if name_or_id in self._id_index: + _obj = self._id_index[name_or_id] + + # when an obj was an instance, but the found one is not equal to it, reset the found one + if _obj is not None and inst_passed and _obj != obj: + _obj = None - # return when not existing - if obj is None: + # do nothing if no object was found, or if it does not exactly match the passed one + if _obj is None: return False # remove from indices and objects - self._name_index.pop(obj.name) - self._id_index.pop(obj.id) - self.objects.remove(obj) + self._name_index.pop(_obj.name) + self._id_index.pop(_obj.id) + self.objects.remove(_obj) return True diff --git a/order/settings.py b/order/settings.py index c931b16..b5c6ba2 100644 --- a/order/settings.py +++ b/order/settings.py @@ -12,8 +12,8 @@ import os import re -from typing import Any +from order.types import T from order.util import no_value @@ -28,7 +28,7 @@ def instance(cls): return cls.__instance @classmethod - def get_env(cls, name: str, default: Any = no_value) -> Any: + def get_env(cls, name: str, default: T = no_value) -> T: if name not in os.environ: if default != no_value: return default diff --git a/order/types.py b/order/types.py new file mode 100644 index 0000000..f35718d --- /dev/null +++ b/order/types.py @@ -0,0 +1,85 @@ +# coding: utf-8 + +""" +Custom type definitions and shorthands to simplify imports of types that are spread across multiple +packages. +""" + +from __future__ import annotations + + +__all__ = [] + + +import re +from collections.abc import KeysView, ValuesView # noqa +from typing import Any, Union, TypeVar, ClassVar, List, Tuple, Sequence, Set, Dict # noqa +from types import GeneratorType # noqa + +from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType # noqa +from annotated_types import Ge, Len # noqa +from pydantic import Field, Strict, StrictInt, StrictFloat, StrictStr # noqa +from pydantic.fields import FieldInfo # noqa + + +#: Strict positive integer. +PositiveStrictInt = Annotated[StrictInt, Ge(0)] + +#: Strict non-empty string. +NonEmptyStrictStr = Annotated[StrictStr, Len(min_length=1)] + +#: Generic type variable, more stringent than Any. +T = TypeVar("T") + + +class Lazy(object): + """ + Annotation factory that adds :py:class:`AdapterModel` to the metadata of the returned annotated + type. + """ + + @classmethod + def __class_getitem__(cls, types: tuple[type]) -> type: + from order.adapters.base import AdapterModel + + if not isinstance(types, tuple): + types = (types,) + return Union[tuple(map(cls.make_strict, types)) + (AdapterModel,)] + + @classmethod + def parse_annotation(cls, type_str: str) -> list[str] | None: + m = re.match(r"^Lazy\[(.+)\]$", type_str) + return m and [s.strip() for s in m.group(1).split(",")] + + @classmethod + def make_strict(cls, type_: type) -> AnnotatedType: + # some types cannot be strict + if not cls.can_make_strict(type_): + return type_ + + # when not decorated with strict meta data, just create a new strict tyoe + if ( + not isinstance(type_, AnnotatedType) or + not any(isinstance(m, Strict) for m in getattr(type_, "__metadata__", [])) + ): + return Annotated[type_, Strict()] + + # when already strict, return as is + metadata = type_.__metadata__ + if all(m.strict for m in metadata if isinstance(m, Strict)): + return type_ + + # at this point, strict metadata exists but it is actually disabled, + # so replace it in metadata and return a new annotated type + metadata = [ + (Strict() if isinstance(m, Strict) else m) + for m in metadata + ] + return Annotated[(*type_.__args__, *metadata)] + + @classmethod + def can_make_strict(cls, type_: type) -> bool: + if type_.__dict__.get("_name") in ("Dict", "List"): + return False + + return True