Skip to content

Commit

Permalink
Merge add scaled identity (M <- aI + bM) feature interface and impls …
Browse files Browse the repository at this point in the history
…for Csr and Dense

In the case of Csr, this checks that the sparsity pattern has all diagonal entries.
This was done to avoid the need for reallocation and ensure user-wrapped data can be used.

Pull request #942
  • Loading branch information
Slaedr authored Feb 19, 2022
2 parents fc6322b + 534d1b8 commit 6ee7bfe
Show file tree
Hide file tree
Showing 39 changed files with 856 additions and 7 deletions.
60 changes: 60 additions & 0 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -842,4 +842,64 @@ __global__
}


template <typename ValueType, typename IndexType>
__global__ __launch_bounds__(default_block_size) void add_scaled_identity(
const ValueType* const __restrict__ alpha,
const ValueType* const __restrict__ beta, const IndexType num_rows,
const IndexType* const __restrict__ row_ptrs,
const IndexType* const __restrict__ col_idxs,
ValueType* const __restrict__ values)
{
constexpr int warp_size = config::warp_size;
auto tile_grp =
group::tiled_partition<warp_size>(group::this_thread_block());
const auto warpid = thread::get_subwarp_id_flat<warp_size, IndexType>();
const auto num_warps = thread::get_subwarp_num_flat<warp_size, IndexType>();
if (warpid < num_rows) {
const auto tid_in_warp = tile_grp.thread_rank();
const IndexType row_start = row_ptrs[warpid];
const IndexType num_nz = row_ptrs[warpid + 1] - row_start;
for (IndexType iz = tid_in_warp; iz < num_nz; iz += warp_size) {
values[iz + row_start] *= beta[0];
if (col_idxs[iz + row_start] == warpid) {
values[iz + row_start] += alpha[0];
}
}
}
}


template <typename IndexType>
__global__ __launch_bounds__(default_block_size) void check_diagonal_entries(
const IndexType num_min_rows_cols,
const IndexType* const __restrict__ row_ptrs,
const IndexType* const __restrict__ col_idxs,
bool* const __restrict__ has_all_diags)
{
constexpr int warp_size = config::warp_size;
auto tile_grp =
group::tiled_partition<warp_size>(group::this_thread_block());
const auto row = thread::get_subwarp_id_flat<warp_size, IndexType>();
if (row < num_min_rows_cols) {
const auto tid_in_warp = tile_grp.thread_rank();
const IndexType row_start = row_ptrs[row];
const IndexType num_nz = row_ptrs[row + 1] - row_start;
bool row_has_diag_local{false};
for (IndexType iz = tid_in_warp; iz < num_nz; iz += warp_size) {
if (col_idxs[iz + row_start] == row) {
row_has_diag_local = true;
break;
}
}
auto row_has_diag = static_cast<bool>(tile_grp.any(row_has_diag_local));
if (!row_has_diag) {
if (tile_grp.thread_rank() == 0) {
*has_all_diags = false;
}
return;
}
}
}


} // namespace kernel
10 changes: 10 additions & 0 deletions common/unified/base/kernel_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ struct to_device_type_impl {

template <typename ValueType>
struct to_device_type_impl<matrix::Dense<ValueType>*&> {
using type = matrix_accessor<device_type<ValueType>>;
static type map_to_device(matrix::Dense<ValueType>* mtx)
{
return to_device_type_impl<
matrix::Dense<ValueType>* const&>::map_to_device(mtx);
}
};

template <typename ValueType>
struct to_device_type_impl<matrix::Dense<ValueType>* const&> {
using type = matrix_accessor<device_type<ValueType>>;
static type map_to_device(matrix::Dense<ValueType>* mtx)
{
Expand Down
22 changes: 22 additions & 0 deletions common/unified/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,28 @@ void get_imag(std::shared_ptr<const DefaultExecutor> exec,
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GET_IMAG_KERNEL);


template <typename ValueType, typename ScalarType>
void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Dense<ScalarType>* const alpha,
const matrix::Dense<ScalarType>* const beta,
matrix::Dense<ValueType>* const mtx)
{
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto mtx) {
mtx(row, col) = beta[0] * mtx(row, col);
if (row == col) {
mtx(row, row) += alpha[0];
}
},
mtx->get_size(), alpha->get_const_values(), beta->get_const_values(),
mtx);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(
GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL);


} // namespace dense
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
Expand Down
3 changes: 3 additions & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_SCALE_KERNEL);
GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_INV_SCALE_KERNEL);
GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_KERNEL);
GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_SUB_SCALED_KERNEL);
GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL);
Expand Down Expand Up @@ -494,6 +495,8 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL);

