Skip to content

Commit

Permalink
Fix comment.
Browse files Browse the repository at this point in the history
  • Loading branch information
tangjj11 committed Oct 8, 2024
1 parent beb04ba commit 9884622
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 58 deletions.
37 changes: 25 additions & 12 deletions sycl/doc/syclcompat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1780,17 +1780,11 @@ template <typename ValueT, typename ValueU>
inline typename std::enable_if_t<!std::is_floating_point_v<ValueT>, double>
pow(const ValueT a, const ValueU b);
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
ValueT>
relu(const ValueT a);
template <typename ValueT> inline ValueT relu(const ValueT a);
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
sycl::vec<ValueT, 2>>
relu(const sycl::vec<ValueT, 2> a);
template <class ValueT, int NumElements>
inline sycl::vec<ValueT, NumElements>
relu(const sycl::vec<ValueT, NumElements> a);
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
Expand Down Expand Up @@ -1893,9 +1887,12 @@ inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,

`vectorized_binary` computes the `BinaryOperation` for two operands,
with each value treated as a vector type. `vectorized_unary` offers the same
interface for operations with a single operand.
interface for operations with a single operand. `vectorized_ternary` offers the
interface for three operands with two `BinaryOperation`.
The implemented `BinaryOperation`s are `abs_diff`, `add_sat`, `rhadd`, `hadd`,
`maximum`, `minimum`, and `sub_sat`.
And the `vectorized_with_pred` offers the `BinaryOperation` for two operands,
meanwihle provides the pred of high/low halfword operation.

