Skip to content

Commit

Permalink
[SYCL][COMPAT] Add vectorized_ternary and vectorized_with_pred.
Browse files Browse the repository at this point in the history
Add vectorized_ternary and vectorized_with_pred.
Update relu and vectorized_binary.

Signed-off-by: Tang, Jiajun jiajun.tang@intel.com
  • Loading branch information
tangjj11 committed Sep 30, 2024
1 parent 8fc9aa5 commit beb04ba
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 5 deletions.
79 changes: 74 additions & 5 deletions sycl/include/syclcompat/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,12 +803,15 @@ relu(const ValueT a) {
return ValueT(0);
return a;
}
template <class ValueT>
template <class ValueT, int NumElements>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
sycl::vec<ValueT, 2>>
relu(const sycl::vec<ValueT, 2> a) {
return {relu(a[0]), relu(a[1])};
sycl::vec<ValueT, NumElements>>
relu(const sycl::vec<ValueT, NumElements> a) {
sycl::vec<T, NumElements> ret;
for (int i = 0; i < NumElements; ++i)
ret[i] = relu(a[i]);
return ret;
}
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
Expand Down Expand Up @@ -932,6 +935,10 @@ struct maximum {
auto operator()(const ValueT x, const ValueT y) const {
return sycl::max(x, y);
}
template <typename T>
auto operator()(const T x, const T y, bool *pred) const {
return (x >= y) ? ((*pred = true), x) : ((*pred = false), y);
}
};

/// A sycl::min wrapper functors.
Expand All @@ -940,6 +947,10 @@ struct minimum {
auto operator()(const ValueT x, const ValueT y) const {
return sycl::min(x, y);
}
template <typename T>
auto operator()(const T x, const T y, bool *pred) const {
return (x <= y) ? ((*pred = true), x) : ((*pred = false), y);
}
};

/// A sycl::sub_sat wrapper functors.
Expand Down Expand Up @@ -979,19 +990,77 @@ struct average {
/// \tparam [in] BinaryOperation The binary operation class
/// \param [in] a The first value
/// \param [in] b The second value
/// \param [in] binary_op The operation to do with the two values
/// \param [in] need_relu Whether the result need relu saturation
/// \returns The vectorized binary operation value of the two values
template <typename VecT, class BinaryOperation>
inline unsigned vectorized_binary(unsigned a, unsigned b,
const BinaryOperation binary_op) {
const BinaryOperation binary_op,
bool need_relu = false) {
sycl::vec<unsigned, 1> v0{a}, v1{b};
auto v2 = v0.as<VecT>();
auto v3 = v1.as<VecT>();
auto v4 =
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
if (need_relu)
v4 = relu(v4);
v0 = v4.template as<sycl::vec<unsigned, 1>>();
return v0;
}

/// Compute two vectorized binary operation value with pred for three values,
/// with each value treated as a 2 \p T type elements vector type.
///
/// \tparam [in] VecT The type of the vector
/// \tparam [in] BinaryOperation1 The first binary operation class
/// \tparam [in] BinaryOperation2 The second binary operation class
/// \param [in] a The first value
/// \param [in] b The second value
/// \param [in] c The third value
/// \param [in] binary_op1 The first operation to do with the first two values
/// \param [in] binary_op2 The second operation to do with the third values
/// \param [in] need_relu Whether the result need relu saturation
/// \returns The two vectorized binary operation value of the three values
template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
const BinaryOperation1 binary_op1,
const BinaryOperation2 binary_op2,
bool need_relu = false) {
const auto v1 = sycl::vec<unsigned, 1>(a).as<VecT>();
const auto v2 = sycl::vec<unsigned, 1>(b).as<VecT>();
const auto v3 = sycl::vec<unsigned, 1>(c).as<VecT>();
auto temp =
detail::vectorized_binary<VecT, BinaryOperation1>()(v1, v2, binary_op1);
temp =
detail::vectorized_binary<VecT, BinaryOperation2>()(temp, v3, binary_op2);
if (need_relu)
temp = relu(temp);
return temp.template as<sycl::vec<unsigned, 1>>();
}

/// Compute vectorized binary operation value with pred for two values, with
/// each value treated as a 2 \p T type elements vector type.
///
/// \tparam [in] T The type of elements type of the vector
/// \tparam [in] BinaryOperation The binary operation class
/// \param [in] a The first value
/// \param [in] b The second value
/// \param [in] binary_op The operation with pred to do with the two values
/// \param [in] pred_hi The pred pointer that pass into high halfword operation
/// \param [in] pred_lo The pred pointer that pass into low halfword operation
/// \returns The vectorized binary operation value of the two values
template <typename T, typename BinaryOperation>
inline unsigned vectorized_with_pred(unsigned a, unsigned b,
const BinaryOperation binary_op,
bool *pred_hi, bool *pred_lo) {
auto v1 = sycl::vec<unsigned, 1>(a).as<sycl::vec<T, 2>>();
auto v2 = sycl::vec<unsigned, 1>(b).as<sycl::vec<T, 2>>();
sycl::vec<T, 2> ret;
ret[0] = binary_op(v1[0], v2[0], pred_lo);
ret[1] = binary_op(v1[1], v2[1], pred_hi);
return ret.template as<sycl::vec<unsigned, 1>>();
}

template <typename T1, typename T2>
using dot_product_acc_t =
std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
Expand Down
49 changes: 49 additions & 0 deletions sycl/test-e2e/syclcompat/math/math_fixt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,52 @@ class UnaryOpTestLauncher : OpTestLauncher {
CHECK(ResultT, res_h_, expected);
}
};

