Skip to content

Commit

Permalink
Some general fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Oct 11, 2023
1 parent 04d3681 commit 964bad7
Show file tree
Hide file tree
Showing 22 changed files with 184 additions and 246 deletions.
13 changes: 6 additions & 7 deletions common/cuda_hip/matrix/batch_ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

template <typename ValueType>
__device__ __forceinline__ void simple_apply(
const gko::batch::matrix::batch_ell::batch_item<const ValueType>& mat,
const gko::batch::matrix::ell::batch_item<const ValueType>& mat,
const ValueType* const __restrict__ b, ValueType* const __restrict__ x)
{
const auto num_rows = mat.num_rows;
Expand All @@ -60,7 +60,7 @@ template <typename ValueType>
__global__ __launch_bounds__(
default_block_size,
sm_oversubscription) void simple_apply_kernel(const gko::batch::matrix::
batch_ell::uniform_batch<
ell::uniform_batch<
const ValueType>
mat,
const gko::batch::
Expand Down Expand Up @@ -88,7 +88,7 @@ __global__ __launch_bounds__(
template <typename ValueType>
__device__ __forceinline__ void advanced_apply(
const ValueType alpha,
const gko::batch::matrix::batch_ell::batch_item<const ValueType>& mat,
const gko::batch::matrix::ell::batch_item<const ValueType>& mat,
const ValueType* const __restrict__ b, const ValueType beta,
ValueType* const __restrict__ x)
{
Expand Down Expand Up @@ -121,10 +121,9 @@ __global__ __launch_bounds__(
const ValueType>
alpha,
const gko::batch::matrix::
batch_ell::
uniform_batch<
const ValueType>
mat,
ell::uniform_batch<
const ValueType>
mat,
const gko::batch::
multi_vector::
uniform_batch<
Expand Down
32 changes: 4 additions & 28 deletions core/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,10 @@ template <typename ValueType, typename IndexType>
std::unique_ptr<Ell<ValueType, IndexType>>
Ell<ValueType, IndexType>::create_with_config_of(
ptr_param<const Ell<ValueType, IndexType>> other)
{
// De-referencing `other` before calling the functions (instead of
// using operator `->`) is currently required to be compatible with
// CUDA 10.1.
// Otherwise, it results in a compile error.
return (*other).create_with_same_config();
}


template <typename ValueType, typename IndexType>
std::unique_ptr<Ell<ValueType, IndexType>>
Ell<ValueType, IndexType>::create_with_same_config() const
{
return Ell<ValueType, IndexType>::create(
this->get_executor(), this->get_size(),
this->get_num_stored_elements_per_row());
other->get_executor(), other->get_size(),
other->get_num_stored_elements_per_row());
}


Expand Down Expand Up @@ -163,12 +151,7 @@ template <typename ValueType, typename IndexType>
void Ell<ValueType, IndexType>::apply_impl(const MultiVector<ValueType>* b,
MultiVector<ValueType>* x) const
{
GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());

GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
this->validate_application_parameters(b, x);
this->get_executor()->run(ell::make_simple_apply(this, b, x));
}

Expand All @@ -179,14 +162,7 @@ void Ell<ValueType, IndexType>::apply_impl(const MultiVector<ValueType>* alpha,
const MultiVector<ValueType>* beta,
MultiVector<ValueType>* x) const
{
GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());

GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(), gko::dim<2>(1, 1));
GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1));
this->validate_application_parameters(alpha, b, beta, x);
this->get_executor()->run(
ell::make_advanced_apply(alpha, this, b, beta, x));
}
Expand Down
20 changes: 10 additions & 10 deletions core/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct uniform_batch {
} // namespace dense