template <typename ValueType, typename IndexType>
GKO_DECLARE_CSR_SCALE_KERNEL(ValueType, IndexType)
Expand Down
20 changes: 20 additions & 0 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ GKO_REGISTER_OPERATION(outplace_absolute_array,
components::outplace_absolute_array);
GKO_REGISTER_OPERATION(scale, csr::scale);
GKO_REGISTER_OPERATION(inv_scale, csr::inv_scale);
GKO_REGISTER_OPERATION(add_scaled_identity, csr::add_scaled_identity);
GKO_REGISTER_OPERATION(check_diagonal_entries,
csr::check_diagonal_entries_exist);


} // anonymous namespace
Expand Down Expand Up @@ -672,6 +675,23 @@ void Csr<ValueType, IndexType>::inv_scale_impl(const LinOp* alpha)
}


template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::add_scaled_identity_impl(const LinOp* const a,
const LinOp* const b)
{
bool has_diags{false};
this->get_executor()->run(
csr::make_check_diagonal_entries(this, has_diags));
if (!has_diags) {
GKO_UNSUPPORTED_MATRIX_PROPERTY(
"The matrix has one or more structurally zero diagonal entries!");
}
this->get_executor()->run(csr::make_add_scaled_identity(
make_temporary_conversion<ValueType>(a).get(),
make_temporary_conversion<ValueType>(b).get(), this));
}


#define GKO_DECLARE_CSR_MATRIX(ValueType, IndexType) \
class Csr<ValueType, IndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_MATRIX);
Expand Down
17 changes: 16 additions & 1 deletion core/matrix/csr_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ namespace kernels {
const matrix::Dense<ValueType>* alpha, \
matrix::Csr<ValueType, IndexType>* to_scale)

#define GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST(ValueType, IndexType) \
void check_diagonal_entries_exist( \
std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<ValueType, IndexType>* mtx, bool& has_all_diags)

#define GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL(ValueType, IndexType) \
void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Dense<ValueType>* alpha, \
const matrix::Dense<ValueType>* beta, \
matrix::Csr<ValueType, IndexType>* mtx)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType, typename IndexType> \
Expand Down Expand Up @@ -244,7 +255,11 @@ namespace kernels {
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_SCALE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_INV_SCALE_KERNEL(ValueType, IndexType)
GKO_DECLARE_CSR_INV_SCALE_KERNEL(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST(ValueType, IndexType); \
template <typename ValueType, typename IndexType> \
GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL(ValueType, IndexType)


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(csr, GKO_DECLARE_ALL_AS_TEMPLATES);
Expand Down
14 changes: 14 additions & 0 deletions core/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ GKO_REGISTER_OPERATION(outplace_absolute_dense, dense::outplace_absolute_dense);
GKO_REGISTER_OPERATION(make_complex, dense::make_complex);
GKO_REGISTER_OPERATION(get_real, dense::get_real);
GKO_REGISTER_OPERATION(get_imag, dense::get_imag);
GKO_REGISTER_OPERATION(add_scaled_identity, dense::add_scaled_identity);


} // anonymous namespace
Expand Down Expand Up @@ -1427,6 +1428,19 @@ void Dense<ValueType>::get_imag(
}


template <typename ValueType>
void Dense<ValueType>::add_scaled_identity_impl(const LinOp* const a,
const LinOp* const b)
{
precision_dispatch_real_complex<ValueType>(
[this](auto dense_alpha, auto dense_beta, auto dense_x) {
this->get_executor()->run(dense::make_add_scaled_identity(
dense_alpha, dense_beta, dense_x));
},
a, b, this);
}


#define GKO_DECLARE_DENSE_MATRIX(_type) class Dense<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_MATRIX);

Expand Down
10 changes: 9 additions & 1 deletion core/matrix/dense_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ namespace kernels {
const matrix::Dense<_vtype>* source, \
matrix::Dense<remove_complex<_vtype>>* result)

#define GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL(_vtype, _scalar_type) \
void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Dense<_scalar_type>* alpha, \
const matrix::Dense<_scalar_type>* beta, \
matrix::Dense<_vtype>* mtx)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType> \
Expand Down Expand Up @@ -385,7 +391,9 @@ namespace kernels {
template <typename ValueType> \
GKO_DECLARE_GET_REAL_KERNEL(ValueType); \
template <typename ValueType> \
GKO_DECLARE_GET_IMAG_KERNEL(ValueType)
GKO_DECLARE_GET_IMAG_KERNEL(ValueType); \
template <typename ValueType, typename ScalarType> \
GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL(ValueType, ScalarType)


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(dense, GKO_DECLARE_ALL_AS_TEMPLATES);
Expand Down
1 change: 0 additions & 1 deletion core/test/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/test/utils/array_generator.hpp"
#include "core/test/utils/assertions.hpp"
#include "core/test/utils/matrix_generator.hpp"
#include "core/test/utils/matrix_utils.hpp"
#include "core/test/utils/value_generator.hpp"


Expand Down
45 changes: 45 additions & 0 deletions core/test/utils/matrix_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,51 @@ void make_hpd(matrix::Dense<ValueType>* mtx,
}


/**
* Changes the diagonal entry in the requested row, shrinking the
* matrix by 1 nonzero entry.
*
* @param mtx The matrix to remove a diagonal entry from.
* @param row_to_process The row from which to remove the diagonal entry.
*/
template <typename MtxType>
void remove_diagonal_entry_from_row(
MtxType* const mtx, const typename MtxType::index_type row_to_process)
{
using value_type = typename MtxType::value_type;
using index_type = typename MtxType::index_type;
matrix_data<value_type, index_type> mdata;
mtx->write(mdata);
auto it = std::remove_if(mdata.nonzeros.begin(), mdata.nonzeros.end(),
[&](auto entry) {
return entry.row == row_to_process &&
entry.column == row_to_process;
});
mdata.nonzeros.erase(it, mdata.nonzeros.end());
mtx->read(mdata);
}


/**
* Ensures each row has a diagonal entry.
*/
template <typename MtxType>
void ensure_all_diagonal_entries(MtxType* const mtx)
{
using value_type = typename MtxType::value_type;
using index_type = typename MtxType::index_type;
matrix_data<value_type, index_type> mdata;
mtx->write(mdata);
const auto nrows = static_cast<index_type>(mtx->get_size()[0]);
mdata.nonzeros.reserve(mtx->get_num_stored_elements() + nrows);
for (index_type i = 0; i < nrows; i++) {
mdata.nonzeros.push_back({i, i, zero<value_type>()});
}
mdata.sum_duplicates();
mtx->read(mdata);
}


} // namespace test
} // namespace gko

