Skip to content

Commit

Permalink
Review updates
Browse files Browse the repository at this point in the history
* Const-correctness
* Remove hand-generated FFT test inputs
* Limit power-of-two requirement to OMP/Reference

Co-authored-by: Thomas Grützmacher <thomas.gruetzmacher@kit.edu>
  • Loading branch information
upsj and Thomas Grützmacher committed Feb 17, 2021
1 parent 047bdeb commit 0addf0f
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 10,455 deletions.
5 changes: 2 additions & 3 deletions core/matrix/fft_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_CORE_MATRIX_FFT_KERNELS_HPP_


#include <ginkgo/core/matrix/fft.hpp>


#include <ginkgo/core/base/array.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/dense.hpp>


Expand Down
2 changes: 1 addition & 1 deletion core/test/utils/assertions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ ::testing::AssertionResult matrices_near_impl(
auto to_remove = [](char c) {
return !std::isalnum(c) && c != '_' && c != '.' && c != '-';
};
// remove all but alphanumerical and _.-<> characters from
// remove all but alphanumerical and _.- characters from
// expressions
firstfile.erase(
std::remove_if(firstfile.begin(), firstfile.end(), to_remove),
Expand Down
33 changes: 19 additions & 14 deletions cuda/matrix/fft_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "core/matrix/fft_kernels.hpp"


#include <array>


#include <cufft.h>


Expand All @@ -52,16 +55,6 @@ namespace cuda {
namespace fft {


struct cufft_deleter {
void operator()(cufftHandle *ptr)
{
auto data = *ptr;
delete ptr;
cufftDestroy(data);
}
};


template <typename InValueType, typename OutValueType>
struct cufft_type_impl {};

Expand Down Expand Up @@ -97,6 +90,15 @@ struct cufft_type_impl<std::complex<double>, std::complex<double>> {


class cufft_handle {
struct cufft_deleter {
void operator()(cufftHandle *ptr)
{
auto data = *ptr;
delete ptr;
cufftDestroy(data);
}
};

public:
operator cufftHandle() const { return *handle_; }

Expand All @@ -112,13 +114,16 @@ public:
{
static_assert(d == 1 || d == 2 || d == 3,
"Only 1D, 2D or 3D FFT supported");
long long cast_n[d];
std::copy_n(&n[0], d, &cast_n[0]);
std::array<long long, d> cast_n;
for (int i = 0; i < d; i++) {
cast_n[i] = static_cast<long long>(n[i]);
}
size_type work_size{};
GKO_ASSERT_NO_CUFFT_ERRORS(cufftSetAutoAllocation(*handle_, false));
GKO_ASSERT_NO_CUFFT_ERRORS(cufftMakePlanMany64(
*handle_, d, cast_n, cast_n, static_cast<int64>(in_batch_stride), 1,
cast_n, static_cast<int64>(out_batch_stride), 1,
*handle_, d, cast_n.data(), cast_n.data(),
static_cast<int64>(in_batch_stride), 1, cast_n.data(),
static_cast<int64>(out_batch_stride), 1,
cufft_type_impl<InValueType, OutValueType>::value,
static_cast<int64>(batch_count), &work_size));
work_area.resize_and_reset(work_size);
Expand Down
13 changes: 8 additions & 5 deletions include/ginkgo/core/base/exception_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,14 @@ inline dim<2> get_size(const dim<2> &size) { return size; }
*
*@throw BadDimension if _val is not a power of two.
*/
#define GKO_ASSERT_IS_POWER_OF_TWO(_val) \
if (_val == 0 || (_val & (_val - 1)) != 0) { \
throw ::gko::BadDimension(__FILE__, __LINE__, __func__, #_val, _val, \
_val, "expected power-of-two dimension"); \
}
#define GKO_ASSERT_IS_POWER_OF_TWO(_val) \
do { \
if (_val == 0 || (_val & (_val - 1)) != 0) { \
throw ::gko::BadDimension(__FILE__, __LINE__, __func__, #_val, \
_val, _val, \
"expected power-of-two dimension"); \
} \
} while (false)


/**
Expand Down
15 changes: 3 additions & 12 deletions include/ginkgo/core/matrix/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ class Fft : public EnableLinOp<Fft>,
Fft(std::shared_ptr<const Executor> exec, size_type size,
bool inverse = false)
: EnableLinOp<Fft>(exec, dim<2>{size}), buffer_{exec}, inverse_{inverse}
{
GKO_ASSERT_IS_POWER_OF_TWO(size);
}
{}

void apply_impl(const LinOp *b, LinOp *x) const override;

Expand Down Expand Up @@ -183,10 +181,7 @@ class Fft2 : public EnableLinOp<Fft2>,
size1_{size1},
size2_{size2},
inverse_{inverse}
{
GKO_ASSERT_IS_POWER_OF_TWO(size1);
GKO_ASSERT_IS_POWER_OF_TWO(size2);
}
{}

void apply_impl(const LinOp *b, LinOp *x) const override;

Expand Down Expand Up @@ -281,11 +276,7 @@ class Fft3 : public EnableLinOp<Fft3>,
size2_{size2},
size3_{size3},
inverse_{inverse}
{
GKO_ASSERT_IS_POWER_OF_TWO(size1);
GKO_ASSERT_IS_POWER_OF_TWO(size2);
GKO_ASSERT_IS_POWER_OF_TWO(size3);
}
{}

void apply_impl(const LinOp *b, LinOp *x) const override;

Expand Down
4 changes: 0 additions & 4 deletions matrices/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ configure_file("${Ginkgo_SOURCE_DIR}/matrices/config.hpp.in"

configure_file("test/ani1.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/ani1.mtx")
configure_file("test/ani4.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/ani4.mtx")
configure_file("test/fourier_in.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/fourier_in.mtx")
configure_file("test/fourier_out1.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/fourier_out1.mtx")
configure_file("test/fourier_out2.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/fourier_out2.mtx")
configure_file("test/fourier_out3.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/fourier_out3.mtx")
configure_file("test/isai_l.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/isai_l.mtx")
configure_file("test/isai_l_excess.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/isai_l_excess.mtx")
configure_file("test/isai_l_excess_rhs.mtx" "${Ginkgo_BINARY_DIR}/matrices/test/isai_l_excess_rhs.mtx")
Expand Down
8 changes: 0 additions & 8 deletions matrices/config.hpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ namespace matrices {

const char *location_ani1_mtx = "@Ginkgo_BINARY_DIR@/matrices/test/ani1.mtx";
const char *location_ani4_mtx = "@Ginkgo_BINARY_DIR@/matrices/test/ani4.mtx";
const char *location_fourier_in_mtx =
"@Ginkgo_BINARY_DIR@/matrices/test/fourier_in.mtx";
const char *location_fourier_out1_mtx =
"@Ginkgo_BINARY_DIR@/matrices/test/fourier_out1.mtx";
const char *location_fourier_out2_mtx =
"@Ginkgo_BINARY_DIR@/matrices/test/fourier_out2.mtx";
const char *location_fourier_out3_mtx =
"@Ginkgo_BINARY_DIR@/matrices/test/fourier_out3.mtx";
const char *location_isai_mtxs = "@Ginkgo_BINARY_DIR@/matrices/test/";
const char *location_1138_bus_mtx =
"@Ginkgo_BINARY_DIR@/matrices/test/1138_bus.mtx";
Expand Down
Loading

0 comments on commit 0addf0f

Please sign in to comment.