```cpp
namespace syclcompat {
Expand All @@ -1910,7 +1907,19 @@ struct abs {

template <typename VecT, class BinaryOperation>
inline unsigned vectorized_binary(unsigned a, unsigned b,
const BinaryOperation binary_op);
const BinaryOperation binary_op,
bool need_relu = false);

template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
const BinaryOperation1 binary_op1,
const BinaryOperation2 binary_op2,
bool need_relu = false);

template <typename ValueT, typename BinaryOperation>
inline unsigned vectorized_with_pred(unsigned a, unsigned b,
const BinaryOperation binary_op,
bool *pred_hi, bool *pred_lo);

// A sycl::abs_diff wrapper functor.
struct abs_diff {
Expand All @@ -1936,11 +1945,15 @@ struct hadd {
struct maximum {
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y) const;
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const;
};
// A sycl::min wrapper functor.
struct minimum {
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y) const;
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const;
};
// A sycl::sub_sat wrapper functor.
struct sub_sat {
Expand Down
43 changes: 20 additions & 23 deletions sycl/include/syclcompat/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,21 +794,19 @@ pow(const ValueT a, const ValueU b) {
/// Performs relu saturation.
/// \param [in] a The input value
/// \returns the relu saturation result
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
ValueT>
relu(const ValueT a) {
if (!detail::isnan(a) && a < ValueT(0))
template <typename ValueT> inline ValueT relu(const ValueT a) {
if constexpr (std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>)
if (!detail::isnan(a) && a < ValueT(0))
return ValueT(0);
if (a < ValueT(0))
return ValueT(0);
return a;
}
template <class ValueT, int NumElements>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
sycl::vec<ValueT, NumElements>>
inline sycl::vec<ValueT, NumElements>
relu(const sycl::vec<ValueT, NumElements> a) {
sycl::vec<T, NumElements> ret;
sycl::vec<ValueT, NumElements> ret;
for (int i = 0; i < NumElements; ++i)
ret[i] = relu(a[i]);
return ret;
Expand Down Expand Up @@ -935,8 +933,8 @@ struct maximum {
auto operator()(const ValueT x, const ValueT y) const {
return sycl::max(x, y);
}
template <typename T>
auto operator()(const T x, const T y, bool *pred) const {
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const {
return (x >= y) ? ((*pred = true), x) : ((*pred = false), y);
}
};
Expand All @@ -947,8 +945,8 @@ struct minimum {
auto operator()(const ValueT x, const ValueT y) const {
return sycl::min(x, y);
}
template <typename T>
auto operator()(const T x, const T y, bool *pred) const {
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const {
return (x <= y) ? ((*pred = true), x) : ((*pred = false), y);
}
};
Expand Down Expand Up @@ -1029,13 +1027,12 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
const auto v1 = sycl::vec<unsigned, 1>(a).as<VecT>();
const auto v2 = sycl::vec<unsigned, 1>(b).as<VecT>();
const auto v3 = sycl::vec<unsigned, 1>(c).as<VecT>();
auto temp =
auto v4 =
detail::vectorized_binary<VecT, BinaryOperation1>()(v1, v2, binary_op1);
temp =
detail::vectorized_binary<VecT, BinaryOperation2>()(temp, v3, binary_op2);
v4 = detail::vectorized_binary<VecT, BinaryOperation2>()(v4, v3, binary_op2);
if (need_relu)
temp = relu(temp);
return temp.template as<sycl::vec<unsigned, 1>>();
v4 = relu(v4);
return v4.template as<sycl::vec<unsigned, 1>>();
}

/// Compute vectorized binary operation value with pred for two values, with
Expand All @@ -1049,13 +1046,13 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
/// \param [in] pred_hi The pred pointer that pass into high halfword operation
/// \param [in] pred_lo The pred pointer that pass into low halfword operation
/// \returns The vectorized binary operation value of the two values
template <typename T, typename BinaryOperation>
template <typename ValueT, typename BinaryOperation>
inline unsigned vectorized_with_pred(unsigned a, unsigned b,
const BinaryOperation binary_op,
bool *pred_hi, bool *pred_lo) {
auto v1 = sycl::vec<unsigned, 1>(a).as<sycl::vec<T, 2>>();
auto v2 = sycl::vec<unsigned, 1>(b).as<sycl::vec<T, 2>>();
sycl::vec<T, 2> ret;
auto v1 = sycl::vec<unsigned, 1>(a).as<sycl::vec<ValueT, 2>>();
auto v2 = sycl::vec<unsigned, 1>(b).as<sycl::vec<ValueT, 2>>();
sycl::vec<ValueT, 2> ret;
ret[0] = binary_op(v1[0], v2[0], pred_lo);
ret[1] = binary_op(v1[1], v2[1], pred_hi);
return ret.template as<sycl::vec<unsigned, 1>>();
Expand Down
45 changes: 42 additions & 3 deletions sycl/test-e2e/syclcompat/math/math_fixt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
ValueT *op1_;
ValueU *op2_;
ResultT res_h_, *res_;
bool *res_hi_;
bool *res_lo_;

public:
BinaryOpTestLauncher(const syclcompat::dim3 &grid,
Expand All @@ -126,6 +128,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
op1_ = syclcompat::malloc<ValueT>(data_size);
op2_ = syclcompat::malloc<ValueU>(data_size);
res_ = syclcompat::malloc<ResultT>(data_size);
res_hi_ = syclcompat::malloc<bool>(1);
res_lo_ = syclcompat::malloc<bool>(1);
};

virtual ~BinaryOpTestLauncher() {
Expand All @@ -134,6 +138,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
syclcompat::free(op1_);
syclcompat::free(op2_);
syclcompat::free(res_);
syclcompat::free(res_hi_);
syclcompat::free(res_lo_);
}

template <auto Kernel>
Expand All @@ -148,6 +154,37 @@ class BinaryOpTestLauncher : OpTestLauncher {

CHECK(ResultT, res_h_, expected);
};
template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool need_relu) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, need_relu);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

CHECK(ResultT, res_h_, expected);
};
template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool expected_hi,
bool expected_lo) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, res_hi_,
res_lo_);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);
bool res_hi_h_, res_lo_h_;
syclcompat::memcpy<bool>(&res_hi_h_, res_hi_, 1);
syclcompat::memcpy<bool>(&res_lo_h_, res_lo_, 1);

CHECK(ResultT, res_h_, expected);
assert(res_hi_h_ == expected_hi);
assert(res_lo_h_ == expected_lo);
};
};

template <typename ValueT, typename ResultT = ValueT>
Expand Down Expand Up @@ -195,7 +232,7 @@ class TernaryOpTestLauncher : OpTestLauncher {
protected:
ValueT *op1_;
ValueU *op2_;
ValueV *op2_;
ValueV *op3_;
ResultT res_h_, *res_;

public:
Expand Down Expand Up @@ -223,13 +260,15 @@ class TernaryOpTestLauncher : OpTestLauncher {
}

template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected) {
void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected,
bool need_relu = false) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::memcpy<ValueU>(op3_, &op3, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_,
need_relu);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

Expand Down
56 changes: 36 additions & 20 deletions sycl/test-e2e/syclcompat/math/math_vectorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,24 @@
#include "math_fixt.hpp"

template <typename BinaryOp, typename ValueT>
void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r) {
void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r,
bool need_relu) {
unsigned ua = static_cast<unsigned>(*a);
unsigned ub = static_cast<unsigned>(*b);
*r = syclcompat::vectorized_binary<sycl::short2>(ua, ub, BinaryOp());
*r = syclcompat::vectorized_binary<sycl::short2>(ua, ub, BinaryOp(),
need_relu);
}

template <typename BinaryOp, typename ValueT>
void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected) {
void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected,
bool need_relu = false) {
std::cout << __PRETTY_FUNCTION__ << std::endl;
constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{1};

BinaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<vectorized_binary_kernel<BinaryOp, ValueT>>(
op1, op2, expected);
op1, op2, expected, need_relu);
}

template <typename UnaryOp, typename ValueT>
Expand Down Expand Up @@ -103,7 +106,7 @@ void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3,

TernaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<
vectorized_binary_kernel<BinaryOp1, BinaryOp2, ValueT>>(
vectorized_ternary_kernel<BinaryOp1, BinaryOp2, ValueT>>(
op1, op2, op3, expected, need_relu);
}

Expand All @@ -117,16 +120,16 @@ void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r,
pred_lo);
}

template <typename ValueT>
template <typename BinaryOp, typename ValueT>
void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected,
bool *pred_hi, bool *pred_lo) {
bool expected_hi, bool expected_lo) {
std::cout << __PRETTY_FUNCTION__ << std::endl;
constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{1};

BinaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<vectorized_with_pred_kernel<ValueT>>(
op1, op2, expected, pred_hi, pred_lo);
.template launch_test<vectorized_with_pred_kernel<BinaryOp, ValueT>>(
op1, op2, expected, expected_hi, expected_lo);
}

int main() {
Expand All @@ -144,29 +147,42 @@ int main() {
0x00000000);
test_vectorized_binary<syclcompat::sub_sat, uint32_t>(0xFFFB0005, 0x00030008,
0xFFF8FFFD);
test_vectorized_binary<syclcompat::abs_diff, uint32_t>(0x00010002, 0x00040002,
0x00030000, true);
test_vectorized_binary<syclcompat::add_sat, uint32_t>(0x00020002, 0xFFFDFFFF,
0x00000001, true);
test_vectorized_binary<syclcompat::rhadd, uint32_t>(0x00010008, 0x00020001,
0x00020005, true);
test_vectorized_binary<syclcompat::hadd, uint32_t>(0x00010003, 0x00020005,
0x00010004, true);
test_vectorized_binary<syclcompat::maximum, uint32_t>(0x0FFF0000, 0x00000FFF,
0x0FFF0FFF, true);
test_vectorized_binary<syclcompat::minimum, uint32_t>(0x0FFF0000, 0x00000FFF,
0x00000000, true);
test_vectorized_binary<syclcompat::sub_sat, uint32_t>(0xFFFB0005, 0x00030008,
0x00000000, true);
test_vectorized_unary<syclcompat::abs, uint32_t>(0xFFFBFFFD, 0x00050003);
test_vectorized_sum_abs_diff<uint32_t>(0x00010002, 0x00040002, 0x00000003);
test_vectorized_ternary<std::plus<>, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00080004);
test_vectorized_ternary<std::plus<>, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
0x00010002, 0x00040002, 0x00080004, 0x00080004, true);
test_vectorized_ternary<std::plus<>, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00050004);
test_vectorized_ternary<std::plus<>, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
0x00010002, 0x00040002, 0x00080004, 0x00050004, true);
test_vectorized_ternary<syclcompat::maximum, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00080004);
test_vectorized_ternary<syclcompat::maximum, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
0x00010002, 0x00040002, 0x00080004, 0x00080004, true);
test_vectorized_ternary<syclcompat::minimum, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
0x00010002, 0x00040002, 0x00080004, 0x00010002);
test_vectorized_ternary<syclcompat::minimum, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
bool pred_hi, bool pred_lo;
0x00010002, 0x00040002, 0x00080004, 0x00010002, true);
test_vectorized_with_pred<syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo);
0x00010002, 0x00040002, 0x00040002, true, true);
test_vectorized_with_pred<syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo);
0x00010002, 0x00040002, 0x00010002, true, true);

return 0;
}

0 comments on commit 9884622

Please sign in to comment.