Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][COMPAT] Add vectorized_ternary and vectorized_with_pred. #15550

Open
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions sycl/doc/syclcompat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1780,17 +1780,11 @@ template <typename ValueT, typename ValueU>
inline typename std::enable_if_t<!std::is_floating_point_v<ValueT>, double>
pow(const ValueT a, const ValueU b);
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
ValueT>
relu(const ValueT a);
template <typename ValueT> inline ValueT relu(const ValueT a);
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
sycl::vec<ValueT, 2>>
relu(const sycl::vec<ValueT, 2> a);
template <class ValueT, int NumElements>
inline sycl::vec<ValueT, NumElements>
relu(const sycl::vec<ValueT, NumElements> a);
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
Expand Down Expand Up @@ -1893,9 +1887,12 @@ inline dot_product_acc_t<T1, T2> dp4a(T1 a, T2 b,

`vectorized_binary` computes the `BinaryOperation` for two operands,
with each value treated as a vector type. `vectorized_unary` offers the same
interface for operations with a single operand.
interface for operations with a single operand. `vectorized_ternary` offers the
interface for three operands with two `BinaryOperation`.
The implemented `BinaryOperation`s are `abs_diff`, `add_sat`, `rhadd`, `hadd`,
`maximum`, `minimum`, and `sub_sat`.
And the `vectorized_with_pred` offers the `BinaryOperation` for two operands,
meanwihle provides the pred of high/low halfword operation.

```cpp
namespace syclcompat {
Expand All @@ -1910,7 +1907,19 @@ struct abs {

template <typename VecT, class BinaryOperation>
inline unsigned vectorized_binary(unsigned a, unsigned b,
const BinaryOperation binary_op);
const BinaryOperation binary_op,
bool need_relu = false);

template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
const BinaryOperation1 binary_op1,
const BinaryOperation2 binary_op2,
bool need_relu = false);

template <typename ValueT, typename BinaryOperation>
inline unsigned vectorized_with_pred(unsigned a, unsigned b,
const BinaryOperation binary_op,
bool *pred_hi, bool *pred_lo);

// A sycl::abs_diff wrapper functor.
struct abs_diff {
Expand All @@ -1936,11 +1945,15 @@ struct hadd {
struct maximum {
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y) const;
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const;
};
// A sycl::min wrapper functor.
struct minimum {
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y) const;
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const;
};
// A sycl::sub_sat wrapper functor.
struct sub_sat {
Expand Down
92 changes: 79 additions & 13 deletions sycl/include/syclcompat/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -794,21 +794,22 @@ pow(const ValueT a, const ValueU b) {
/// Performs relu saturation.
/// \param [in] a The input value
/// \returns the relu saturation result
template <typename ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
ValueT>
relu(const ValueT a) {
if (!detail::isnan(a) && a < ValueT(0))
template <typename ValueT> inline ValueT relu(const ValueT a) {
if constexpr (std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>)
if (!detail::isnan(a) && a < ValueT(0))
return ValueT(0);
if (a < ValueT(0))
return ValueT(0);
return a;
}
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
std::is_same_v<sycl::half, ValueT>,
sycl::vec<ValueT, 2>>
relu(const sycl::vec<ValueT, 2> a) {
return {relu(a[0]), relu(a[1])};
template <class ValueT, int NumElements>
inline sycl::vec<ValueT, NumElements>
relu(const sycl::vec<ValueT, NumElements> a) {
sycl::vec<ValueT, NumElements> ret;
for (int i = 0; i < NumElements; ++i)
ret[i] = relu(a[i]);
return ret;
}
template <class ValueT>
inline std::enable_if_t<std::is_floating_point_v<ValueT> ||
Expand Down Expand Up @@ -932,6 +933,10 @@ struct maximum {
auto operator()(const ValueT x, const ValueT y) const {
return sycl::max(x, y);
}
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const {
return (x >= y) ? ((*pred = true), x) : ((*pred = false), y);
}
};

/// A sycl::min wrapper functors.
Expand All @@ -940,6 +945,10 @@ struct minimum {
auto operator()(const ValueT x, const ValueT y) const {
return sycl::min(x, y);
}
template <typename ValueT>
auto operator()(const ValueT x, const ValueT y, bool *pred) const {
return (x <= y) ? ((*pred = true), x) : ((*pred = false), y);
}
};

/// A sycl::sub_sat wrapper functors.
Expand Down Expand Up @@ -979,19 +988,76 @@ struct average {
/// \tparam [in] BinaryOperation The binary operation class
/// \param [in] a The first value
/// \param [in] b The second value
/// \param [in] binary_op The operation to do with the two values
/// \param [in] need_relu Whether the result need relu saturation
/// \returns The vectorized binary operation value of the two values
template <typename VecT, class BinaryOperation>
inline unsigned vectorized_binary(unsigned a, unsigned b,
const BinaryOperation binary_op) {
const BinaryOperation binary_op,
bool need_relu = false) {
sycl::vec<unsigned, 1> v0{a}, v1{b};
auto v2 = v0.as<VecT>();
auto v3 = v1.as<VecT>();
auto v4 =
detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
if (need_relu)
v4 = relu(v4);
v0 = v4.template as<sycl::vec<unsigned, 1>>();
return v0;
}

/// Compute two vectorized binary operation value with pred for three values,
/// with each value treated as a 2 \p T type elements vector type.
///
/// \tparam [in] VecT The type of the vector
/// \tparam [in] BinaryOperation1 The first binary operation class
/// \tparam [in] BinaryOperation2 The second binary operation class
/// \param [in] a The first value
/// \param [in] b The second value
/// \param [in] c The third value
/// \param [in] binary_op1 The first operation to do with the first two values
/// \param [in] binary_op2 The second operation to do with the third values
/// \param [in] need_relu Whether the result need relu saturation
/// \returns The two vectorized binary operation value of the three values
template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
inline unsigned vectorized_ternary(unsigned a, unsigned b, unsigned c,
const BinaryOperation1 binary_op1,
const BinaryOperation2 binary_op2,
bool need_relu = false) {
const auto v1 = sycl::vec<unsigned, 1>(a).as<VecT>();
const auto v2 = sycl::vec<unsigned, 1>(b).as<VecT>();
const auto v3 = sycl::vec<unsigned, 1>(c).as<VecT>();
auto v4 =
detail::vectorized_binary<VecT, BinaryOperation1>()(v1, v2, binary_op1);
v4 = detail::vectorized_binary<VecT, BinaryOperation2>()(v4, v3, binary_op2);
if (need_relu)
v4 = relu(v4);
return v4.template as<sycl::vec<unsigned, 1>>();
}

/// Compute vectorized binary operation value with pred for two values, with
/// each value treated as a 2 \p T type elements vector type.
///
/// \tparam [in] T The type of elements type of the vector
/// \tparam [in] BinaryOperation The binary operation class
/// \param [in] a The first value
/// \param [in] b The second value
/// \param [in] binary_op The operation with pred to do with the two values
/// \param [in] pred_hi The pred pointer that pass into high halfword operation
/// \param [in] pred_lo The pred pointer that pass into low halfword operation
/// \returns The vectorized binary operation value of the two values
template <typename ValueT, typename BinaryOperation>
inline unsigned vectorized_with_pred(unsigned a, unsigned b,
const BinaryOperation binary_op,
bool *pred_hi, bool *pred_lo) {
auto v1 = sycl::vec<unsigned, 1>(a).as<sycl::vec<ValueT, 2>>();
auto v2 = sycl::vec<unsigned, 1>(b).as<sycl::vec<ValueT, 2>>();
sycl::vec<ValueT, 2> ret;
ret[0] = binary_op(v1[0], v2[0], pred_lo);
ret[1] = binary_op(v1[1], v2[1], pred_hi);
return ret.template as<sycl::vec<unsigned, 1>>();
}

template <typename T1, typename T2>
using dot_product_acc_t =
std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
Expand Down
88 changes: 88 additions & 0 deletions sycl/test-e2e/syclcompat/math/math_fixt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
ValueT *op1_;
ValueU *op2_;
ResultT res_h_, *res_;
bool *res_hi_;
bool *res_lo_;

public:
BinaryOpTestLauncher(const syclcompat::dim3 &grid,
Expand All @@ -126,6 +128,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
op1_ = syclcompat::malloc<ValueT>(data_size);
op2_ = syclcompat::malloc<ValueU>(data_size);
res_ = syclcompat::malloc<ResultT>(data_size);
res_hi_ = syclcompat::malloc<bool>(1);
res_lo_ = syclcompat::malloc<bool>(1);
};

virtual ~BinaryOpTestLauncher() {
Expand All @@ -134,6 +138,8 @@ class BinaryOpTestLauncher : OpTestLauncher {
syclcompat::free(op1_);
syclcompat::free(op2_);
syclcompat::free(res_);
syclcompat::free(res_hi_);
syclcompat::free(res_lo_);
}

template <auto Kernel>
Expand All @@ -148,6 +154,37 @@ class BinaryOpTestLauncher : OpTestLauncher {

CHECK(ResultT, res_h_, expected);
};
template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool need_relu) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, need_relu);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

CHECK(ResultT, res_h_, expected);
};
template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ResultT expected, bool expected_hi,
bool expected_lo) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_, res_hi_,
res_lo_);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);
bool res_hi_h_, res_lo_h_;
syclcompat::memcpy<bool>(&res_hi_h_, res_hi_, 1);
syclcompat::memcpy<bool>(&res_lo_h_, res_lo_, 1);

CHECK(ResultT, res_h_, expected);
assert(res_hi_h_ == expected_hi);
assert(res_lo_h_ == expected_lo);
};
};

template <typename ValueT, typename ResultT = ValueT>
Expand Down Expand Up @@ -187,3 +224,54 @@ class UnaryOpTestLauncher : OpTestLauncher {
CHECK(ResultT, res_h_, expected);
}
};

// Templated ResultT to support both arithmetic and boolean operators
template <typename ValueT, typename ValueU, typename ValueV,
typename ResultT = std::common_type_t<ValueT, ValueU, ValueV>>
class TernaryOpTestLauncher : OpTestLauncher {
protected:
ValueT *op1_;
ValueU *op2_;
ValueV *op3_;
ResultT res_h_, *res_;

public:
TernaryOpTestLauncher(const syclcompat::dim3 &grid,
const syclcompat::dim3 &threads,
const size_t data_size = 1)
: OpTestLauncher{
grid, threads, data_size,
should_skip<ValueT>()(syclcompat::get_current_device())} {
if (skip_)
return;
op1_ = syclcompat::malloc<ValueT>(data_size);
op2_ = syclcompat::malloc<ValueU>(data_size);
op3_ = syclcompat::malloc<ValueU>(data_size);
res_ = syclcompat::malloc<ResultT>(data_size);
};

virtual ~TernaryOpTestLauncher() {
if (skip_)
return;
syclcompat::free(op1_);
syclcompat::free(op2_);
syclcompat::free(op3_);
syclcompat::free(res_);
}

template <auto Kernel>
void launch_test(ValueT op1, ValueU op2, ValueU op3, ResultT expected,
bool need_relu = false) {
if (skip_)
return;
syclcompat::memcpy<ValueT>(op1_, &op1, data_size_);
syclcompat::memcpy<ValueU>(op2_, &op2, data_size_);
syclcompat::memcpy<ValueU>(op3_, &op3, data_size_);
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, op3_, res_,
need_relu);
syclcompat::wait();
syclcompat::memcpy<ResultT>(&res_h_, res_, data_size_);

CHECK(ResultT, res_h_, expected);
};
};
Loading
Loading