Skip to content

Commit

Permalink
perf: Refactor SchemaBase.from_dict and co
Browse files Browse the repository at this point in the history
There's quite a lot in here, so I've left my notes in `_subclasses` temporarily.
- Removed unused `use_json=False` branch
- Evaluate the hash table **once** and not every time `SchemaBase.from_dict` is called
  • Loading branch information
dangotbanned committed Aug 25, 2024
1 parent bd1d580 commit 8ca4266
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 60 deletions.
32 changes: 22 additions & 10 deletions tests/utils/test_schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings
from collections import deque
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, Sequence

import jsonschema
import jsonschema.exceptions
Expand Down Expand Up @@ -48,7 +48,23 @@ def test_actual_json_schema_draft_is_same_as_hardcoded_default():
class _TestSchema(SchemaBase):
@classmethod
def _default_wrapper_classes(cls):
return _TestSchema.__subclasses__()
return schemapi._subclasses(_TestSchema)

@classmethod
def from_dict(
cls: type[schemapi.TSchemaBase], dct: dict[str, Any], validate: bool = True
) -> schemapi.TSchemaBase:
"""
Overrides ``SchemaBase``, which uses a cached ``FromDict.hash_tps``.
The cached version is based on an iterator over:
schemapi._subclasses(VegaLiteSchema)
"""
if validate:
cls.validate(dct)
converter = schemapi._FromDict(cls._default_wrapper_classes())
return converter.from_dict(dct, cls)


class MySchema(_TestSchema):
Expand Down Expand Up @@ -383,14 +399,10 @@ class BadSchema(SchemaBase):
assert str(err.value).startswith("Cannot instantiate object")


@pytest.mark.parametrize("use_json", [True, False])
def test_hash_schema(use_json):
classes = _TestSchema._default_wrapper_classes()
FromDict = schemapi._FromDict

for cls in classes:
hsh1 = FromDict.hash_schema(cls._schema, use_json=use_json)
hsh2 = FromDict.hash_schema(cls._schema, use_json=use_json)
def test_hash_schema():
for cls in _TestSchema._default_wrapper_classes():
hsh1 = schemapi._hash_schema(cls._schema)
hsh2 = schemapi._hash_schema(cls._schema)
assert hsh1 == hsh2
assert hash(hsh1) == hash(hsh2)

Expand Down
149 changes: 99 additions & 50 deletions tools/schemapi/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,17 +646,6 @@ def _prune_subset_enum(iterable: _Errs, /) -> _ErrsLazy:
}


def _subclasses(cls: type[Any]) -> Iterator[type[Any]]:
"""Breadth-first sequence of all classes which inherit from cls."""
seen = set()
current: set[type[Any]] = {cls}
while current:
seen |= current
current = set(chain.from_iterable(cls.__subclasses__() for cls in current))
for cls in current - seen:
yield cls


