Skip to content

Commit

Permalink
[TOPI] sparse_dense Op sparse_data input added (#6889)
Browse files Browse the repository at this point in the history
* [TOPI] sparse_dense op sparse_data input added

* [1] clang issue resolved

* [2] python format resolved

* [3] lint error resolved

* [4] Review comments handled

* [5] Lint error resolved

* [6] Review comments handled

* [7] Review comments handled

* [8] Review comments handled
  • Loading branch information
ANSHUMAN TRIPATHY committed Dec 15, 2020
1 parent 054466b commit 862655b
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 64 deletions.
10 changes: 9 additions & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,15 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {

/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
bool sparse_lhs;

TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {
TVM_ATTR_FIELD(sparse_lhs)
.set_default(false)
.describe(
"Indicate whether sparse matrix is multiplied on the right or the left. If true, then "
"the operation is S * D^T (D dense, S sparse). If false, the operation is D * S^T");
}
};

/*! \brief Attributes for sparse_transpose operator */
Expand Down
29 changes: 23 additions & 6 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,28 +926,45 @@ def _impl(inputs, attr, params, mod):

data = inputs[3]

# By default, in tensorflow the first input ,i.e., data is sparse
sparse_lhs = True

# If both are true means First input was dense and second was sparse
if attr.get("adjoint_a") and attr.get("adjoint_b"):
sparse_lhs = False

rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

# Create scipy sparse Tensor(CSR)
weight_sp = csr_matrix(
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)
weight_sp = csr_matrix(weight_sp.transpose())

if sparse_lhs:
data = _op.transpose(data)
else:
weight_sp = csr_matrix(weight_sp.transpose())

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)

ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs])
ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, weight_indptrs], sparse_lhs)

# If both are true means First input was dense and second was sparse
# TODO(ANSHUMAN87): Support other adjoint option too
if attr.get("adjoint_a") and attr.get("adjoint_b"):
if not sparse_lhs:
ret = _op.transpose(ret)
else:

# Case 1. If both are true means first input was dense and second was sparse
# Case 2. If both are false means first input was sparse and second was dense
# TODO(ANSHUMAN87): Support other adjoint option too
if not (
(attr.get("adjoint_a") and attr.get("adjoint_b"))
or ((not attr.get("adjoint_a")) and (not attr.get("adjoint_b")))
):
raise tvm.error.OpAttributeUnImplemented(
"Only tf.sparse.sparse_dense_matmul() with adjoint_a=True and adjoint_b=True"
"or with adjoint_a=False and adjoint_b=False"
" is supported, but adjoint_a={} and adjoint_b={} was supplied.".format(
attr.get("adjoint_a"), attr.get("adjoint_b")
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compute_fifo_buffer(attrs, inputs, out_type):
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type):
"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]


reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)
Expand Down
44 changes: 31 additions & 13 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,17 +1993,27 @@ def batch_matmul(x, y):
return _make.batch_matmul(x, y)