// Templated ResultT to support both arithmetic and boolean operators
template <typename ValueT, typename ValueU, typename ValueV,
typename ResultT = std::common_type_t<ValueT, ValueU, ValueV>>
class TernaryOpTestLauncher : OpTestLauncher {
protected:
ValueT *op1_;
ValueU *op2_;
ValueV *op2_;
ResultT res_h_, *res_;

public:
TernaryOpTestLauncher(const syclcompat::dim3 &grid,
const syclcompat::dim3 &threads,
const size_t data_size = 1)
: OpTestLauncher{
grid, threads, data_size,
should_skip<ValueT>()(syclcompat::get_current_device())} {
if (skip_)
return;
op1_ = syclcompat::malloc<ValueT>(data_size);
op2_ = syclcompat::malloc<ValueU>(data_size);
op3_ = syclcompat::malloc<ValueU>(data_size);
res_ = syclcompat::malloc<ResultT>(data_size);
};

virtual ~TernaryOpTestLauncher() {
if (skip_)
return;
syclcompat::free(op1_);
syclcompat::free(op2_);
syclcompat::free(op3_);
syclcompat::free(res_);
}

template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::memcpy<ValueU>(op3_, &op3, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

CHECK(ResultT, res_h_, expected);
};
};
66 changes: 66 additions & 0 deletions sycl/test-e2e/syclcompat/math/math_vectorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,51 @@ void test_vectorized_sum_abs_diff(ValueT op1, ValueT op2, unsigned expected) {
expected);
}

template <typename BinaryOp1, typename BinaryOp2, typename ValueT>
void vectorized_ternary_kernel(ValueT *a, ValueT *b, ValueT *c, unsigned *r,
bool need_relu) {
unsigned ua = static_cast<unsigned>(*a);
unsigned ub = static_cast<unsigned>(*b);
unsigned uc = static_cast<unsigned>(*c);
*r = syclcompat::vectorized_ternary<sycl::short2>(ua, ub, uc, BinaryOp1(),
BinaryOp2(), need_relu);
}

template <typename BinaryOp1, typename BinaryOp2, typename ValueT>
void test_vectorized_ternary(ValueT op1, ValueT op2, ValueT op3,
unsigned expected, bool need_relu = false) {
std::cout << __PRETTY_FUNCTION__ << std::endl;
constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{1};

TernaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<
vectorized_binary_kernel<BinaryOp1, BinaryOp2, ValueT>>(
op1, op2, op3, expected, need_relu);
}

template <typename BinaryOp, typename ValueT>
void vectorized_with_pred_kernel(ValueT *a, ValueT *b, unsigned *r,
bool *pred_hi, bool *pred_lo) {
unsigned ua = static_cast<unsigned>(*a);
unsigned ub = static_cast<unsigned>(*b);

*r = syclcompat::vectorized_with_pred<short>(ua, ub, BinaryOp(), pred_hi,
pred_lo);
}

template <typename ValueT>
void test_vectorized_with_pred(ValueT op1, ValueT op2, unsigned expected,
bool *pred_hi, bool *pred_lo) {
std::cout << __PRETTY_FUNCTION__ << std::endl;
constexpr syclcompat::dim3 grid{1};
constexpr syclcompat::dim3 threads{1};

BinaryOpTestLauncher<ValueT, ValueT, unsigned>(grid, threads)
.template launch_test<vectorized_with_pred_kernel<ValueT>>(
op1, op2, expected, pred_hi, pred_lo);
}

int main() {
test_vectorized_binary<syclcompat::abs_diff, uint32_t>(0x00010002, 0x00040002,
0x00030000);
Expand All @@ -101,6 +146,27 @@ int main() {
0xFFF8FFFD);
test_vectorized_unary<syclcompat::abs, uint32_t>(0xFFFBFFFD, 0x00050003);
test_vectorized_sum_abs_diff<uint32_t>(0x00010002, 0x00040002, 0x00000003);
test_vectorized_ternary<std::plus<>, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
test_vectorized_ternary<std::plus<>, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
test_vectorized_ternary<std::plus<>, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
test_vectorized_ternary<std::plus<>, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
test_vectorized_ternary<syclcompat::maximum, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
test_vectorized_ternary<syclcompat::maximum, syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
test_vectorized_ternary<syclcompat::minimum, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000);
test_vectorized_ternary<syclcompat::minimum, syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00080004, 0x00030000, true);
bool pred_hi, bool pred_lo;
test_vectorized_with_pred<syclcompat::maximum, uint32_t>(
0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo);
test_vectorized_with_pred<syclcompat::minimum, uint32_t>(
0x00010002, 0x00040002, 0x00030000, &pred_hi, &pred_lo);

return 0;
}

0 comments on commit beb04ba

Please sign in to comment.