From aa983119f21473b4aac8270b28200a41d925d0e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Tue, 17 Sep 2024 14:49:09 +0000 Subject: [PATCH 1/2] ENH: Implement `reshape` function --- sparse/mlir_backend/__init__.py | 2 + sparse/mlir_backend/_constructors.py | 2 +- sparse/mlir_backend/_ops.py | 53 ++++++++++++++- sparse/mlir_backend/tests/test_simple.py | 82 ++++++++++++++++++++++-- 4 files changed, 131 insertions(+), 8 deletions(-) diff --git a/sparse/mlir_backend/__init__.py b/sparse/mlir_backend/__init__.py index 192217a3..93eefe1d 100644 --- a/sparse/mlir_backend/__init__.py +++ b/sparse/mlir_backend/__init__.py @@ -15,10 +15,12 @@ ) from ._ops import ( add, + reshape, ) __all__ = [ "add", "asarray", "asdtype", + "reshape", ] diff --git a/sparse/mlir_backend/_constructors.py b/sparse/mlir_backend/_constructors.py index 9a975ffa..658a7291 100644 --- a/sparse/mlir_backend/_constructors.py +++ b/sparse/mlir_backend/_constructors.py @@ -239,7 +239,7 @@ def from_sps(cls, arr: np.ndarray) -> "Dense": return dense_instance - def to_sps(self, shape: tuple[int, ...]) -> sps.csr_array: + def to_sps(self, shape: tuple[int, ...]) -> np.ndarray: data = ranked_memref_to_numpy(self.data) return data.reshape(shape) diff --git a/sparse/mlir_backend/_ops.py b/sparse/mlir_backend/_ops.py index 351c16b6..06573fb2 100644 --- a/sparse/mlir_backend/_ops.py +++ b/sparse/mlir_backend/_ops.py @@ -5,10 +5,12 @@ from mlir import ir from mlir.dialects import arith, func, linalg, sparse_tensor, tensor +import numpy as np + from ._common import fn_cache -from ._constructors import Tensor +from ._constructors import Tensor, numpy_to_ranked_memref from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm -from ._dtypes import DType, FloatingDType +from ._dtypes import DType, FloatingDType, Index @fn_cache @@ -68,6 +70,31 @@ def add(a, b): return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS]) +# @fn_cache +def get_reshape_module( + a_tensor_type: ir.RankedTensorType, + shape_tensor_type: ir.RankedTensorType, + out_tensor_type: ir.RankedTensorType, +) -> ir.Module: + with ir.Location.unknown(ctx): + module = ir.Module.create() + + with ir.InsertionPoint(module.body): + + @func.FuncOp.from_py_func(a_tensor_type, shape_tensor_type) + def reshape(a, shape): + return tensor.reshape(out_tensor_type, a, shape) + + reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + if DEBUG: + (CWD / "reshape_module.mlir").write_text(str(module)) + pm.run(module.operation) + if DEBUG: + (CWD / "reshape_module_opt.mlir").write_text(str(module)) + + return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS]) + + def add(x1: Tensor, x2: Tensor) -> Tensor: ret_obj = x1._format_class() out_tensor_type = x1._obj.get_tensor_definition(x1.shape) @@ -88,3 +115,25 @@ def add(x1: Tensor, x2: Tensor) -> Tensor: *x2._obj.to_module_arg(), ) return Tensor(ret_obj, shape=out_tensor_type.shape) + + +def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor: + ret_obj = x._format_class() + x_tensor_type = x._obj.get_tensor_definition(x.shape) + out_tensor_type = x._obj.get_tensor_definition(shape) + + with ir.Location.unknown(ctx): + shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type()) + + # TODO: Add proper caching + reshape_module = get_reshape_module(x_tensor_type, shape_tensor_type, out_tensor_type) + + shape = np.array(shape) + reshape_module.invoke( + "reshape", + ctypes.pointer(ctypes.pointer(ret_obj)), + *x._obj.to_module_arg(), + ctypes.pointer(ctypes.pointer(numpy_to_ranked_memref(shape))), + ) + + return Tensor(ret_obj, shape=out_tensor_type.shape) diff --git a/sparse/mlir_backend/tests/test_simple.py b/sparse/mlir_backend/tests/test_simple.py index fb462fc5..fd5276ba 100644 --- a/sparse/mlir_backend/tests/test_simple.py +++ b/sparse/mlir_backend/tests/test_simple.py @@ -75,6 +75,15 @@ def sampler_real_floating(size: tuple[int, ...]): raise NotImplementedError(f"{dtype=} not yet supported.") +def get_exampe_csf_arrays(dtype: np.dtype) -> tuple: + pos_1 = np.array([0, 1, 3], dtype=np.int64) + crd_1 = np.array([1, 0, 1], dtype=np.int64) + pos_2 = np.array([0, 3, 5, 7], dtype=np.int64) + crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64) + data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype) + return pos_1, crd_1, pos_2, crd_2, data + + @parametrize_dtypes @pytest.mark.parametrize("shape", [(100,), (10, 200), (5, 10, 20)]) def test_dense_format(dtype, shape): @@ -176,11 +185,7 @@ def test_add(rng, dtype): @parametrize_dtypes def test_csf_format(dtype): SHAPE = (2, 2, 4) - pos_1 = np.array([0, 1, 3], dtype=np.int64) - crd_1 = np.array([1, 0, 1], dtype=np.int64) - pos_2 = np.array([0, 3, 5, 7], dtype=np.int64) - crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64) - data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype) + pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype) csf = [pos_1, crd_1, pos_2, crd_2, data] csf_tensor = sparse.asarray(csf, shape=SHAPE, dtype=sparse.asdtype(dtype), format="csf") @@ -192,3 +197,70 @@ def test_csf_format(dtype): csf_2 = [pos_1, crd_1, pos_2, crd_2, data * 2] for actual, expected in zip(res_tensor, csf_2, strict=False): np.testing.assert_array_equal(actual, expected) + + +@parametrize_dtypes +def test_reshape(rng, dtype): + DENSITY = 0.5 + sampler = generate_sampler(dtype, rng) + + # CSR, CSC, COO + for shape, new_shape in [((100, 50), (25, 200)), ((80, 1), (8, 10))]: + for format in ["csr", "csc", "coo"]: + if format == "coo": + # NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135 + continue + if format == "csc": + # NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641 + continue + + arr = sps.random_array( + shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler + ) + if format == "coo": + arr.sum_duplicates() + + tensor = sparse.asarray(arr) + + actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse() + expected = arr.todense().reshape(new_shape) + + np.testing.assert_array_equal(actual.todense(), expected) + + # CSF + csf_shape = (2, 2, 4) + for shape, new_shape, expected_arrs in [ + ( + csf_shape, + (4, 4, 1), + [ + np.array([0, 0, 3, 5, 7]), + np.array([0, 1, 3, 0, 3, 0, 1]), + np.array([0, 1, 2, 3, 4, 5, 6, 7]), + np.array([0, 0, 0, 0, 0, 0, 0]), + np.array([1, 2, 3, 4, 5, 6, 7]), + ], + ), + ( + csf_shape, + (2, 1, 8), + [ + np.array([0, 1, 2]), + np.array([0, 0]), + np.array([0, 3, 7]), + np.array([4, 5, 7, 0, 3, 4, 5]), + np.array([1, 2, 3, 4, 5, 6, 7]), + ], + ), + ]: + csf = get_exampe_csf_arrays(dtype) + csf_tensor = sparse.asarray(csf, shape=shape, dtype=sparse.asdtype(dtype), format="csf") + + result = sparse.reshape(csf_tensor, shape=new_shape).to_scipy_sparse() + + for actual, expected in zip(result, expected_arrs, strict=False): + np.testing.assert_array_equal(actual, expected) + + # DENSE + # NOTE: dense reshape is probably broken in MLIR + # dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE) From 932547726bba74000e381d42ae1c67d7e661e597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Mon, 23 Sep 2024 11:55:26 +0000 Subject: [PATCH 2/2] Apply review comments --- sparse/mlir_backend/_ops.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sparse/mlir_backend/_ops.py b/sparse/mlir_backend/_ops.py index 06573fb2..8e8a0ba7 100644 --- a/sparse/mlir_backend/_ops.py +++ b/sparse/mlir_backend/_ops.py @@ -70,7 +70,7 @@ def add(a, b): return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS]) -# @fn_cache +@fn_cache def get_reshape_module( a_tensor_type: ir.RankedTensorType, shape_tensor_type: ir.RankedTensorType, @@ -99,7 +99,6 @@ def add(x1: Tensor, x2: Tensor) -> Tensor: ret_obj = x1._format_class() out_tensor_type = x1._obj.get_tensor_definition(x1.shape) - # TODO: Add proper caching # TODO: Decide what will be the output tensor_type add_module = get_add_module( x1._obj.get_tensor_definition(x1.shape), @@ -125,7 +124,6 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor: with ir.Location.unknown(ctx): shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type()) - # TODO: Add proper caching reshape_module = get_reshape_module(x_tensor_type, shape_tensor_type, out_tensor_type) shape = np.array(shape)