def _from_array_like(obj: Iterable[Any], /) -> list[Any]:
try:
ser = nw.from_native(obj, strict=True, series_only=True)
Expand Down Expand Up @@ -1288,7 +1277,13 @@ def from_dict(
"""
if validate:
cls.validate(dct)
converter = _FromDict(cls._default_wrapper_classes())
# NOTE: the breadth-first search occurs only once now
# `_FromDict` is purely ClassVar/classmethods
converter: type[_FromDict] | _FromDict = (
_FromDict
if _FromDict.hash_tps
else _FromDict(cls._default_wrapper_classes())
)
return converter.from_dict(dct, cls)

@classmethod
Expand Down Expand Up @@ -1389,6 +1384,9 @@ def _passthrough(*args: Any, **kwds: Any) -> Any | dict[str, Any]:


def _freeze(val):
# NOTE: No longer referenced
# - Previously only called during tests
# - Not during any library code
if isinstance(val, dict):
return frozenset((k, _freeze(v)) for k, v in val.items())
elif isinstance(val, set):
Expand All @@ -1399,6 +1397,64 @@ def _freeze(val):
return val


def _hash_schema(
schema: _JsonParameter,
/,
*,
exclude: Iterable[str] = frozenset(
("definitions", "title", "description", "$schema", "id")
),
) -> int:
"""
Return the hash value for a ``schema``.
Parameters
----------
schema
``SchemaBase._schema``.
exclude
``schema`` keys which are not considered when identifying equivalence.
"""
if isinstance(schema, Mapping):
schema = {k: v for k, v in schema.items() if k not in exclude}
return hash(json.dumps(schema, sort_keys=True))


def _subclasses(cls: type[TSchemaBase]) -> Iterator[type[TSchemaBase]]:
"""
Breadth-first sequence of all classes which inherit from ``cls``.
Notes
-----
- `__subclasses__()` alone isn't helpful, as that is only immediate subclasses
- Deterministic
- Used for `SchemaBase` & `VegaLiteSchema`
- In practice, it provides an iterator over all classes in the schema below `VegaLiteSchema`
- The first one is `Root`
- The order itself, I don't think is important
- But probably important that it doesn't change
- Thinking they used an iterator so that the subclasses are evaluated after they have all been defined
- `Chart` seems to try to avoid calling this
- Using `TopLevelMixin.__subclasses__()` first if possible
- It is always called during `Chart.encode()`
- Chart.encode()
- altair.utils.core.infer_encoding_types
- _ChannelCache.infer_encoding_types
- _ChannelCache._wrap_in_channel
- SchemaBase.from_dict (recursive, hot loop, validate =False, within a try/except)
- _FromDict(cls._default_wrapper_classes())
- schemapi._subclasses(schema.core.VegaLiteSchema)
"""
seen = set()
current: set[type[TSchemaBase]] = {cls}
while current:
seen |= current
current = set(chain.from_iterable(cls.__subclasses__() for cls in current))
for cls in current - seen:
yield cls


class _FromDict:
"""
Class used to construct SchemaBase class hierarchies from a dict.
Expand All @@ -1408,84 +1464,80 @@ class _FromDict:
specified in the ``wrapper_classes`` positional-only argument to the constructor.
"""

_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id")
hash_tps: ClassVar[defaultdict[int, deque[type[SchemaBase]]]] = defaultdict(deque)
"""
Maps unique schemas to corresponding types.
def __init__(self, wrapper_classes: Iterable[type[SchemaBase]], /) -> None:
# Create a mapping of a schema hash to a list of matching classes
# This lets us quickly determine the correct class to construct
self.class_dict: dict[int, list[type[SchemaBase]]] = defaultdict(list)
for tp in wrapper_classes:
if tp._schema is not None:
self.class_dict[self.hash_schema(tp._schema)].append(tp)
The logic is that after removing a subset of keys, some schemas are identical.
@classmethod
def hash_schema(cls, schema: dict[str, Any], use_json: bool = True) -> int:
"""
Compute a python hash for a nested dictionary which properly handles dicts, lists, sets, and tuples.
If there are multiple matches, we use the first one in the ``deque``.
At the top level, the function excludes from the hashed schema all keys
listed in `exclude_keys`.
``_subclasses`` yields the results of a `breadth-first search`_,
so the first matching class is the most general match.
This implements two methods: one based on conversion to JSON, and one based
on recursive conversions of unhashable to hashable types; the former seems
to be slightly faster in several benchmarks.
"""
if cls._hash_exclude_keys and isinstance(schema, dict):
schema = {
key: val
for key, val in schema.items()
if key not in cls._hash_exclude_keys
}
s: Any = json.dumps(schema, sort_keys=True) if use_json else _freeze(schema)
return hash(s)
.. _breadth-first search:
https://en.wikipedia.org/wiki/Breadth-first_search
"""

def __init__(self, wrapper_classes: Iterator[type[SchemaBase]], /) -> None:
cls = type(self)
for tp in wrapper_classes:
if tp._schema is not None:
cls.hash_tps[_hash_schema(tp._schema)].append(tp)

@overload
@classmethod
def from_dict(
self,
cls,
dct: TSchemaBase,
tp: None = ...,
schema: None = ...,
rootschema: None = ...,
default_class: Any = ...,
) -> TSchemaBase: ...
@overload
@classmethod
def from_dict(
self,
cls,
dct: dict[str, Any] | list[dict[str, Any]],
tp: Any = ...,
schema: Any = ...,
rootschema: Any = ...,
default_class: type[TSchemaBase] = ..., # pyright: ignore[reportInvalidTypeVarUse]
) -> TSchemaBase: ...
@overload
@classmethod
def from_dict(
self,
cls,
dct: dict[str, Any],
tp: None = ...,
schema: dict[str, Any] = ...,
rootschema: None = ...,
default_class: Any = ...,
) -> SchemaBase: ...
@overload
@classmethod
def from_dict(
self,
cls,
dct: dict[str, Any],
tp: type[TSchemaBase],
schema: None = ...,
rootschema: None = ...,
default_class: Any = ...,
) -> TSchemaBase: ...
@overload
@classmethod
def from_dict(
self,
cls,
dct: dict[str, Any] | list[dict[str, Any]],
tp: type[TSchemaBase],
schema: dict[str, Any],
rootschema: dict[str, Any] | None = ...,
default_class: Any = ...,
) -> Never: ...
@classmethod
def from_dict(
self,
cls,
dct: dict[str, Any] | list[dict[str, Any]] | TSchemaBase,
tp: type[TSchemaBase] | None = None,
schema: dict[str, Any] | None = None,
Expand All @@ -1502,18 +1554,15 @@ def from_dict(
root_schema: dict[str, Any] = rootschema or tp._rootschema or current_schema
target_tp = tp
elif schema is not None:
# If there are multiple matches, we use the first one in the dict.
# Our class dict is constructed breadth-first from top to bottom,
# so the first class that matches is the most general match.
current_schema = schema
root_schema = rootschema or current_schema
matches = self.class_dict[self.hash_schema(current_schema)]
target_tp = matches[0] if matches else default_class
matches = cls.hash_tps[_hash_schema(current_schema)]
target_tp = next(iter(matches), default_class)
else:
msg = "Must provide either `tp` or `schema`, but not both."
raise ValueError(msg)

from_dict = partial(self.from_dict, rootschema=root_schema)
from_dict = partial(cls.from_dict, rootschema=root_schema)
# Can also return a list?
resolved = _resolve_references(current_schema, root_schema)
if "anyOf" in resolved or "oneOf" in resolved:
Expand Down

0 comments on commit 8ca4266

Please sign in to comment.