Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: write performance improvements, api clarity #645

Merged
merged 13 commits into from
Dec 1, 2022
24 changes: 24 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ Fill in the relevant sections, clearly linking the issue the change is attemping

`debugpy` is installed in local development. A VSCode launch config is provided. Run `inv test -v -d` to enable the debugger (`-d` for debug). It'll then wait for you to attach your VSCode debugging client.

#### Debugging Performance Issues

You can run `inv benchmark` to run the full benchmark suite. Alternatively, write a test file, e.g.:

```py
# test_performance.py
import pytest
import os

SIZE = int(os.environ.get("SIZE", 1000))

@pytest.mark.parametrize("x", range(SIZE))
def test_performance(x, snapshot):
assert x == snapshot
```

and then run:

```sh
SIZE=1000 python -m cProfile -s cumtime -m pytest test_performance.py --snapshot-update -s > profile.log
```

See the cProfile docs for metric sorting options.

## Styleguides

### Commit Messages
Expand Down
15 changes: 11 additions & 4 deletions src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __post_init__(self) -> None:
def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
) -> "AbstractSyrupyExtension":
return extension_class(test_location=self.test_location)
return extension_class()

@property
def extension(self) -> "AbstractSyrupyExtension":
Expand Down Expand Up @@ -238,8 +238,12 @@ def __eq__(self, other: "SerializableData") -> bool:
return self._assert(other)

def _assert(self, data: "SerializableData") -> bool:
snapshot_location = self.extension.get_location(index=self.index)
snapshot_name = self.extension.get_snapshot_name(index=self.index)
snapshot_location = self.extension.get_location(
test_location=self.test_location, index=self.index
)
snapshot_name = self.extension.get_snapshot_name(
test_location=self.test_location, index=self.index
)
snapshot_data: Optional["SerializedData"] = None
serialized_data: Optional["SerializedData"] = None
matches = False
Expand All @@ -264,6 +268,7 @@ def _assert(self, data: "SerializableData") -> bool:
if not matches and self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
Expand Down Expand Up @@ -299,7 +304,9 @@ def _post_assert(self) -> None:
def _recall_data(self, index: "SnapshotIndex") -> Optional["SerializableData"]:
try:
return self.extension.read_snapshot(
index=index, session_id=str(id(self.session))
test_location=self.test_location,
index=index,
session_id=str(id(self.session)),
)
except SnapshotDoesNotExist:
return None
4 changes: 2 additions & 2 deletions src/syrupy/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
SNAPSHOT_DIRNAME = "__snapshots__"
SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot fossil"
SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot fossil"
SNAPSHOT_EMPTY_FOSSIL_KEY = "empty snapshot collection"
SNAPSHOT_UNKNOWN_FOSSIL_KEY = "unknown snapshot collection"

EXIT_STATUS_FAIL_UNUSED = 1

Expand Down
52 changes: 26 additions & 26 deletions src/syrupy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SnapshotUnknown(Snapshot):


@dataclass
class SnapshotFossil:
class SnapshotCollection:
"""A collection of snapshots at a save location"""

location: str
Expand All @@ -54,8 +54,8 @@ def add(self, snapshot: "Snapshot") -> None:
if snapshot.name != SNAPSHOT_EMPTY_FOSSIL_KEY:
self.remove(SNAPSHOT_EMPTY_FOSSIL_KEY)

def merge(self, snapshot_fossil: "SnapshotFossil") -> None:
for snapshot in snapshot_fossil:
def merge(self, snapshot_collection: "SnapshotCollection") -> None:
for snapshot in snapshot_collection:
self.add(snapshot)

def remove(self, snapshot_name: str) -> None:
Expand All @@ -69,8 +69,8 @@ def __iter__(self) -> Iterator["Snapshot"]:


@dataclass
class SnapshotEmptyFossil(SnapshotFossil):
"""This is a saved fossil that is known to be empty and thus can be removed"""
class SnapshotEmptyCollection(SnapshotCollection):
"""This is a saved collection that is known to be empty and thus can be removed"""

