Skip to content

Commit

Permalink
refactored code to add scaled identity, introduced new mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
Slaedr committed Jan 10, 2022
1 parent 23ec8c9 commit a990e91
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 39 deletions.
14 changes: 14 additions & 0 deletions core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/components/device_matrix_data_kernels.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "core/factorization/factorization_kernels.hpp"
#include "core/matrix/csr_kernels.hpp"


Expand Down Expand Up @@ -99,6 +100,9 @@ GKO_REGISTER_OPERATION(outplace_absolute_array,
components::outplace_absolute_array);
GKO_REGISTER_OPERATION(scale, csr::scale);
GKO_REGISTER_OPERATION(inv_scale, csr::inv_scale);
GKO_REGISTER_OPERATION(add_scaled_identity, csr::add_scaled_identity);
GKO_REGISTER_OPERATION(add_diagonal_elems,
factorization::add_diagonal_elements);


} // anonymous namespace
Expand Down Expand Up @@ -619,6 +623,16 @@ void Csr<ValueType, IndexType>::inv_scale_impl(const LinOp* alpha)
}


template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::add_scaled_identity_impl(const LinOp* const a,
const LinOp* const b)
{
this->get_executor()->run(csr::make_add_diagonal_elems(this, false));
this->get_executor()->run(csr::make_add_scaled_identity(
as<Dense<ValueType>>(a), as<Dense<ValueType>>(b), this));
}


#define GKO_DECLARE_CSR_MATRIX(ValueType, IndexType) \
class Csr<ValueType, IndexType>
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_MATRIX);
Expand Down
14 changes: 14 additions & 0 deletions core/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ GKO_REGISTER_OPERATION(outplace_absolute_dense, dense::outplace_absolute_dense);
GKO_REGISTER_OPERATION(make_complex, dense::make_complex);
GKO_REGISTER_OPERATION(get_real, dense::get_real);
GKO_REGISTER_OPERATION(get_imag, dense::get_imag);
GKO_REGISTER_OPERATION(add_scaled_identity, dense::add_scaled_identity);


} // anonymous namespace
Expand Down Expand Up @@ -1276,6 +1277,19 @@ void Dense<ValueType>::get_imag(
}


template <typename ValueType>
void Dense<ValueType>::add_scaled_identity_impl(const LinOp* const a,
const LinOp* const b)
{
precision_dispatch_real_complex<ValueType>(
[this](auto dense_alpha, auto dense_beta, auto dense_x) {
this->get_executor()->run(dense::make_add_scaled_identity(
dense_alpha, dense_beta, dense_x));
},
a, b, this);
}


#define GKO_DECLARE_DENSE_MATRIX(_type) class Dense<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_MATRIX);

Expand Down
38 changes: 2 additions & 36 deletions core/matrix/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/matrix/dense.hpp>


#include "core/factorization/factorization_kernels.hpp"
#include "core/matrix/csr_kernels.hpp"
#include "core/matrix/dense_kernels.hpp"