namespace batch_ell {
namespace ell {


/**
Expand All @@ -109,7 +109,7 @@ struct batch_item {
template <typename ValueType>
struct uniform_batch {
using value_type = ValueType;
using index_type = int;
using index_type = int32;
using entry_type = batch_item<value_type>;

ValueType* values;
Expand All @@ -127,7 +127,7 @@ struct uniform_batch {
};


} // namespace batch_ell
} // namespace ell


template <typename ValueType>
Expand Down Expand Up @@ -165,26 +165,26 @@ GKO_ATTRIBUTES GKO_INLINE dense::batch_item<ValueType> extract_batch_item(


template <typename ValueType>
GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item<const ValueType> to_const(
const batch_ell::batch_item<ValueType>& b)
GKO_ATTRIBUTES GKO_INLINE ell::batch_item<const ValueType> to_const(
const ell::batch_item<ValueType>& b)
{
return {b.values, b.col_idxs, b.stride,
b.num_rows, b.num_cols, b.num_stored_elems_per_row};
}


template <typename ValueType>
GKO_ATTRIBUTES GKO_INLINE batch_ell::uniform_batch<const ValueType> to_const(
const batch_ell::uniform_batch<ValueType>& ub)
GKO_ATTRIBUTES GKO_INLINE ell::uniform_batch<const ValueType> to_const(
const ell::uniform_batch<ValueType>& ub)
{
return {ub.values, ub.col_idxs, ub.num_batch_items, ub.stride,
ub.num_rows, ub.num_cols, ub.num_stored_elems_per_row};
}


template <typename ValueType>
GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item<ValueType> extract_batch_item(
const batch_ell::uniform_batch<ValueType>& batch, const size_type batch_idx)
GKO_ATTRIBUTES GKO_INLINE ell::batch_item<ValueType> extract_batch_item(
const ell::uniform_batch<ValueType>& batch, const size_type batch_idx)
{
return {batch.values +
batch_idx * batch.num_stored_elems_per_row * batch.num_rows,
Expand All @@ -196,7 +196,7 @@ GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item<ValueType> extract_batch_item(
}

template <typename ValueType>
GKO_ATTRIBUTES GKO_INLINE batch_ell::batch_item<ValueType> extract_batch_item(
GKO_ATTRIBUTES GKO_INLINE ell::batch_item<ValueType> extract_batch_item(
ValueType* const batch_values, int* const batch_col_idxs, const int stride,
const int num_rows, const int num_cols, int num_elems_per_row,
const size_type batch_idx)
Expand Down
8 changes: 4 additions & 4 deletions core/test/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ TYPED_TEST(Ell, SparseMtxKnowsItsSizeAndValues)
TYPED_TEST(Ell, CanBeEmpty)
{
auto empty = gko::batch::matrix::Ell<TypeParam>::create(this->exec);

this->assert_empty(empty.get());
}


TYPED_TEST(Ell, ReturnsNullValuesArrayWhenEmpty)
{
auto empty = gko::batch::matrix::Ell<TypeParam>::create(this->exec);

ASSERT_EQ(empty->get_const_values(), nullptr);
}

Expand Down Expand Up @@ -284,7 +286,6 @@ TYPED_TEST(Ell, CanBeConstructedFromEllMatrices)
using value_type = typename TestFixture::value_type;
using EllMtx = typename TestFixture::EllMtx;
using size_type = gko::size_type;

auto mat1 = gko::initialize<EllMtx>({{-1.0, 0.0, 0.0}, {0.0, 2.5, 3.5}},
this->exec);
auto mat2 =
Expand All @@ -304,15 +305,14 @@ TYPED_TEST(Ell, CanBeConstructedFromEllMatricesByDuplication)
using index_type = int;
using EllMtx = typename TestFixture::EllMtx;
using size_type = gko::size_type;

auto mat1 =
gko::initialize<EllMtx>({{1.0, 0.0, 0.0}, {0.0, 2.0, 0.0}}, this->exec);

auto bat_m =
gko::batch::create_from_item<gko::batch::matrix::Ell<value_type>>(
this->exec,
std::vector<EllMtx*>{mat1.get(), mat1.get(), mat1.get()},
mat1->get_num_stored_elements_per_row());

auto m = gko::batch::create_from_item<gko::batch::matrix::Ell<value_type>>(
this->exec, 3, mat1.get(), mat1->get_num_stored_elements_per_row());

Expand All @@ -326,7 +326,6 @@ TYPED_TEST(Ell, CanBeConstructedByDuplicatingEllMatrices)
using index_type = int;
using EllMtx = typename TestFixture::EllMtx;
using size_type = gko::size_type;

auto mat1 = gko::initialize<EllMtx>({{-1.0, 0.0, 0.0}, {0.0, 2.5, 0.0}},
this->exec);
auto mat2 =
Expand Down Expand Up @@ -372,6 +371,7 @@ TYPED_TEST(Ell, CanBeListConstructed)
{
using value_type = typename TestFixture::value_type;
using index_type = int;

auto m = gko::batch::initialize<gko::batch::matrix::Ell<TypeParam>>(
{{0.0, -1.0}, {1.0, 0.0}}, this->exec);

Expand Down
3 changes: 2 additions & 1 deletion cuda/matrix/batch_dense_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <thrust/functional.h>


#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/matrix/batch_dense.hpp>


#include "core/base/batch_struct.hpp"
Expand Down
6 changes: 2 additions & 4 deletions cuda/matrix/batch_ell_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <thrust/functional.h>
#include <thrust/transform.h>


#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/matrix/batch_ell.hpp>


#include "core/base/batch_struct.hpp"
#include "core/matrix/batch_struct.hpp"
#include "cuda/base/batch_struct.hpp"
#include "cuda/base/config.hpp"
#include "cuda/base/cublas_bindings.hpp"
#include "cuda/base/pointer_mode_guard.hpp"
#include "cuda/base/thrust.cuh"
#include "cuda/components/cooperative_groups.cuh"
#include "cuda/components/reduction.cuh"
Expand Down
23 changes: 12 additions & 11 deletions cuda/matrix/batch_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


#include <ginkgo/core/matrix/batch_dense.hpp>
#include <ginkgo/core/matrix/batch_ell.hpp>


#include "core/base/batch_struct.hpp"
Expand Down Expand Up @@ -91,33 +92,33 @@ get_batch_struct(batch::matrix::Dense<ValueType>* const op)
* Generates an immutable uniform batch struct from a batch of ell matrices.
*/
template <typename ValueType>
inline batch::matrix::batch_ell::uniform_batch<const cuda_type<ValueType>>
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>>
get_batch_struct(const batch::matrix::Ell<ValueType, int32>* const op)
{
return {as_cuda_type(op->get_const_values()),
op->get_const_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


/**
* Generates a uniform batch struct from a batch of ell matrices.
*/
template <typename ValueType>
inline batch::matrix::batch_ell::uniform_batch<cuda_type<ValueType>>
get_batch_struct(batch::matrix::Ell<ValueType, int32>* const op)
inline batch::matrix::ell::uniform_batch<cuda_type<ValueType>> get_batch_struct(
batch::matrix::Ell<ValueType, int32>* const op)
{
return {as_cuda_type(op->get_values()),
op->get_col_idxs(),
op->get_num_batch_items(),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[0]),
static_cast<int>(op->get_common_size()[1]),
static_cast<int>(op->get_num_stored_elements_per_row())};
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[0]),
static_cast<int32>(op->get_common_size()[1]),
static_cast<int32>(op->get_num_stored_elements_per_row())};
}


Expand Down
62 changes: 29 additions & 33 deletions dpcpp/matrix/batch_ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <CL/sycl.hpp>


#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/matrix/batch_ell.hpp>


#include "core/base/batch_struct.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "core/matrix/batch_struct.hpp"
#include "dpcpp/base/batch_struct.hpp"
#include "dpcpp/base/config.hpp"
#include "dpcpp/base/dim3.dp.hpp"
#include "dpcpp/base/dpct.hpp"
#include "dpcpp/base/helper.hpp"
Expand Down Expand Up @@ -98,19 +94,19 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}

