From ff7a2a2949469f3742936bdc66763d863854340c Mon Sep 17 00:00:00 2001 From: Noah Negin-Ulster Date: Thu, 1 Dec 2022 17:38:16 -0500 Subject: [PATCH] refactor: do not instantiate extension with test_location --- src/syrupy/assertion.py | 3 +- src/syrupy/extensions/base.py | 77 ++++++++++++++--------------------- src/syrupy/session.py | 26 ++++++++---- 3 files changed, 51 insertions(+), 55 deletions(-) diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index e345341e..c32fa695 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -94,7 +94,7 @@ def __post_init__(self) -> None: def __init_extension( self, extension_class: Type["AbstractSyrupyExtension"] ) -> "AbstractSyrupyExtension": - return extension_class(test_location=self.test_location) + return extension_class() @property def extension(self) -> "AbstractSyrupyExtension": @@ -268,6 +268,7 @@ def _assert(self, data: "SerializableData") -> bool: if not matches and self.update_snapshots: self.session.queue_snapshot_write( extension=self.extension, + test_location=self.test_location, data=serialized_data, index=self.index, ) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 1067e7fd..306f4579 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -3,7 +3,6 @@ ABC, abstractmethod, ) -from collections import defaultdict from difflib import ndiff from gettext import gettext from itertools import zip_longest @@ -11,7 +10,6 @@ from typing import ( TYPE_CHECKING, Callable, - DefaultDict, Dict, Iterator, List, @@ -79,11 +77,6 @@ def serialize( class SnapshotCollectionStorage(ABC): _file_extension = "" - @property - @abstractmethod - def test_location(self) -> "PyTestLocation": - raise NotImplementedError - @classmethod def get_snapshot_name( cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0 @@ -160,63 +153,58 @@ def read_snapshot( def write_snapshot( cls, *, - test_location: "PyTestLocation", - snapshots: List[Tuple["SerializedData", "SnapshotIndex"]], + snapshot_location: str, + snapshots: List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ) -> None: """ This method is _final_, do not override. You can override `_write_snapshot_collection` in a subclass to change behaviour. """ + if not snapshots: + return + # First we group by location since it'll let us batch by file on disk. # Not as useful for single file snapshots, but useful for the standard # Amber extension. - locations: DefaultDict[str, List["Snapshot"]] = defaultdict(list) - for data, index in snapshots: - location = cls.get_location(test_location=test_location, index=index) + snapshot_collection = SnapshotCollection(location=snapshot_location) + for data, test_location, index in snapshots: snapshot_name = cls.get_snapshot_name( test_location=test_location, index=index ) - locations[location].append(Snapshot(name=snapshot_name, data=data)) - - # Ensures the folder path for the snapshot file exists. - try: - Path( - cls.get_location(test_location=test_location, index=index) - ).parent.mkdir(parents=True) - except FileExistsError: - pass + snapshot = Snapshot(name=snapshot_name, data=data) + snapshot_collection.add(snapshot) - for location, location_snapshots in locations.items(): - snapshot_collection = SnapshotCollection(location=location) - - if not test_location.matches_snapshot_location(location): + if not test_location.matches_snapshot_location(snapshot_location): warning_msg = gettext( "{line_end}Can not relate snapshot location '{}' " "to the test location.{line_end}" "Consider adding '{}' to the generated location." ).format( - location, + snapshot_location, test_location.basename, line_end="\n", ) warnings.warn(warning_msg) - for snapshot in location_snapshots: - snapshot_collection.add(snapshot) - - if not test_location.matches_snapshot_name(snapshot.name): - warning_msg = gettext( - "{line_end}Can not relate snapshot name '{}' " - "to the test location.{line_end}" - "Consider adding '{}' to the generated name." - ).format( - snapshot.name, - test_location.testname, - line_end="\n", - ) - warnings.warn(warning_msg) + if not test_location.matches_snapshot_name(snapshot.name): + warning_msg = gettext( + "{line_end}Can not relate snapshot name '{}' " + "to the test location.{line_end}" + "Consider adding '{}' to the generated name." + ).format( + snapshot.name, + test_location.testname, + line_end="\n", + ) + warnings.warn(warning_msg) + + # Ensures the folder path for the snapshot file exists. + try: + Path(snapshot_location).parent.mkdir(parents=True) + except FileExistsError: + pass - cls._write_snapshot_collection(snapshot_collection=snapshot_collection) + cls._write_snapshot_collection(snapshot_collection=snapshot_collection) @abstractmethod def delete_snapshots( @@ -432,9 +420,4 @@ def matches( class AbstractSyrupyExtension( SnapshotSerializer, SnapshotCollectionStorage, SnapshotReporter, SnapshotComparator ): - def __init__(self, test_location: "PyTestLocation"): - self._test_location = test_location - - @property - def test_location(self) -> "PyTestLocation": - return self._test_location + pass diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 15f535cb..5f74833c 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -14,10 +14,13 @@ Optional, Set, Tuple, + Type, ) import pytest +from syrupy.location import PyTestLocation + from .constants import EXIT_STATUS_FAIL_UNUSED from .data import SnapshotCollections from .report import SnapshotReport @@ -49,24 +52,33 @@ class SnapshotSession: ) _queued_snapshot_writes: Dict[ - "AbstractSyrupyExtension", List[Tuple["SerializedData", "SnapshotIndex"]] + Tuple[Type["AbstractSyrupyExtension"], str], + List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ] = field(default_factory=dict) def queue_snapshot_write( self, extension: "AbstractSyrupyExtension", + test_location: "PyTestLocation", data: "SerializedData", index: "SnapshotIndex", ) -> None: - queue = self._queued_snapshot_writes.get(extension, []) - queue.append((data, index)) - self._queued_snapshot_writes[extension] = queue + snapshot_location = extension.get_location( + test_location=test_location, index=index + ) + key = (extension.__class__, snapshot_location) + queue = self._queued_snapshot_writes.get(key, []) + queue.append((data, test_location, index)) + self._queued_snapshot_writes[key] = queue def flush_snapshot_write_queue(self) -> None: - for extension, queued_write in self._queued_snapshot_writes.items(): + for ( + extension_class, + snapshot_location, + ), queued_write in self._queued_snapshot_writes.items(): if queued_write: - extension.write_snapshot( - test_location=extension.test_location, snapshots=queued_write + extension_class.write_snapshot( + snapshot_location=snapshot_location, snapshots=queued_write ) self._queued_snapshot_writes = {}