From 096f9b04e9abf666e442f6cfd87b18c9395bc157 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sat, 1 Jul 2023 08:15:17 +0200 Subject: [PATCH] Fix upcast on in-place. (#598) --- sparse/_sparse_array.py | 11 +++++++++++ sparse/tests/test_elemwise.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/sparse/_sparse_array.py b/sparse/_sparse_array.py index 8bd5db62..546addc7 100644 --- a/sparse/_sparse_array.py +++ b/sparse/_sparse_array.py @@ -289,6 +289,17 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ) if out is not None: + test_args = [ + np.empty(1, dtype=a.dtype) if hasattr(a, "dtype") else [a] + for a in inputs + ] + test_kwargs = kwargs.copy() + if method == "reduce": + test_kwargs["axis"] = None + test_out = tuple(np.empty(1, dtype=a.dtype) for a in out) + if len(test_out) == 1: + test_out = test_out[0] + getattr(ufunc, method)(*test_args, out=test_out, **test_kwargs) kwargs["dtype"] = out[0].dtype if method == "outer": diff --git a/sparse/tests/test_elemwise.py b/sparse/tests/test_elemwise.py index 746cbeb0..86cd6a05 100644 --- a/sparse/tests/test_elemwise.py +++ b/sparse/tests/test_elemwise.py @@ -720,3 +720,10 @@ def test_no_deprecation_warning(): a = np.array([1, 2]) s = sparse.COO(a, a, shape=(3,)) s == s + + +# Regression test for gh-587 +def test_no_out_upcast(): + a = sparse.COO([[0, 1], [0, 1]], [1, 1], shape=(2, 2)) + with pytest.raises(TypeError): + a *= 0.5