_snapshots: Dict[str, "Snapshot"] = field(
default_factory=lambda: {SnapshotEmpty().name: SnapshotEmpty()}
Expand All @@ -82,42 +82,42 @@ def has_snapshots(self) -> bool:


@dataclass
class SnapshotUnknownFossil(SnapshotFossil):
"""This is a saved fossil that is unclaimed by any extension currently in use"""
class SnapshotUnknownCollection(SnapshotCollection):
"""This is a saved collection that is unclaimed by any extension currently in use"""

_snapshots: Dict[str, "Snapshot"] = field(
default_factory=lambda: {SnapshotUnknown().name: SnapshotUnknown()}
)


@dataclass
class SnapshotFossils:
_snapshot_fossils: Dict[str, "SnapshotFossil"] = field(default_factory=dict)
class SnapshotCollections:
_snapshot_collections: Dict[str, "SnapshotCollection"] = field(default_factory=dict)

def get(self, location: str) -> Optional["SnapshotFossil"]:
return self._snapshot_fossils.get(location)
def get(self, location: str) -> Optional["SnapshotCollection"]:
return self._snapshot_collections.get(location)

def add(self, snapshot_fossil: "SnapshotFossil") -> None:
self._snapshot_fossils[snapshot_fossil.location] = snapshot_fossil
def add(self, snapshot_collection: "SnapshotCollection") -> None:
self._snapshot_collections[snapshot_collection.location] = snapshot_collection

def update(self, snapshot_fossil: "SnapshotFossil") -> None:
snapshot_fossil_to_update = self.get(snapshot_fossil.location)
if snapshot_fossil_to_update is None:
snapshot_fossil_to_update = SnapshotFossil(
location=snapshot_fossil.location
def update(self, snapshot_collection: "SnapshotCollection") -> None:
snapshot_collection_to_update = self.get(snapshot_collection.location)
if snapshot_collection_to_update is None:
snapshot_collection_to_update = SnapshotCollection(
location=snapshot_collection.location
)
self.add(snapshot_fossil_to_update)
snapshot_fossil_to_update.merge(snapshot_fossil)
self.add(snapshot_collection_to_update)
snapshot_collection_to_update.merge(snapshot_collection)

def merge(self, snapshot_fossils: "SnapshotFossils") -> None:
for snapshot_fossil in snapshot_fossils:
self.update(snapshot_fossil)
def merge(self, snapshot_collections: "SnapshotCollections") -> None:
for snapshot_collection in snapshot_collections:
self.update(snapshot_collection)

def __iter__(self) -> Iterator["SnapshotFossil"]:
return iter(self._snapshot_fossils.values())
def __iter__(self) -> Iterator["SnapshotCollection"]:
return iter(self._snapshot_collections.values())

def __contains__(self, key: str) -> bool:
return key in self._snapshot_fossils
return key in self._snapshot_collections


@dataclass
Expand Down
27 changes: 14 additions & 13 deletions src/syrupy/extensions/amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Set,
)

from syrupy.data import SnapshotFossil
from syrupy.data import SnapshotCollection
from syrupy.extensions.base import AbstractSyrupyExtension

from .serializer import DataSerializer
Expand All @@ -21,6 +21,8 @@ class AmberSnapshotExtension(AbstractSyrupyExtension):
An amber snapshot file stores data in the following format:
"""

_file_extension = "ambr"

def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
"""
Returns the serialized form of 'data' to be compared
Expand All @@ -31,27 +33,23 @@ def serialize(self, data: "SerializableData", **kwargs: Any) -> str:
def delete_snapshots(
self, snapshot_location: str, snapshot_names: Set[str]
) -> None:
snapshot_fossil_to_update = DataSerializer.read_file(snapshot_location)
snapshot_collection_to_update = DataSerializer.read_file(snapshot_location)
for snapshot_name in snapshot_names:
snapshot_fossil_to_update.remove(snapshot_name)
snapshot_collection_to_update.remove(snapshot_name)

if snapshot_fossil_to_update.has_snapshots:
DataSerializer.write_file(snapshot_fossil_to_update)
if snapshot_collection_to_update.has_snapshots:
DataSerializer.write_file(snapshot_collection_to_update)
else:
Path(snapshot_location).unlink()

@property
def _file_extension(self) -> str:
return "ambr"

def _read_snapshot_fossil(self, snapshot_location: str) -> "SnapshotFossil":
def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)

