Skip to content

Commit

Permalink
Fix comment.
Browse files Browse the repository at this point in the history
  • Loading branch information
tangjj11 committed Oct 8, 2024
1 parent beb04ba commit 2944ced
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 39 deletions.
27 changes: 12 additions & 15 deletions sycl/include/syclcompat/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,21 +794,19 @@ pow(const ValueT a, const ValueU b) {
/// Performs relu saturation.
/// \param [in] a The input value
/// \returns the relu saturation result
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
ValueT>
relu(const ValueT a) {
if (!detail::isnan(a) && a < ValueT(0))
template <typename ValueT> inline ValueT relu(const ValueT a) {
if constexpr (std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>)
if (!detail::isnan(a) && a < ValueT(0))
return ValueT(0);
if (a < ValueT(0))
return ValueT(0);
return a;
}
template <class ValueT, int NumElements>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
sycl::vec<ValueT, NumElements>>
inline sycl::vec<ValueT, NumElements>
relu(const sycl::vec<ValueT, NumElements> a) {
sycl::vec<T, NumElements> ret;
sycl::vec<ValueT, NumElements> ret;
for (int i = 0; i < NumElements; ++i)
ret[i] = relu(a[i]);
return ret;
Expand Down Expand Up @@ -1029,13 +1027,12 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
const auto v1 = sycl::vec<unsigned, 1>(a).as<VecT>();
const auto v2 = sycl::vec<unsigned, 1>(b).as<VecT>();
const auto v3 = sycl::vec<unsigned, 1>(c).as<VecT>();
auto temp =
auto v4 =
detail::vectorized_binary<VecT, BinaryOperation1>()(v1, v2, binary_op1);
temp =
detail::vectorized_binary<VecT, BinaryOperation2>()(temp, v3, binary_op2);
v4 = detail::vectorized_binary<VecT, BinaryOperation2>()(v4, v3, binary_op2);
if (need_relu)
temp = relu(temp);
return temp.template as<sycl::vec<unsigned, 1>>();
v4 = relu(v4);
return v4.template as<sycl::vec<unsigned, 1>>();
}

/// Compute vectorized binary operation value with pred for two values, with
Expand Down
27 changes: 22 additions & 5 deletions sycl/test-e2e/syclcompat/math/math_fixt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,27 @@ class BinaryOpTestLauncher : OpTestLauncher {
}

template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ResultT expected) {
void launch_test(ValueT op1, ValueU op2, ResultT expected,
bool need_relu = false) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, need_relu);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

CHECK(ResultT, res_h_, expected);
};
template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool *pred_hi,
bool *pred_lo) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, pred_hi,
pred_lo);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

Expand Down Expand Up @@ -195,7 +210,7 @@ class TernaryOpTestLauncher : OpTestLauncher {
protected:
ValueT *op1_;
ValueU *op2_;
ValueV *op2_;
ValueV *op3_;
ResultT res_h_, *res_;

public:
Expand Down Expand Up @@ -223,13 +238,15 @@ class TernaryOpTestLauncher : OpTestLauncher {
}

template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected) {
void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected,
bool need_relu = false) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::memcpy<ValueU>(op3_, &op3, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_,
need_relu);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

Expand Down
60 changes: 41 additions & 19 deletions sycl/test-e2e/syclcompat/math/math_vectorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,24 @@
#include "math_fixt.hpp"

template <typename BinaryOp, typename ValueT>
void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r) {
void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r,
bool need_relu) {
unsigned ua = static_cast<unsigned>(*a);
unsigned ub = static_cast<unsigned>(*b);
*r = syclcompat::vectorized_binary<sycl::short2>(ua, ub, BinaryOp());
*r = syclcompat::vectorized_binary<sycl::short2>(ua, ub, BinaryOp(),
need_relu);
}

template <typename BinaryOp, typename ValueT>
void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected) {
void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected,
bool need_relu = false) {
std::cout << __PRETTY_FUNCTION__ << std::endl;
constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{1};

BinaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<vectorized_binary_kernel<BinaryOp, ValueT>>(
op1, op2, expected);
op1, op2, expected, need_relu);
}

