diff --git a/sycl/doc/syclcompat/README.md b/sycl/doc/syclcompat/README.md index 6dd8708afeb62..38d5946411209 100644 --- a/sycl/doc/syclcompat/README.md +++ b/sycl/doc/syclcompat/README.md @@ -1780,17 +1780,11 @@ template inline typename std::enable_if_t, double> pow(const ValueT a, const ValueU b); -template -inline std::enable_if_t || - std::is_same_v, - ValueT> -relu(const ValueT a); +template inline ValueT relu(const ValueT a); -template -inline std::enable_if_t || - std::is_same_v, - sycl::vec> -relu(const sycl::vec a); +template +inline sycl::vec +relu(const sycl::vec a); template inline std::enable_if_t || @@ -1893,9 +1887,12 @@ inline dot_product_acc_t dp4a(T1 a, T2 b, `vectorized_binary` computes the `BinaryOperation` for two operands, with each value treated as a vector type. `vectorized_unary` offers the same -interface for operations with a single operand. +interface for operations with a single operand. `vectorized_ternary` offers the +interface for three operands with two `BinaryOperation`. The implemented `BinaryOperation`s are `abs_diff`, `add_sat`, `rhadd`, `hadd`, `maximum`, `minimum`, and `sub_sat`. +And the `vectorized_with_pred` offers the `BinaryOperation` for two operands, +meanwihle provides the pred of high/low halfword operation. ```cpp namespace syclcompat { @@ -1910,7 +1907,19 @@ struct abs { template inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op); + const BinaryOperation binary_op, + bool need_relu = false); + +template +inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, + const BinaryOperation1 binary_op1, + const BinaryOperation2 binary_op2, + bool need_relu = false); + +template +inline unsigned vectorized_with_pred(unsigned a, unsigned b, + const BinaryOperation binary_op, + bool *pred_hi, bool *pred_lo); // A sycl::abs_diff wrapper functor. struct abs_diff { @@ -1936,11 +1945,15 @@ struct hadd { struct maximum { template auto operator()(const ValueT x, const ValueT y) const; + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const; }; // A sycl::min wrapper functor. struct minimum { template auto operator()(const ValueT x, const ValueT y) const; + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const; }; // A sycl::sub_sat wrapper functor. struct sub_sat { diff --git a/sycl/include/syclcompat/math.hpp b/sycl/include/syclcompat/math.hpp index 8b06ff376b67c..808327f4873ac 100644 --- a/sycl/include/syclcompat/math.hpp +++ b/sycl/include/syclcompat/math.hpp @@ -794,21 +794,19 @@ pow(const ValueT a, const ValueU b) { /// Performs relu saturation. /// \param [in] a The input value /// \returns the relu saturation result -template -inline std::enable_if_t || - std::is_same_v, - ValueT> -relu(const ValueT a) { - if (!detail::isnan(a) && a < ValueT(0)) +template inline ValueT relu(const ValueT a) { + if constexpr (std::is_floating_point_v || + std::is_same_v) + if (!detail::isnan(a) && a < ValueT(0)) + return ValueT(0); + if (a < ValueT(0)) return ValueT(0); return a; } template -inline std::enable_if_t || - std::is_same_v, - sycl::vec> +inline sycl::vec relu(const sycl::vec a) { - sycl::vec ret; + sycl::vec ret; for (int i = 0; i < NumElements; ++i) ret[i] = relu(a[i]); return ret; @@ -935,8 +933,8 @@ struct maximum { auto operator()(const ValueT x, const ValueT y) const { return sycl::max(x, y); } - template - auto operator()(const T x, const T y, bool *pred) const { + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { return (x >= y) ? ((*pred = true), x) : ((*pred = false), y); } }; @@ -947,8 +945,8 @@ struct minimum { auto operator()(const ValueT x, const ValueT y) const { return sycl::min(x, y); } - template - auto operator()(const T x, const T y, bool *pred) const { + template + auto operator()(const ValueT x, const ValueT y, bool *pred) const { return (x <= y) ? ((*pred = true), x) : ((*pred = false), y); } }; @@ -1029,13 +1027,12 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, const auto v1 = sycl::vec(a).as(); const auto v2 = sycl::vec(b).as(); const auto v3 = sycl::vec(c).as(); - auto temp = + auto v4 = detail::vectorized_binary()(v1, v2, binary_op1); - temp = - detail::vectorized_binary()(temp, v3, binary_op2); + v4 = detail::vectorized_binary()(v4, v3, binary_op2); if (need_relu) - temp = relu(temp); - return temp.template as>(); + v4 = relu(v4); + return v4.template as>(); } /// Compute vectorized binary operation value with pred for two values, with @@ -1049,13 +1046,13 @@ inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c, /// \param [in] pred_hi The pred pointer that pass into high halfword operation /// \param [in] pred_lo The pred pointer that pass into low halfword operation /// \returns The vectorized binary operation value of the two values -template +template inline unsigned vectorized_with_pred(unsigned a, unsigned b, const BinaryOperation binary_op, bool *pred_hi, bool *pred_lo) { - auto v1 = sycl::vec(a).as>(); - auto v2 = sycl::vec(b).as>(); - sycl::vec ret; + auto v1 = sycl::vec(a).as>(); + auto v2 = sycl::vec(b).as>(); + sycl::vec ret; ret[0] = binary_op(v1[0], v2[0], pred_lo); ret[1] = binary_op(v1[1], v2[1], pred_hi); return ret.template as>(); diff --git a/sycl/test-e2e/syclcompat/math/math_fixt.hpp b/sycl/test-e2e/syclcompat/math/math_fixt.hpp index 7f6d9e9289230..fc1030c993e9d 100644 --- a/sycl/test-e2e/syclcompat/math/math_fixt.hpp +++ b/sycl/test-e2e/syclcompat/math/math_fixt.hpp @@ -113,6 +113,8 @@ class BinaryOpTestLauncher : OpTestLauncher { ValueT *op1_; ValueU *op2_; ResultT res_h_, *res_; + bool *res_hi_; + bool *res_lo_; public: BinaryOpTestLauncher(const syclcompat::dim3 &grid, @@ -126,6 +128,8 @@ class BinaryOpTestLauncher : OpTestLauncher { op1_ = syclcompat::malloc(data_size); op2_ = syclcompat::malloc(data_size); res_ = syclcompat::malloc(data_size); + res_hi_ = syclcompat::malloc(1); + res_lo_ = syclcompat::malloc(1); }; virtual ~BinaryOpTestLauncher() { @@ -134,6 +138,8 @@ class BinaryOpTestLauncher : OpTestLauncher { syclcompat::free(op1_); syclcompat::free(op2_); syclcompat::free(res_); + syclcompat::free(res_hi_); + syclcompat::free(res_lo_); } template @@ -148,6 +154,37 @@ class BinaryOpTestLauncher : OpTestLauncher { CHECK(ResultT, res_h_, expected); }; + template + void launch_test(ValueT op1, ValueU op2, ResultT expected, bool need_relu) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, res_, need_relu); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + + CHECK(ResultT, res_h_, expected); + }; + template + void launch_test(ValueT op1, ValueU op2, ResultT expected, bool expected_hi, + bool expected_lo) { + if (skip_) + return; + syclcompat::memcpy(op1_, &op1, data_size_); + syclcompat::memcpy(op2_, &op2, data_size_); + syclcompat::launch(grid_, threads_, op1_, op2_, res_, res_hi_, + res_lo_); + syclcompat::wait(); + syclcompat::memcpy(&res_h_, res_, data_size_); + bool res_hi_h_, res_lo_h_; + syclcompat::memcpy(&res_hi_h_, res_hi_, 1); + syclcompat::memcpy(&res_lo_h_, res_lo_, 1); + + CHECK(ResultT, res_h_, expected); + assert(res_hi_h_ == expected_hi); + assert(res_lo_h_ == expected_lo); + }; }; template @@ -195,7 +232,7 @@ class TernaryOpTestLauncher : OpTestLauncher { protected: ValueT *op1_; ValueU *op2_; - ValueV *op2_; + ValueV *op3_; ResultT res_h_, *res_; public: @@ -223,13 +260,15 @@ class TernaryOpTestLauncher : OpTestLauncher { } template - void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected) { + void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected, + bool need_relu = false) { if (skip_) return; syclcompat::memcpy(op1_, &op1, data_size_); syclcompat::memcpy(op2_, &op2, data_size_); syclcompat::memcpy(op3_, &op3, data_size_); - syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_); + syclcompat::launch(grid_, threads_, op1_, op2_, op3_, res_, + need_relu); syclcompat::wait(); syclcompat::memcpy(&res_h_, res_, data_size_); diff --git a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp index 16a4bf92a70fc..315dd27369969 100644 --- a/sycl/test-e2e/syclcompat/math/math_vectorized.cpp +++ b/sycl/test-e2e/syclcompat/math/math_vectorized.cpp @@ -31,21 +31,24 @@ #include "math_fixt.hpp" template -void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r) { +void vectorized_binary_kernel(ValueT *a, ValueT *b, unsigned *r, + bool need_relu) { unsigned ua = static_cast(*a); unsigned ub = static_cast(*b); - *r = syclcompat::vectorized_binary(ua, ub, BinaryOp()); + *r = syclcompat::vectorized_binary(ua, ub, BinaryOp(), + need_relu); } template -void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected) { +void test_vectorized_binary(ValueT op1, ValueT op2, unsigned expected, + bool need_relu = false) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; BinaryOpTestLauncher(grid, threads) .template launch_test>( - op1, op2, expected); + op1, op2, expected, need_relu); } template @@ -103,7 +106,7 @@ void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3, TernaryOpTestLauncher(grid, threads) .template launch_test< - vectorized_binary_kernel>( + vectorized_ternary_kernel>( op1, op2, op3, expected, need_relu); } @@ -117,16 +120,16 @@ void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r, pred_lo); } -template +template void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected, - bool *pred_hi, bool *pred_lo) { + bool expected_hi, bool expected_lo) { std::cout << __PRETTY_FUNCTION__ << std::endl; constexpr syclcompat::dim3 grid{1}; constexpr syclcompat::dim3 threads{1}; BinaryOpTestLauncher(grid, threads) - .template launch_test>( - op1, op2, expected, pred_hi, pred_lo); + .template launch_test>( + op1, op2, expected, expected_hi, expected_lo); } int main() { @@ -144,29 +147,42 @@ int main() { 0x00000000); test_vectorized_binary(0xFFFB0005, 0x00030008, 0xFFF8FFFD); + test_vectorized_binary(0x00010002, 0x00040002, + 0x00030000, true); + test_vectorized_binary(0x00020002, 0xFFFDFFFF, + 0x00000001, true); + test_vectorized_binary(0x00010008, 0x00020001, + 0x00020005, true); + test_vectorized_binary(0x00010003, 0x00020005, + 0x00010004, true); + test_vectorized_binary(0x0FFF0000, 0x00000FFF, + 0x0FFF0FFF, true); + test_vectorized_binary(0x0FFF0000, 0x00000FFF, + 0x00000000, true); + test_vectorized_binary(0xFFFB0005, 0x00030008, + 0x00000000, true); test_vectorized_unary(0xFFFBFFFD, 0x00050003); test_vectorized_sum_abs_diff(0x00010002, 0x00040002, 0x00000003); test_vectorized_ternary, syclcompat::maximum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00080004); test_vectorized_ternary, syclcompat::maximum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); test_vectorized_ternary, syclcompat::minimum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00050004); test_vectorized_ternary, syclcompat::minimum, uint32_t>( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + 0x00010002, 0x00040002, 0x00080004, 0x00050004, true); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00080004); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); + 0x00010002, 0x00040002, 0x00080004, 0x00080004, true); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000); + 0x00010002, 0x00040002, 0x00080004, 0x00010002); test_vectorized_ternary( - 0x00010002, 0x00040002, 0x00080004, 0x00030000, true); - bool pred_hi, bool pred_lo; + 0x00010002, 0x00040002, 0x00080004, 0x00010002, true); test_vectorized_with_pred( - 0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo); + 0x00010002, 0x00040002, 0x00040002, true, true); test_vectorized_with_pred( - 0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo); + 0x00010002, 0x00040002, 0x00010002, true, true); return 0; }