namespace gko {
namespace matrix {
namespace identity {
namespace {


GKO_REGISTER_OPERATION(dense_add_scaled_identity, dense::add_scaled_identity);
GKO_REGISTER_OPERATION(csr_add_scaled_identity, csr::add_scaled_identity);
GKO_REGISTER_OPERATION(csr_add_diagonal_elems,
factorization::add_diagonal_elements);


} // anonymous namespace
} // namespace identity


template <typename ValueType>
Expand All @@ -72,25 +55,8 @@ void Identity<ValueType>::apply_impl(const LinOp* alpha, const LinOp* b,
const LinOp* beta, LinOp* x) const
{
if (auto bI = dynamic_cast<const Identity<ValueType>*>(b)) {
GKO_ASSERT_IS_SQUARE_MATRIX(x);
if (auto xd = dynamic_cast<Dense<ValueType>*>(x)) {
precision_dispatch_real_complex<ValueType>(
[this](auto dense_alpha, auto dense_beta, auto dense_x) {
this->get_executor()->run(
identity::make_dense_add_scaled_identity(
dense_alpha, dense_beta, dense_x));
},
alpha, beta, x);
} else if (auto xc = dynamic_cast<Csr<ValueType, int32>*>(x)) {
this->get_executor()->run(
identity::make_csr_add_diagonal_elems(xc, false));
this->get_executor()->run(identity::make_csr_add_scaled_identity(
as<Dense<ValueType>>(alpha), as<Dense<ValueType>>(beta), xc));
} else if (auto xc = dynamic_cast<Csr<ValueType, int64>*>(x)) {
this->get_executor()->run(
identity::make_csr_add_diagonal_elems(xc, false));
this->get_executor()->run(identity::make_csr_add_scaled_identity(
as<Dense<ValueType>>(alpha), as<Dense<ValueType>>(beta), xc));
if (auto xs = dynamic_cast<EnableScaledIdentityAddition*>(x)) {
xs->add_scaled_identity(alpha, beta);
} else {
GKO_NOT_IMPLEMENTED;
}
Expand Down
20 changes: 20 additions & 0 deletions include/ginkgo/core/base/exception_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,26 @@ inline T ensure_allocated_impl(T ptr, const std::string& file, int line,
"semi-colon warnings")


/**
* Checks that the operator is a scalar, ie., has size 1x1.
*
* @param _op Operator to be checked.
*
* @throw BadDimension if _op does not have size 1x1.
*/
#define GKO_ASSERT_IS_SCALAR(_op) \
{ \
auto sz = gko::detail::get_size(_op); \
if (sz[0] != 1 || sz[1] != 1) { \
throw ::gko::BadDimension(__FILE__, __LINE__, __func__, #_op, \
sz[0], sz[1], "expected scalar"); \
} \
} \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")


} // namespace gko


Expand Down
27 changes: 26 additions & 1 deletion include/ginkgo/core/base/lin_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ template <typename ValueType>
class Diagonal;


}
} // namespace matrix


/**
Expand Down Expand Up @@ -767,6 +767,31 @@ class EnableAbsoluteComputation : public AbsoluteComputable {
};


/**
* Mix-in that adds the operation M <- a I + b M for matrix M, identity
* operator I and scalars a and b, where M is the calling object.
*/
class EnableScaledIdentityAddition {
public:
/**
* Scales this and adds another scalar times the identity to it.
*
* @param a Scalar to multiply the identity operator by before adding.
* @param b Scalar to multiply this before adding the scaled identity to
* it.
*/
void add_scaled_identity(const LinOp* const a, const LinOp* const b)
{
GKO_ASSERT_IS_SCALAR(a);
GKO_ASSERT_IS_SCALAR(b);
add_scaled_identity_impl(a, b);
}

private:
virtual void add_scaled_identity_impl(const LinOp* a, const LinOp* b) = 0;
};


/**
* The EnableLinOp mixin can be used to provide sensible default implementations
* of the majority of the LinOp and PolymorphicObject interface.
Expand Down
5 changes: 4 additions & 1 deletion include/ginkgo/core/matrix/csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
public Transposable,
public Permutable<IndexType>,
public EnableAbsoluteComputation<
remove_complex<Csr<ValueType, IndexType>>> {
remove_complex<Csr<ValueType, IndexType>>>,
public EnableScaledIdentityAddition {
friend class EnableCreateMethod<Csr>;
friend class EnablePolymorphicObject<Csr, LinOp>;
friend class Coo<ValueType, IndexType>;
Expand Down Expand Up @@ -1168,6 +1169,8 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
Array<index_type> row_ptrs_;
Array<index_type> srow_;
std::shared_ptr<strategy_type> strategy_;

void add_scaled_identity_impl(const LinOp* a, const LinOp* b) override;
};


Expand Down
5 changes: 4 additions & 1 deletion include/ginkgo/core/matrix/dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class Dense
public Transposable,
public Permutable<int32>,
public Permutable<int64>,
public EnableAbsoluteComputation<remove_complex<Dense<ValueType>>> {
public EnableAbsoluteComputation<remove_complex<Dense<ValueType>>>,
public EnableScaledIdentityAddition {
friend class EnableCreateMethod<Dense>;
friend class EnablePolymorphicObject<Dense, LinOp>;
friend class Coo<ValueType, int32>;
Expand Down Expand Up @@ -1061,6 +1062,8 @@ class Dense
private:
Array<value_type> values_;
size_type stride_;

void add_scaled_identity_impl(const LinOp* a, const LinOp* b) override;
};


Expand Down
16 changes: 16 additions & 0 deletions reference/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,22 @@ TYPED_TEST(Csr, InvScalesData)
}


TYPED_TEST(Csr, ScaleCsrAddIdentityRectangular)
{
using Vec = typename TestFixture::Vec;
using T = typename TestFixture::value_type;
using Csr = typename TestFixture::Mtx;
auto alpha = gko::initialize<Vec>({2.0}, this->exec);
auto beta = gko::initialize<Vec>({-1.0}, this->exec);
auto b = gko::initialize<Csr>(
{I<T>{2.0, 0.0}, I<T>{1.0, 2.5}, I<T>{0.0, -4.0}}, this->exec);

b->add_scaled_identity(alpha.get(), beta.get());

GKO_ASSERT_MTX_NEAR(b, l({{0.0, 0.0}, {-1.0, -0.5}, {0.0, 4.0}}), 0.0);
}


template <typename ValueIndexType>
class CsrComplex : public ::testing::Test {
protected:
Expand Down
16 changes: 16 additions & 0 deletions reference/test/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4081,6 +4081,22 @@ TYPED_TEST(Dense, MakeTemporaryConversionConstDoesntConvertBack)
}


TYPED_TEST(Dense, ScaleAddIdentityRectangular)
{
using T = typename TestFixture::value_type;
using Vec = typename TestFixture::Mtx;
using MixedVec = typename TestFixture::MixedMtx;
auto alpha = gko::initialize<Vec>({2.0}, this->exec);
auto beta = gko::initialize<Vec>({-1.0}, this->exec);
auto b = gko::initialize<Vec>(
{I<T>{2.0, 0.0}, I<T>{1.0, 2.5}, I<T>{0.0, -4.0}}, this->exec);

b->add_scaled_identity(alpha.get(), beta.get());

GKO_ASSERT_MTX_NEAR(b, l({{0.0, 0.0}, {-1.0, -0.5}, {0.0, 4.0}}), 0.0);
}


template <typename T>
class DenseComplex : public ::testing::Test {
protected:
Expand Down

0 comments on commit a990e91

Please sign in to comment.