Skip to content

Commit

Permalink
refactor: do not instantiate extension with test_location
Browse files Browse the repository at this point in the history
  • Loading branch information
Noah Negin-Ulster committed Dec 1, 2022
1 parent ae07435 commit ff7a2a2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 55 deletions.
3 changes: 2 additions & 1 deletion src/syrupy/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __post_init__(self) -> None:
def __init_extension(
self, extension_class: Type["AbstractSyrupyExtension"]
) -> "AbstractSyrupyExtension":
return extension_class(test_location=self.test_location)
return extension_class()

@property
def extension(self) -> "AbstractSyrupyExtension":
Expand Down Expand Up @@ -268,6 +268,7 @@ def _assert(self, data: "SerializableData") -> bool:
if not matches and self.update_snapshots:
self.session.queue_snapshot_write(
extension=self.extension,
test_location=self.test_location,
data=serialized_data,
index=self.index,
)
Expand Down
77 changes: 30 additions & 47 deletions src/syrupy/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
ABC,
abstractmethod,
)
from collections import defaultdict
from difflib import ndiff
from gettext import gettext
from itertools import zip_longest
from pathlib import Path
from typing import (
TYPE_CHECKING,
Callable,
DefaultDict,
Dict,
Iterator,
List,
Expand Down Expand Up @@ -79,11 +77,6 @@ def serialize(
class SnapshotCollectionStorage(ABC):
_file_extension = ""

@property
@abstractmethod
def test_location(self) -> "PyTestLocation":
raise NotImplementedError

@classmethod
def get_snapshot_name(
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
Expand Down Expand Up @@ -160,63 +153,58 @@ def read_snapshot(
def write_snapshot(
cls,
*,
test_location: "PyTestLocation",
snapshots: List[Tuple["SerializedData", "SnapshotIndex"]],
snapshot_location: str,
snapshots: List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
) -> None:
"""
This method is _final_, do not override. You can override
`_write_snapshot_collection` in a subclass to change behaviour.
"""
if not snapshots:
return

# First we group by location since it'll let us batch by file on disk.
# Not as useful for single file snapshots, but useful for the standard
# Amber extension.
locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list)
for data, index in snapshots:
location = cls.get_location(test_location=test_location, index=index)
snapshot_collection = SnapshotCollection(location=snapshot_location)
for data, test_location, index in snapshots:
snapshot_name = cls.get_snapshot_name(
test_location=test_location, index=index
)
locations[location].append(Snapshot(name=snapshot_name, data=data))

# Ensures the folder path for the snapshot file exists.
try:
Path(
cls.get_location(test_location=test_location, index=index)
).parent.mkdir(parents=True)
except FileExistsError:
pass
snapshot = Snapshot(name=snapshot_name, data=data)
snapshot_collection.add(snapshot)

for location, location_snapshots in locations.items():
snapshot_collection = SnapshotCollection(location=location)

if not test_location.matches_snapshot_location(location):
if not test_location.matches_snapshot_location(snapshot_location):
warning_msg = gettext(
"{line_end}Can not relate snapshot location '{}' "
"to the test location.{line_end}"
"Consider adding '{}' to the generated location."
).format(
location,
snapshot_location,
test_location.basename,
line_end="\n",
)
warnings.warn(warning_msg)

for snapshot in location_snapshots:
snapshot_collection.add(snapshot)

if not test_location.matches_snapshot_name(snapshot.name):
warning_msg = gettext(
"{line_end}Can not relate snapshot name '{}' "
"to the test location.{line_end}"
"Consider adding '{}' to the generated name."
).format(
snapshot.name,
test_location.testname,
line_end="\n",
)
warnings.warn(warning_msg)
if not test_location.matches_snapshot_name(snapshot.name):
warning_msg = gettext(
"{line_end}Can not relate snapshot name '{}' "
"to the test location.{line_end}"
"Consider adding '{}' to the generated name."
).format(
snapshot.name,
test_location.testname,
line_end="\n",
)
warnings.warn(warning_msg)

# Ensures the folder path for the snapshot file exists.
try:
Path(snapshot_location).parent.mkdir(parents=True)
except FileExistsError:
pass

cls._write_snapshot_collection(snapshot_collection=snapshot_collection)
cls._write_snapshot_collection(snapshot_collection=snapshot_collection)

@abstractmethod
def delete_snapshots(
Expand Down Expand Up @@ -432,9 +420,4 @@ def matches(
class AbstractSyrupyExtension(
SnapshotSerializer, SnapshotCollectionStorage, SnapshotReporter, SnapshotComparator
):
def __init__(self, test_location: "PyTestLocation"):
self._test_location = test_location

@property
def test_location(self) -> "PyTestLocation":
return self._test_location
pass
26 changes: 19 additions & 7 deletions src/syrupy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
Optional,
Set,
Tuple,
Type,
)

import pytest

from syrupy.location import PyTestLocation

from .constants import EXIT_STATUS_FAIL_UNUSED
from .data import SnapshotCollections
from .report import SnapshotReport
Expand Down Expand Up @@ -49,24 +52,33 @@ class SnapshotSession:
)

_queued_snapshot_writes: Dict[
"AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]]
Tuple[Type["AbstractSyrupyExtension"], str],
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
] = field(default_factory=dict)

def queue_snapshot_write(
self,
extension: "AbstractSyrupyExtension",
test_location: "PyTestLocation",
data: "SerializedData",
index: "SnapshotIndex",
) -> None:
queue = self._queued_snapshot_writes.get(extension, [])
queue.append((data, index))
self._queued_snapshot_writes[extension] = queue
snapshot_location = extension.get_location(
test_location=test_location, index=index
)
key = (extension.__class__, snapshot_location)
queue = self._queued_snapshot_writes.get(key, [])
queue.append((data, test_location, index))
self._queued_snapshot_writes[key] = queue

def flush_snapshot_write_queue(self) -> None:
for extension, queued_write in self._queued_snapshot_writes.items():
for (
extension_class,
snapshot_location,
), queued_write in self._queued_snapshot_writes.items():
if queued_write:
extension.write_snapshot(
test_location=extension.test_location, snapshots=queued_write
extension_class.write_snapshot(
snapshot_location=snapshot_location, snapshots=queued_write
)
self._queued_snapshot_writes = {}

Expand Down

1 comment on commit ff7a2a2

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: ff7a2a2 Previous: 23cca84 Ratio
benchmarks/test_1000x.py::test_1000x_reads 0.6663152805060267 iter/sec (stddev: 0.0768578591462337) 0.6754078596653935 iter/sec (stddev: 0.06391877117699159) 1.01
benchmarks/test_1000x.py::test_1000x_writes 0.5935792286248747 iter/sec (stddev: 0.24067621781047613) 0.6345993135561808 iter/sec (stddev: 0.23174880874067105) 1.07
benchmarks/test_standard.py::test_standard 0.6123782109210533 iter/sec (stddev: 0.07218052425756886) 0.6315599143065584 iter/sec (stddev: 0.0923523543680502) 1.03

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.