Skip to content

Commit

Permalink
Merge pull request #776 from pydata/reshape-func
Browse files Browse the repository at this point in the history
ENH: Implement `reshape` function
  • Loading branch information
mtsokol committed Sep 23, 2024
2 parents 373f29f + 9325477 commit c81d2e2
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 9 deletions.
2 changes: 2 additions & 0 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
)
from ._ops import (
add,
reshape,
)

__all__ = [
"add",
"asarray",
"asdtype",
"reshape",
]
2 changes: 1 addition & 1 deletion sparse/mlir_backend/_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def from_sps(cls, arr: np.ndarray) -> "Dense":

return dense_instance

def to_sps(self, shape: tuple[int, ...]) -> sps.csr_array:
def to_sps(self, shape: tuple[int, ...]) -> np.ndarray:
data = ranked_memref_to_numpy(self.data)
return data.reshape(shape)

Expand Down
53 changes: 50 additions & 3 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from mlir import ir
from mlir.dialects import arith, func, linalg, sparse_tensor, tensor

import numpy as np

from ._common import fn_cache
from ._constructors import Tensor
from ._constructors import Tensor, numpy_to_ranked_memref
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
from ._dtypes import DType, FloatingDType
from ._dtypes import DType, FloatingDType, Index


@fn_cache
Expand Down Expand Up @@ -68,11 +70,35 @@ def add(a, b):
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


@fn_cache
def get_reshape_module(
a_tensor_type: ir.RankedTensorType,
shape_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(a_tensor_type, shape_tensor_type)
def reshape(a, shape):
return tensor.reshape(out_tensor_type, a, shape)

reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "reshape_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


def add(x1: Tensor, x2: Tensor) -> Tensor:
ret_obj = x1._format_class()
out_tensor_type = x1._obj.get_tensor_definition(x1.shape)

# TODO: Add proper caching
# TODO: Decide what will be the output tensor_type
add_module = get_add_module(
x1._obj.get_tensor_definition(x1.shape),
Expand All @@ -88,3 +114,24 @@ def add(x1: Tensor, x2: Tensor) -> Tensor:
*x2._obj.to_module_arg(),
)
return Tensor(ret_obj, shape=out_tensor_type.shape)


def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
ret_obj = x._format_class()
x_tensor_type = x._obj.get_tensor_definition(x.shape)
out_tensor_type = x._obj.get_tensor_definition(shape)

with ir.Location.unknown(ctx):
shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type())

reshape_module = get_reshape_module(x_tensor_type, shape_tensor_type, out_tensor_type)

shape = np.array(shape)
reshape_module.invoke(
"reshape",
ctypes.pointer(ctypes.pointer(ret_obj)),
*x._obj.to_module_arg(),
ctypes.pointer(ctypes.pointer(numpy_to_ranked_memref(shape))),
)

return Tensor(ret_obj, shape=out_tensor_type.shape)
82 changes: 77 additions & 5 deletions sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def sampler_real_floating(size: tuple[int, ...]):
raise NotImplementedError(f"{dtype=} not yet supported.")


def get_exampe_csf_arrays(dtype: np.dtype) -> tuple:
pos_1 = np.array([0, 1, 3], dtype=np.int64)
crd_1 = np.array([1, 0, 1], dtype=np.int64)
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64)
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
return pos_1, crd_1, pos_2, crd_2, data


@parametrize_dtypes
@pytest.mark.parametrize("shape", [(100,), (10, 200), (5, 10, 20)])
def test_dense_format(dtype, shape):
Expand Down Expand Up @@ -176,11 +185,7 @@ def test_add(rng, dtype):
@parametrize_dtypes
def test_csf_format(dtype):
SHAPE = (2, 2, 4)
pos_1 = np.array([0, 1, 3], dtype=np.int64)
crd_1 = np.array([1, 0, 1], dtype=np.int64)
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64)
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype)
csf = [pos_1, crd_1, pos_2, crd_2, data]

csf_tensor = sparse.asarray(csf, shape=SHAPE, dtype=sparse.asdtype(dtype), format="csf")
Expand All @@ -192,3 +197,70 @@ def test_csf_format(dtype):
csf_2 = [pos_1, crd_1, pos_2, crd_2, data * 2]
for actual, expected in zip(res_tensor, csf_2, strict=False):
np.testing.assert_array_equal(actual, expected)


@parametrize_dtypes
def test_reshape(rng, dtype):
DENSITY = 0.5
sampler = generate_sampler(dtype, rng)

# CSR, CSC, COO
for shape, new_shape in [((100, 50), (25, 200)), ((80, 1), (8, 10))]:
for format in ["csr", "csc", "coo"]:
if format == "coo":
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
continue
if format == "csc":
# NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
continue

arr = sps.random_array(
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
)
if format == "coo":
arr.sum_duplicates()

tensor = sparse.asarray(arr)

actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse()
expected = arr.todense().reshape(new_shape)

np.testing.assert_array_equal(actual.todense(), expected)

# CSF
csf_shape = (2, 2, 4)
for shape, new_shape, expected_arrs in [
(
csf_shape,
(4, 4, 1),
[
np.array([0, 0, 3, 5, 7]),
np.array([0, 1, 3, 0, 3, 0, 1]),
np.array([0, 1, 2, 3, 4, 5, 6, 7]),
np.array([0, 0, 0, 0, 0, 0, 0]),
np.array([1, 2, 3, 4, 5, 6, 7]),
],
),
(
csf_shape,
(2, 1, 8),
[
np.array([0, 1, 2]),
np.array([0, 0]),
np.array([0, 3, 7]),
np.array([4, 5, 7, 0, 3, 4, 5]),
np.array([1, 2, 3, 4, 5, 6, 7]),
],
),
]:
csf = get_exampe_csf_arrays(dtype)
csf_tensor = sparse.asarray(csf, shape=shape, dtype=sparse.asdtype(dtype), format="csf")

result = sparse.reshape(csf_tensor, shape=new_shape).to_scipy_sparse()

for actual, expected in zip(result, expected_arrs, strict=False):
np.testing.assert_array_equal(actual, expected)

# DENSE
# NOTE: dense reshape is probably broken in MLIR
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)

0 comments on commit c81d2e2

Please sign in to comment.