Skip to content

Commit

Permalink
Feature/reflected ops noncommutative testing (#1)
Browse files Browse the repository at this point in the history
* np array solution

* cleanup

* np solution for division

* full reflected ops tests
  • Loading branch information
beckernick authored Aug 22, 2018
1 parent 7b0be07 commit 2745b29
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
24 changes: 21 additions & 3 deletions pygdf/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,13 @@ def __sub__(self, other):
return self._binaryop(other, 'sub')

def __rsub__(self, other):
return self.__sub__(other)
if isinstance(other, (int, float,
np.int32, np.int64,
np.float32, np.float64)):
empty = np.empty(len(self))
empty.fill(other)
other = Series(empty)
return other.__sub__(self)

def __mul__(self, other):
return self._binaryop(other, 'mul')
Expand All @@ -300,13 +306,25 @@ def __floordiv__(self, other):
return self._binaryop(other, 'floordiv')

def __rfloordiv__(self, other):
return self.__floordiv__(other)
if isinstance(other, (int, float,
np.int32, np.int64,
np.float32, np.float64)):
empty = np.empty(len(self))
empty.fill(other)
other = Series(empty)
return other.__floordiv__(self)

def __truediv__(self, other):
return self._binaryop(other, 'truediv')

def __rtruediv__(self, other):
return self.__truediv__(other)
if isinstance(other, (int, float,
np.int32, np.int64,
np.float32, np.float64)):
empty = np.empty(len(self))
empty.fill(other)
other = Series(empty)
return other.__truediv__(self)

__div__ = __truediv__

Expand Down
16 changes: 4 additions & 12 deletions pygdf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,17 @@ def test_series_cmpop_mixed_dtype(cmpop, lhs_dtype, rhs_dtype):
cmpop(lhs, rhs))


_commut_ops = [
reflected_ops = [
lambda x: 1 + x,
lambda x: 1 * x,
]

_noncommut_ops = [
lambda x: 1 - x,
lambda x: 1 / x,
lambda x: 1 // x,
]


@pytest.mark.parametrize('func, dtype', list(product(_commut_ops, _dtypes)))
def test_commutative_reflected_op_scalar(func, dtype):
@pytest.mark.parametrize('func, dtype', list(product(reflected_ops, _dtypes)))
def test_reflected_ops_scalar(func, dtype):
import pandas as pd

# create random series
Expand All @@ -168,9 +165,4 @@ def test_commutative_reflected_op_scalar(func, dtype):
ps_result = func(random_series)

# verify
np.testing.assert_array_equal(ps_result, gs_result)


@pytest.mark.parametrize('func, dtype', list(product(_noncommut_ops, _dtypes)))
def test_noncommutative_reflected_ops(func, dtype):
pass
np.testing.assert_allclose(ps_result, gs_result)

0 comments on commit 2745b29

Please sign in to comment.