Skip to content

Commit

Permalink
Merge pull request #37 from sidlak-c137/condense
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Sep 21, 2022
2 parents 88b74d4 + f5deea4 commit 1b05efb
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/tvm/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
"""Python-interface for Sparse-TIR"""

from .lower import lower_sparse_iter, lower_sparse_buffer
from .format_rewrite import FormatRewriteRule, column_part_hyb
from .format_rewrite import FormatRewriteRule, column_part_hyb, condense
from .specialize import specialize_buffer
6 changes: 6 additions & 0 deletions python/tvm/sparse/format_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,9 @@ def column_part_hyb(num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, bu
return _ffi_api.ColumnPartHyb(
num_rows, num_cols, indptr_nd, indices_nd, num_col_parts, buckets # type: ignore
)


def condense(indptr_nd, indices_nd, bucket_size):
return _ffi_api.ConDense(
indptr_nd, indices_nd, bucket_size # type: ignore
)
45 changes: 45 additions & 0 deletions src/sparse/format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,52 @@ Array<Array<Array<NDArray>>> ColumnPartHyb(int num_rows, int num_cols, NDArray i
return {row_indices_nd, col_indices_nd, mask_nd};
}

Array<NDArray> ConDense(NDArray indptr, NDArray indices, int block_size) {
// Check inputs
CHECK_EQ(indptr->dtype.bits, 32) << "Only support int32 index data type, got "
<< int(indptr->dtype.bits) << " bits for indptr.";
CHECK_EQ(indices->dtype.bits, 32) << "Only support int32 index data type, got "
<< int(indices->dtype.bits) << " bits for indices.";
CHECK_EQ(indptr->device.device_type, kDLCPU) << "Only support ConDense conversion on CPU.";
CHECK_EQ(indices->device.device_type, kDLCPU) << "Only support ConDense conversion on CPU.";
// Get data from NDArrays
int* indptr_data = static_cast<int*>(indptr->data);
int* indices_data = static_cast<int*>(indices->data);
// Set up return values
int n = indptr->shape[0] - 1;
int num_blocks = (n + block_size - 1) / block_size;
std::vector<int> ret_indptr(num_blocks + 1);
std::vector<int> ret_indices;
ret_indptr[0] = 0;
// Condense matrix
for (int block_id = 0; block_id < num_blocks; block_id++) {
int curr_block = block_id * block_size;
int next_block = curr_block + block_size;
int lo = indptr_data[curr_block];
int hi = next_block > n ? indptr_data[n] : indptr_data[next_block];
// Find unique indices from lo to hi
std::vector<int> unique(hi - lo);
for (int i = 0; i < hi - lo; i++) {
unique[i] = indices_data[lo + i];
}
std::sort(unique.begin(), unique.end());
unique.erase(std::unique(unique.begin(), unique.end()), unique.end());
ret_indices.insert(ret_indices.end(), unique.begin(), unique.end());
ret_indptr[block_id + 1] = ret_indptr[block_id] + unique.size();
}

// Convert to NDArray
int ret_indptr_size = ret_indptr.size();
int ret_indices_size = ret_indices.size();
NDArray indptr_nd = NDArray::Empty({ret_indptr_size}, {kDLInt, 32, 1}, {kDLCPU, 0});
NDArray indices_nd = NDArray::Empty({ret_indices_size}, {kDLInt, 32, 1}, {kDLCPU, 0});
indptr_nd.CopyFromBytes(ret_indptr.data(), ret_indptr_size * sizeof(int));
indices_nd.CopyFromBytes(ret_indices.data(), ret_indices_size * sizeof(int));
return {indptr_nd, indices_nd};
}

namespace sparse {
TVM_REGISTER_GLOBAL("tir.sparse.ColumnPartHyb").set_body_typed(ColumnPartHyb);
TVM_REGISTER_GLOBAL("tir.sparse.ConDense").set_body_typed(ConDense);
} // namespace sparse
} // namespace tvm
84 changes: 67 additions & 17 deletions tests/python/sparsetir/test_format_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tvm
import dgl
import numpy as np
from tvm.sparse import FormatRewriteRule, column_part_hyb
from tvm.sparse import FormatRewriteRule, column_part_hyb, condense
from sparse_tir_scripts import csrmm
from sparse_tir_format_rewrite_scripts import (
bsr,
Expand Down Expand Up @@ -71,7 +71,8 @@ def test_csrmm_bsr_rewrite():
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.SparseFormatRewrite(rewrites)(mod)
print(mod["main"].script())
tvm.ir.assert_structural_equal(mod["main"], bsr_rewrite_with_preprocess, True)
tvm.ir.assert_structural_equal(
mod["main"], bsr_rewrite_with_preprocess, True)


def test_csrmm_ell_rewrite():
Expand All @@ -92,7 +93,8 @@ def test_csrmm_ell_rewrite():
)
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.SparseFormatRewrite(rewrites)(mod)
tvm.ir.assert_structural_equal(mod["main"], ell_rewrite_with_preprocess, True)
tvm.ir.assert_structural_equal(
mod["main"], ell_rewrite_with_preprocess, True)