template <typename UnaryOp, typename ValueT>
Expand All @@ -66,7 +69,8 @@ void test_vectorized_unary(ValueT op1, unsigned expected) {
}

template <typename ValueT>
void vectorized_sum_abs_diff_kernel(ValueT *a, ValueT *b, unsigned *r) {
void vectorized_sum_abs_diff_kernel(ValueT *a, ValueT *b, unsigned *r,
bool need_relu) {
unsigned ua = static_cast<unsigned>(*a);
unsigned ub = static_cast<unsigned>(*b);

Expand Down Expand Up @@ -103,7 +107,7 @@ void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3,

TernaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<
vectorized_binary_kernel<BinaryOp1, BinaryOp2, ValueT>>(
vectorized_ternary_kernel<BinaryOp1, BinaryOp2, ValueT>>(
op1, op2, op3, expected, need_relu);
}

Expand All @@ -117,15 +121,15 @@ void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r,
pred_lo);
}

template <typename ValueT>
template <typename BinaryOp, typename ValueT>
void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected,
bool *pred_hi, bool *pred_lo) {
std::cout << __PRETTY_FUNCTION__ << std::endl;
constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{1};

BinaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<vectorized_with_pred_kernel<ValueT>>(
.template launch_test<vectorized_with_pred_kernel<BinaryOp, ValueT>>(
op1, op2, expected, pred_hi, pred_lo);
}

Expand All @@ -144,29 +148,47 @@ int main() {
0x00000000);
test_vectorized_binary<syclcompat::sub_sat, uint32_t>(0xFFFB0005, 0x00030008,
0xFFF8FFFD);
test_vectorized_binary<syclcompat::abs_diff, uint32_t>(0x00010002, 0x00040002,
0x00030000, true);
test_vectorized_binary<syclcompat::add_sat, uint32_t>(0x00020002, 0xFFFDFFFF,
0x00000001, true);
test_vectorized_binary<syclcompat::rhadd, uint32_t>(0x00010008, 0x00020001,
0x00020005, true);
test_vectorized_binary<syclcompat::hadd, uint32_t>(0x00010003, 0x00020005,
0x00010004, true);
test_vectorized_binary<syclcompat::maximum, uint32_t>(0x0FFF0000, 0x00000FFF,
0x0FFF0FFF, true);
test_vectorized_binary<syclcompat::minimum, uint32_t>(0x0FFF0000, 0x00000FFF,
0x00000000, true);
test_vectorized_binary<syclcompat::sub_sat, uint32_t>(0xFFFB0005, 0x00030008,
0x00000000, true);
test_vectorized_unary<syclcompat::abs, uint32_t>(0xFFFBFFFD, 0x00050003);
test_vectorized_sum_abs_diff<uint32_t>(0x00010002, 0x00040002, 0x00000003);
test_vectorized_ternary<std::plus<>, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00080004);
test_vectorized_ternary<std::plus<>, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
0x00010002, 0x00040002, 0x00080004, 0x00080004, true);
test_vectorized_ternary<std::plus<>, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00050004);
test_vectorized_ternary<std::plus<>, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
0x00010002, 0x00040002, 0x00080004, 0x00050004, true);
test_vectorized_ternary<syclcompat::maximum, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00080004);
test_vectorized_ternary<syclcompat::maximum, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
0x00010002, 0x00040002, 0x00080004, 0x00080004, true);
test_vectorized_ternary<syclcompat::minimum, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00010002);
test_vectorized_ternary<syclcompat::minimum, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
bool pred_hi, bool pred_lo;
0x00010002, 0x00040002, 0x00080004, 0x00010002, true);
bool pred_hi, pred_lo;
test_vectorized_with_pred<syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo);
0x00010002, 0x00040002, 0x00040002, &pred_hi, &pred_lo);
assert(pred_hi == true);
assert(pred_lo == true);
test_vectorized_with_pred<syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo);
0x00010002, 0x00040002, 0x00010002, &pred_hi, &pred_lo);
assert(pred_hi == true);
assert(pred_lo == true);

return 0;
}

0 comments on commit 2944ced

Please sign in to comment.