Skip to content

Commit

Permalink
Introducing support for numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
freol35241 committed May 29, 2020
1 parent 6cc7b7a commit db2a3e3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
28 changes: 23 additions & 5 deletions pytest_pinned.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,20 @@ def pytest_unconfigure(config):
if EXPECTED_RESULTS:
with path.open('w') as f:
json.dump(EXPECTED_RESULTS, f, indent=4, sort_keys=True)

def _is_numpy_array(obj):
import sys
np = sys.modules.get("numpy")
if np:
return isinstance(obj, np.ndarray)
return False

class ExpectedResult:

# Tell numpy to use our `__eq__` operator instead of its.

__array_ufunc__ = None
__array_priority__ = 100

def __init__(self, expected, node, write):
self._expected = expected
Expand Down Expand Up @@ -88,21 +100,27 @@ def _get_expected(self, key):
def __eq__(self, other):
key = self._get_next_key()

# Special treatment of numpy arrays
import sys
np = sys.modules.get("numpy")
array = np and isinstance(other, np.ndarray)

if self._write:
self._expected[key] = other
return True
self._expected[key] = other if not array else other.tolist()

expected = self._get_expected(key)
res = self._compare_func(expected, other)

res = np.all(res) if array else res

self._reset_compare_func()

return res

def __call__(self, *args, **kwargs):
def approx_wrapper(expected, value):
def __call__(self, *args, **kwargs):
def wrapper(expected, value):
return eq(expected, pytest.approx(value, *args, **kwargs))
self._compare_func = approx_wrapper
self._compare_func = wrapper
return self

def __repr__(self):
Expand Down
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pytest-cov
pylint
pylint
numpy
11 changes: 9 additions & 2 deletions test/test_pinpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import json
import numpy as np

def test_passing_with_pinned_results(testdir):
"""Testing that we fail tests prior to having any
Expand All @@ -9,6 +10,7 @@ def test_passing_with_pinned_results(testdir):
# create a temporary pytest test file
testdir.makepyfile(
"""
def test_str(pinned):
assert pinned == "Hello World!"
Expand All @@ -18,6 +20,11 @@ def test_scalar(pinned):
def test_list(pinned):
assert [[1,2,3]] == pinned
def test_array(pinned):
import sys
np = sys.modules.get('numpy')
assert np.ones((5,7)) == pinned
def test_dict(pinned):
assert {'a': 1, 'b': 2, 'c': 3} == pinned
Expand All @@ -32,11 +39,11 @@ def test_multiple(pinned):

# Collect expected results
result = testdir.runpytest('--pinned-rewrite')
result.assert_outcomes(passed=5)
result.assert_outcomes(passed=6)

# Test again, this time ot should pass
result = testdir.runpytest()
result.assert_outcomes(passed=5)
result.assert_outcomes(passed=6)

def test_failing_with_pinned_results(testdir):
"""Testing that we fail tests prior to having any
Expand Down

0 comments on commit db2a3e3

Please sign in to comment.