def csrpadding_inv_index_map(i, jo, ji):
Expand Down Expand Up @@ -120,7 +122,8 @@ def test_csrmm_padding_rewrite():
]
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.SparseFormatRewrite(rewrites)(mod)
tvm.ir.assert_structural_equal(mod["main"], padding_rewrite_with_preprocess, True)
tvm.ir.assert_structural_equal(
mod["main"], padding_rewrite_with_preprocess, True)


def scipy_column_part_hyb(g, column_part, bucket_sizes):
Expand All @@ -131,7 +134,7 @@ def scipy_column_part_hyb(g, column_part, bucket_sizes):
nnz = mat.nnz
per_column_part_size = (n + column_part - 1) // column_part
sub_mats = [
mat[:, i * per_column_part_size : (i + 1) * per_column_part_size]
mat[:, i * per_column_part_size: (i + 1) * per_column_part_size]
for i in range(column_part)
]

Expand All @@ -143,9 +146,11 @@ def scipy_column_part_hyb(g, column_part, bucket_sizes):
in_degrees = sub_mat.indptr[1:] - sub_mat.indptr[:-1]
for i, bucket_size in enumerate(bucket_sizes[:-1]):
last_bucket_size = 0 if i == 0 else bucket_sizes[i - 1]
ell_n.append(int(((in_degrees > last_bucket_size) & (in_degrees <= bucket_size)).sum()))
ell_n.append(int(((in_degrees > last_bucket_size)
& (in_degrees <= bucket_size)).sum()))
sub_indegrees = in_degrees[in_degrees > bucket_sizes[-2]]
ell_n.append(int(((sub_indegrees + bucket_sizes[-1] - 1) // bucket_sizes[-1]).sum()))
ell_n.append(
int(((sub_indegrees + bucket_sizes[-1] - 1) // bucket_sizes[-1]).sum()))

ell_rows = []
ell_indices = []
Expand All @@ -157,24 +162,29 @@ def scipy_column_part_hyb(g, column_part, bucket_sizes):
for i, bucket_size in enumerate(bucket_sizes[:-1]):
last_bucket_size = 0 if i == 0 else bucket_sizes[i - 1]
ell_rows.append(
((in_degrees > last_bucket_size) & (in_degrees <= bucket_size)).nonzero()[0]
((in_degrees > last_bucket_size) &
(in_degrees <= bucket_size)).nonzero()[0]
)
ell_rows.append((in_degrees > bucket_sizes[-2]).nonzero()[0])

for i, bucket_size in enumerate(bucket_sizes[:-1]):
indices = np.zeros(
(ell_n[partition * len(bucket_sizes) + i], bucket_size), dtype=np.int32
(ell_n[partition * len(bucket_sizes) + i],
bucket_size), dtype=np.int32
)
for j, row_id in enumerate(ell_rows[partition * len(bucket_sizes) + i]):
row = sub_mat[row_id]
indices[j, : row.nnz] = row.indices + partition * per_column_part_size
indices[j, : row.nnz] = row.indices + \
partition * per_column_part_size
ell_indices.append(indices)

# split rows for the last bucket
indices = np.zeros(
(ell_n[(partition + 1) * len(bucket_sizes) - 1], bucket_sizes[-1]), dtype=np.int32
(ell_n[(partition + 1) * len(bucket_sizes) - 1],
bucket_sizes[-1]), dtype=np.int32
)
new_rows = np.zeros((ell_n[(partition + 1) * len(bucket_sizes) - 1],), dtype=np.int32)
new_rows = np.zeros(
(ell_n[(partition + 1) * len(bucket_sizes) - 1],), dtype=np.int32)
bucket_size = bucket_sizes[-1]
i = 0
for row_id in ell_rows[-1]:
Expand All @@ -183,11 +193,12 @@ def scipy_column_part_hyb(g, column_part, bucket_sizes):
if start_offset + bucket_size >= row.nnz:
# last bucket
indices[i, : row.nnz - start_offset] = (
row.indices[start_offset:] + partition * per_column_part_size
row.indices[start_offset:] +
partition * per_column_part_size
)
else:
indices[i] = (
row.indices[start_offset : start_offset + bucket_size]
row.indices[start_offset: start_offset + bucket_size]
+ partition * per_column_part_size
)
new_rows[i] = row_id
Expand All @@ -202,7 +213,7 @@ def scipy_column_part_hyb(g, column_part, bucket_sizes):
def test_column_part_hyb():
g = dgl.rand_graph(1000, 10000).int()
column_parts = 4
buckets = [1, 2, 4, 8]
buckets = [1, 2, 4]
indptr, indices, _ = g.adj_sparse("csc")
indptr_nd = tvm.nd.array(indptr.numpy(), device=tvm.cpu())
indices_nd = tvm.nd.array(indices.numpy(), device=tvm.cpu())
Expand All @@ -211,7 +222,8 @@ def test_column_part_hyb():
g.num_dst_nodes(), g.num_src_nodes(), indptr_nd, indices_nd, column_parts, buckets
)
# compute indices with scipy
row_indices_scipy, col_indices_scipy = scipy_column_part_hyb(g, column_parts, buckets)
row_indices_scipy, col_indices_scipy = scipy_column_part_hyb(
g, column_parts, buckets)

for part_id in range(column_parts):
for bucket_id, _ in enumerate(buckets):
Expand All @@ -225,9 +237,47 @@ def test_column_part_hyb():
)


def condense_py(indptr, indices, block_size):
m = len(indptr) - 1
ret_indptr = [0]
ret_indices = []
for block_id in range((m + block_size - 1) // block_size):
start_offset = indptr[block_id * block_size]
end_offset = indptr[-1] if (block_id + 1) * \
block_size > m else indptr[(block_id + 1) * block_size]
tile_indices = indices[start_offset: end_offset]
unique_col_indices = np.unique(tile_indices)
ret_indptr.append(ret_indptr[-1] + len(unique_col_indices))
ret_indices.append(unique_col_indices)
return np.array(ret_indptr), np.concatenate(ret_indices)


def test_condense():
g = dgl.rand_graph(1000, 10000).int()
buckets = [1, 2, 4, 8]
indptr, indices, _ = g.adj_sparse("csc")
indptr = indptr.numpy()
indices = indices.numpy()
indptr_nd = tvm.nd.array(indptr, device=tvm.cpu())
indices_nd = tvm.nd.array(indices, device=tvm.cpu())
for bucket_size in buckets:
# built-in c++ function
indptr_ret, indices_ret = condense(indptr_nd, indices_nd, bucket_size)
# Python version of function
indptr_py, indices_py = condense_py(indptr, indices, bucket_size)
assert np.array_equal(
indptr_ret.numpy(),
indptr_py,
)
assert np.array_equal(
indices_ret.numpy(),
indices_py,
)


if __name__ == "__main__":
# test_csrmm_bsr_rewrite()
# test_csrmm_ell_rewrite()
# test_csrmm_padding_rewrite()
test_column_part_hyb()
# test_condense()
test_condense()

0 comments on commit 1b05efb

Please sign in to comment.