Expand Down
61 changes: 61 additions & 0 deletions core/test/utils/matrix_utils_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <gtest/gtest.h>


#include <ginkgo/core/matrix/csr.hpp>


#include "core/test/utils.hpp"
#include "core/test/utils/matrix_generator.hpp"

Expand Down Expand Up @@ -198,4 +201,62 @@ TYPED_TEST(MatrixUtils, MakeHpdMatrixWithRatioCorrectly)
}


TEST(MatrixUtils, RemoveDiagonalEntry)
{
using T = float;
using Csr = gko::matrix::Csr<T, int>;
auto exec = gko::ReferenceExecutor::create();
auto b = gko::initialize<Csr>(
{I<T>{2.0, 0.0, 1.1, 0.0}, I<T>{1.0, 2.4, 0.0, -1.0},
I<T>{0.0, -4.0, 2.2, -2.0}, I<T>{0.0, -3.0, 1.5, 1.0}},
exec);
const int row_to_remove = 2;

gko::test::remove_diagonal_entry_from_row(b.get(), row_to_remove);

const auto rowptrs = b->get_const_row_ptrs();
const auto colidxs = b->get_const_col_idxs();
for (int i = 0; i < 4; i++) {
bool has_diag = false;
for (int j = rowptrs[i]; j < rowptrs[i + 1]; j++) {
if (colidxs[j] == i) {
has_diag = true;
}
}
ASSERT_EQ(has_diag, i != row_to_remove);
}
}


TEST(MatrixUtils, ModifyToEnsureAllDiagonalEntries)
{
using T = float;
using Csr = gko::matrix::Csr<T, int>;
auto exec = gko::ReferenceExecutor::create();
auto b = gko::initialize<Csr>(
{I<T>{2.0, 0.0, 1.1, 0.0}, I<T>{1.0, 2.4, 0.0, -1.0},
I<T>{0.0, -4.0, 2.2, -2.0}, I<T>{0.0, -3.0, 1.5, 1.0}},
exec);

gko::test::ensure_all_diagonal_entries(b.get());

const auto rowptrs = b->get_const_row_ptrs();
const auto colidxs = b->get_const_col_idxs();
bool all_diags = true;
for (int i = 0; i < 3; i++) {
bool has_diag = false;
for (int j = rowptrs[i]; j < rowptrs[i + 1]; j++) {
if (colidxs[j] == i) {
has_diag = true;
}
}
if (!has_diag) {
all_diags = false;
break;
}
}
ASSERT_TRUE(all_diags);
}


} // namespace
Loading

0 comments on commit 6ee7bfe

Please sign in to comment.