def sparse_dense(data, weight):
# pylint: disable=no-else-return,inconsistent-return-statements
def sparse_dense(dense_mat, sparse_mat, sparse_lhs=False):
r"""
Computes the matrix multiplication of `data` and `weight`, where `data` is
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
Computes the matrix multiplication of `dense_mat` and `sparse_mat`, where `dense_mat` is
a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.
.. math::
\if sparse_lhs=False:
.. math::
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
= \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n]
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
\if sparse_lhs=True:
.. math::
where `as_dense` returns dense equivalent of the given sparse matrix.
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
= \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n]
where `as_dense` returns dense equivalent of the given S(sparse matrix)
while performing matmul with given D(dense matrix).
See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
Expand All @@ -2013,20 +2023,28 @@ def sparse_dense(data, weight):
Parameters
----------
data : tvm.relay.Expr
The input data for the matrix multiplication
dense_mat : tvm.relay.Expr
The input dense matrix for the matrix multiplication
weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the matrix multiplication.
sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The input sparse matrix for the matrix multiplication.
sparse_lhs : bool, optional
Indicates whether lhs or rhs matrix is sparse. Default value is False.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
if hasattr(weight, "indices"):
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
return _make.sparse_dense(data, weight[0], weight[1], weight[2])
if hasattr(sparse_mat, "indices"):
return _make.sparse_dense(
dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr, sparse_lhs
)
else:
return _make.sparse_dense(
dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2], sparse_lhs
)


def sparse_transpose(x):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""

def _compute_sparse_dense(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]

return _compute_sparse_dense

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def schedule_sparse_dense(outs):
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])

# TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also
def _callback(op):
if op.tag == "sparse_dense_bsrmm":
if op.tag == "sparse_dense_bsrmm_v2":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2"
out = s.outputs[0].output(0)

if op not in s.outputs:
Expand Down Expand Up @@ -362,6 +363,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
sparse_dense implementation for one that operates on a padded matrix. We
also padd the matrix.
"""
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
if (
isinstance(inputs[1], relay.Constant)
and isinstance(inputs[2], relay.Constant)
Expand Down
140 changes: 132 additions & 8 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import get_const_tuple


def sparse_dense(data, weight_data, weight_indices, weight_indptr):
def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -52,13 +52,104 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm
func = _sparse_dense_csrmm_v2
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm
func = _sparse_dense_bsrmm_v2
return func(data, weight_data, weight_indices, weight_indptr)


def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
"""
Computes sparse-dense matrix multiplication of
`(data_data, data_indices, data_indptr)` and `weight.T`
Parameters
----------
data_data:
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
data_indices:
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)
data_indptr:
1-D with shape [M + 1] (CSR) or
1-D with shape [(M + 1) // bs_r] (BSR)
weight:
2-D with shape [N, K], float32
Returns
-------
output : tvm.te.Tensor
2-D with shape [M, N]
"""
assert len(data_data.shape) in (1, 3)
if len(data_data.shape) == 1:
func = _sparse_dense_csrmm_v1
if len(data_data.shape) == 3:
func = _sparse_dense_bsrmm_v1
return func(data_data, data_indices, data_indptr, weight)


# pylint: disable=no-else-return,inconsistent-return-statements
def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_lhs=False):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`, if sparse_lhs=False
or
Computes sparse-dense matrix multiplication of
`(data_data, data_indices, data_indptr)` and `weight.T`, if sparse_lhs=True
Parameters
----------
dense_data : tvm.te.Tensor
2-D with shape [M, K], float32
sparse_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
sparse_indices : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)
sparse_indptr : tvm.te.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)
sparse_lhs : bool, optional
Indicates whether lhs or rhs matrix is sparse. Default value is False.
Returns
-------
output : tvm.te.Tensor
2-D with shape [M, N]
"""
if sparse_lhs:
return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data)
else:
return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr)


def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight):
oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0])

def f(row, i):
row_start = data_indptr[row]
row_end = data_indptr[row + 1]
row_elems = row_end - row_start
elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
elem = row_start + elem_idx
a_val = data_data[elem]
weight_val = weight[i, data_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm_v1")


def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr):
oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1)

def f(i, row):
Expand All @@ -71,10 +162,41 @@ def f(i, row):
weight_val = data[i, weight_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm")
return te.compute(oshape, f, tag="sparse_dense_csrmm_v2")


def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight):
(m, _) = get_const_tuple(weight.shape)
(_, bs_r, bs_c) = get_const_tuple(data_data.shape)
(num_blocks_plus_1,) = get_const_tuple(data_indptr.shape)
num_blocks = num_blocks_plus_1 - 1

def _compute_block(nb_j, j, i):
row_start = data_indptr[nb_j]
row_end = data_indptr[nb_j + 1]
row_elems = row_end - row_start
elem_idx = te.reduce_axis((0, row_elems), name="elem_idx")
block_offset = row_start + elem_idx
c = te.reduce_axis((0, bs_c), name="c")
block_j = data_indices[block_offset]
block_ij_val = data_data[block_offset][j][c]
x_val = weight[i, bs_c * block_j + c]
return te.sum(block_ij_val * x_val, axis=[elem_idx, c])

idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1"
)
return te.compute(
(num_blocks * bs_r, m),
lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n],
tag="sparse_dense_bsrmm_v1",
)


def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
Expand All @@ -95,11 +217,13 @@ def _compute_block(i, nb_j, j):
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod

bsrmm_block = te.compute((m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block")
bsrmm_block = te.compute(
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2"
)
return te.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
tag="sparse_dense_bsrmm",
tag="sparse_dense_bsrmm_v2",
)


Expand Down
Loading

0 comments on commit 862655b

Please sign in to comment.