// Launch a kernel that has nbatches blocks, each block has max group size
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
simple_apply_kernel(mat_b, b_b, x_b, item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
simple_apply_kernel(mat_b, b_b, x_b, item_ct1);
});
});
}

Expand Down Expand Up @@ -145,24 +141,24 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
const dim3 grid(num_batch_items);

// Launch a kernel that has nbatches blocks, each block has max group size
(exec->get_queue())->submit([&](sycl::handler& cgh) {
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl_nd_range(grid, block),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto beta_b =
batch::extract_batch_item(beta_ub, group_id);
advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b,
item_ct1);
});
sycl_nd_range(grid, block), [=
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
config::warp_size)]] {
auto group = item_ct1.get_group();
auto group_id = group.get_group_linear_id();
const auto mat_b =
batch::matrix::extract_batch_item(mat_ub, group_id);
const auto b_b = batch::extract_batch_item(b_ub, group_id);
const auto x_b = batch::extract_batch_item(x_ub, group_id);
const auto alpha_b =
batch::extract_batch_item(alpha_ub, group_id);
const auto beta_b =
batch::extract_batch_item(beta_ub, group_id);
advanced_apply_kernel(alpha_b, mat_b, b_b, beta_b, x_b,
item_ct1);
});
});
}

Expand Down
Loading

0 comments on commit 964bad7

Please sign in to comment.