@staticmethod
@lru_cache()
def __cacheable_read_snapshot(
snapshot_location: str, cache_key: str
) -> "SnapshotFossil":
) -> "SnapshotCollection":
return DataSerializer.read_file(snapshot_location)

def _read_snapshot_data_from_location(
Expand All @@ -63,8 +61,11 @@ def _read_snapshot_data_from_location(
snapshot = snapshots.get(snapshot_name)
return snapshot.data if snapshot else None

def _write_snapshot_fossil(self, *, snapshot_fossil: "SnapshotFossil") -> None:
DataSerializer.write_file(snapshot_fossil, merge=True)
@classmethod
def _write_snapshot_collection(
cls, *, snapshot_collection: "SnapshotCollection"
) -> None:
DataSerializer.write_file(snapshot_collection, merge=True)


__all__ = ["AmberSnapshotExtension", "DataSerializer"]
22 changes: 12 additions & 10 deletions src/syrupy/extensions/amber/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from syrupy.data import (
Snapshot,
SnapshotFossil,
SnapshotCollection,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,18 +70,20 @@ class DataSerializer:
_marker_crn: str = "\r\n"

@classmethod
def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> None:
def write_file(
cls, snapshot_collection: "SnapshotCollection", merge: bool = False
) -> None:
"""
Writes the snapshot data into the snapshot file that can be read later.
"""
filepath = snapshot_fossil.location
filepath = snapshot_collection.location
if merge:
base_snapshot = cls.read_file(filepath)
base_snapshot.merge(snapshot_fossil)
snapshot_fossil = base_snapshot
base_snapshot.merge(snapshot_collection)
snapshot_collection = base_snapshot

with open(filepath, "w", encoding=TEXT_ENCODING, newline=None) as f:
for snapshot in sorted(snapshot_fossil, key=lambda s: s.name):
for snapshot in sorted(snapshot_collection, key=lambda s: s.name):
snapshot_data = str(snapshot.data)
if snapshot_data is not None:
f.write(f"{cls._marker_name} {snapshot.name}\n")
Expand All @@ -90,15 +92,15 @@ def write_file(cls, snapshot_fossil: "SnapshotFossil", merge: bool = False) -> N
f.write(f"\n{cls._marker_divider}\n")

@classmethod
def read_file(cls, filepath: str) -> "SnapshotFossil":
def read_file(cls, filepath: str) -> "SnapshotCollection":
"""
Read the raw snapshot data (str) from the snapshot file into a dict
of snapshot name to raw data. This does not attempt any deserialization
of the snapshot data.
"""
name_marker_len = len(cls._marker_name)
indent_len = len(cls._indent)
snapshot_fossil = SnapshotFossil(location=filepath)
snapshot_collection = SnapshotCollection(location=filepath)
try:
with open(filepath, "r", encoding=TEXT_ENCODING, newline=None) as f:
test_name = None
Expand All @@ -112,7 +114,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil":
if line.startswith(cls._indent):
snapshot_data += line[indent_len:]
elif line.startswith(cls._marker_divider) and snapshot_data:
snapshot_fossil.add(
snapshot_collection.add(
Snapshot(
name=test_name,
data=snapshot_data.rstrip(os.linesep),
Expand All @@ -121,7 +123,7 @@ def read_file(cls, filepath: str) -> "SnapshotFossil":
except FileNotFoundError:
pass

return snapshot_fossil
return snapshot_collection

@classmethod
def serialize(
Expand Down
Loading