From db2a3e3969218a3ff8759ba36985c12233bed488 Mon Sep 17 00:00:00 2001 From: freol35241 Date: Fri, 29 May 2020 21:27:38 +0200 Subject: [PATCH] Introducing support for numpy arrays --- pytest_pinned.py | 28 +++++++++++++++++++++++----- requirements_dev.txt | 3 ++- test/test_pinpoint.py | 11 +++++++++-- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/pytest_pinned.py b/pytest_pinned.py index bcdc50f..8f00787 100644 --- a/pytest_pinned.py +++ b/pytest_pinned.py @@ -48,8 +48,20 @@ def pytest_unconfigure(config): if EXPECTED_RESULTS: with path.open('w') as f: json.dump(EXPECTED_RESULTS, f, indent=4, sort_keys=True) + +def _is_numpy_array(obj): + import sys + np = sys.modules.get("numpy") + if np: + return isinstance(obj, np.ndarray) + return False class ExpectedResult: + + # Tell numpy to use our `__eq__` operator instead of its. + + __array_ufunc__ = None + __array_priority__ = 100 def __init__(self, expected, node, write): self._expected = expected @@ -88,21 +100,27 @@ def _get_expected(self, key): def __eq__(self, other): key = self._get_next_key() + # Special treatment of numpy arrays + import sys + np = sys.modules.get("numpy") + array = np and isinstance(other, np.ndarray) + if self._write: - self._expected[key] = other - return True + self._expected[key] = other if not array else other.tolist() expected = self._get_expected(key) res = self._compare_func(expected, other) + res = np.all(res) if array else res + self._reset_compare_func() return res - def __call__(self, *args, **kwargs): - def approx_wrapper(expected, value): + def __call__(self, *args, **kwargs): + def wrapper(expected, value): return eq(expected, pytest.approx(value, *args, **kwargs)) - self._compare_func = approx_wrapper + self._compare_func = wrapper return self def __repr__(self): diff --git a/requirements_dev.txt b/requirements_dev.txt index c832390..d8e9943 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,2 +1,3 @@ pytest-cov -pylint \ No newline at end of file +pylint +numpy \ No newline at end of file diff --git a/test/test_pinpoint.py b/test/test_pinpoint.py index f994f9e..3d18756 100644 --- a/test/test_pinpoint.py +++ b/test/test_pinpoint.py @@ -1,5 +1,6 @@ import pytest import json +import numpy as np def test_passing_with_pinned_results(testdir): """Testing that we fail tests prior to having any @@ -9,6 +10,7 @@ def test_passing_with_pinned_results(testdir): # create a temporary pytest test file testdir.makepyfile( """ + def test_str(pinned): assert pinned == "Hello World!" @@ -18,6 +20,11 @@ def test_scalar(pinned): def test_list(pinned): assert [[1,2,3]] == pinned + def test_array(pinned): + import sys + np = sys.modules.get('numpy') + assert np.ones((5,7)) == pinned + def test_dict(pinned): assert {'a': 1, 'b': 2, 'c': 3} == pinned @@ -32,11 +39,11 @@ def test_multiple(pinned): # Collect expected results result = testdir.runpytest('--pinned-rewrite') - result.assert_outcomes(passed=5) + result.assert_outcomes(passed=6) # Test again, this time ot should pass result = testdir.runpytest() - result.assert_outcomes(passed=5) + result.assert_outcomes(passed=6) def test_failing_with_pinned_results(testdir): """Testing that we fail tests prior to having any