From 9b3c121bf2946fe99b898a5fcebf78893c9c38aa Mon Sep 17 00:00:00 2001 From: Yihan Wang Date: Mon, 2 Sep 2024 14:25:42 +0800 Subject: [PATCH] [SYCLomatic] Refine SYCLCompat support for CUB (#2302) Signed-off-by: Wang, Yihan --- clang/lib/DPCT/CallExprRewriter.cpp | 4 +- clang/lib/DPCT/CallExprRewriter.h | 4 + .../Rewriters/CUB/RewriterClassMethods.cpp | 349 +++++++----------- .../CUB/RewriterUtilityFunctions.cpp | 156 ++++---- .../lib/DPCT/Rewriters/RewriterSYCLcompat.cpp | 48 ++- 5 files changed, 245 insertions(+), 316 deletions(-) diff --git a/clang/lib/DPCT/CallExprRewriter.cpp b/clang/lib/DPCT/CallExprRewriter.cpp index 5cbb11f0fad8..f60331458e60 100644 --- a/clang/lib/DPCT/CallExprRewriter.cpp +++ b/clang/lib/DPCT/CallExprRewriter.cpp @@ -129,8 +129,10 @@ std::unique_ptr>>(); void CallExprRewriterFactoryBase::initRewriterMap() { - if (DpctGlobalInfo::useSYCLCompat()) + if (DpctGlobalInfo::useSYCLCompat()) { initRewriterMapSYCLcompat(*RewriterMap); + initRewriterMethodMapSYCLcompat(*MethodRewriterMap); + } initRewriterMapAtomic(); initRewriterMapCUB(); initRewriterMapCUFFT(); diff --git a/clang/lib/DPCT/CallExprRewriter.h b/clang/lib/DPCT/CallExprRewriter.h index fb6be4f8fd37..b8a605b47c62 100644 --- a/clang/lib/DPCT/CallExprRewriter.h +++ b/clang/lib/DPCT/CallExprRewriter.h @@ -51,6 +51,10 @@ class CallExprRewriterFactoryBase { std::unordered_map> &RewriterMap); + static void initRewriterMethodMapSYCLcompat( + std::unordered_map> + &MethodRewriterMap); static void initRewriterMapAtomic(); static void initRewriterMapCUB(); static void initRewriterMapCUFFT(); diff --git a/clang/lib/DPCT/Rewriters/CUB/RewriterClassMethods.cpp b/clang/lib/DPCT/Rewriters/CUB/RewriterClassMethods.cpp index 12cc14b75a6d..5ed5c620f7fa 100644 --- a/clang/lib/DPCT/Rewriters/CUB/RewriterClassMethods.cpp +++ b/clang/lib/DPCT/Rewriters/CUB/RewriterClassMethods.cpp @@ -71,232 +71,167 @@ RewriterMap dpct::createClassMethodsRewriterMap() { MEMBER_CALL(MemberExprBase(), false, LITERAL("create_normalize"))))) // cub::BlockRadixSort.Sort - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::BlockRadixSort.Sort", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockRadixSort.Sort")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CASE_FACTORY_ENTRY( - CASE(makeCheckAnd( - makeCheckAnd( - CheckArgCount(3, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - CheckParamType(1, "int", /*isStrict=*/true)), - CheckParamType(2, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.Sort", MemberExprBase(), false, - "sort", NDITEM, ARG(0), ARG(1), ARG(2))), - CASE( - makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CASE_FACTORY_ENTRY( + CASE( + makeCheckAnd( + makeCheckAnd(CheckArgCount(3, std::equal_to<>(), /*IncludeDefaultArg=*/false), CheckParamType(1, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY("cub::BlockRadixSort.Sort", - MemberExprBase(), false, "sort", - NDITEM, ARG(0), ARG(1))), - CASE(CheckArgCount(1, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - MEMBER_CALL_FACTORY_ENTRY("cub::BlockRadixSort.Sort", - MemberExprBase(), false, - "sort", NDITEM, ARG(0))), - OTHERWISE(UNSUPPORT_FACTORY_ENTRY( - "cub::BlockRadixSort.Sort", Diagnostics::API_NOT_MIGRATED, - printCallExprPretty()))))) + CheckParamType(2, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY("cub::BlockRadixSort.Sort", + MemberExprBase(), false, "sort", + NDITEM, ARG(0), ARG(1), ARG(2))), + CASE(makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + CheckParamType(1, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY("cub::BlockRadixSort.Sort", + MemberExprBase(), false, "sort", + NDITEM, ARG(0), ARG(1))), + CASE(CheckArgCount(1, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + MEMBER_CALL_FACTORY_ENTRY("cub::BlockRadixSort.Sort", + MemberExprBase(), false, "sort", + NDITEM, ARG(0))), + OTHERWISE(UNSUPPORT_FACTORY_ENTRY("cub::BlockRadixSort.Sort", + Diagnostics::API_NOT_MIGRATED, + printCallExprPretty())))) // cub::BlockRadixSort.SortDescending - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescending", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockRadixSort.SortDescending")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CASE_FACTORY_ENTRY( - CASE(makeCheckAnd( - makeCheckAnd( - CheckArgCount(3, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - CheckParamType(1, "int", /*isStrict=*/true)), - CheckParamType(2, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescending", - MemberExprBase(), false, "sort_descending", NDITEM, - ARG(0), ARG(1), ARG(2))), - CASE( - makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CASE_FACTORY_ENTRY( + CASE( + makeCheckAnd( + makeCheckAnd(CheckArgCount(3, std::equal_to<>(), /*IncludeDefaultArg=*/false), CheckParamType(1, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescending", - MemberExprBase(), false, "sort_descending", NDITEM, - ARG(0), ARG(1))), - CASE(CheckArgCount(1, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescending", - MemberExprBase(), false, "sort_descending", NDITEM, - ARG(0))), - OTHERWISE(UNSUPPORT_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescending", - Diagnostics::API_NOT_MIGRATED, printCallExprPretty()))))) + CheckParamType(2, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescending", MemberExprBase(), + false, "sort_descending", NDITEM, ARG(0), ARG(1), + ARG(2))), + CASE(makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + CheckParamType(1, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescending", MemberExprBase(), + false, "sort_descending", NDITEM, ARG(0), ARG(1))), + CASE(CheckArgCount(1, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescending", MemberExprBase(), + false, "sort_descending", NDITEM, ARG(0))), + OTHERWISE(UNSUPPORT_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescending", + Diagnostics::API_NOT_MIGRATED, printCallExprPretty())))) // cub::BlockRadixSort.SortBlockedToStriped - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockRadixSort.SortBlockedToStriped", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockRadixSort.SortBlockedToStriped")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CASE_FACTORY_ENTRY( - CASE(makeCheckAnd( - makeCheckAnd( - CheckArgCount(3, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - CheckParamType(1, "int", /*isStrict=*/true)), - CheckParamType(2, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortBlockedToStriped", - MemberExprBase(), false, "sort_blocked_to_striped", - NDITEM, ARG(0), ARG(1), ARG(2))), - CASE( - makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CASE_FACTORY_ENTRY( + CASE( + makeCheckAnd( + makeCheckAnd(CheckArgCount(3, std::equal_to<>(), /*IncludeDefaultArg=*/false), CheckParamType(1, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortBlockedToStriped", - MemberExprBase(), false, "sort_blocked_to_striped", - NDITEM, ARG(0), ARG(1))), - CASE(CheckArgCount(1, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortBlockedToStriped", - MemberExprBase(), false, "sort_blocked_to_striped", - NDITEM, ARG(0))), - OTHERWISE(UNSUPPORT_FACTORY_ENTRY( + CheckParamType(2, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY( "cub::BlockRadixSort.SortBlockedToStriped", - Diagnostics::API_NOT_MIGRATED, printCallExprPretty()))))) + MemberExprBase(), false, "sort_blocked_to_striped", + NDITEM, ARG(0), ARG(1), ARG(2))), + CASE(makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + CheckParamType(1, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortBlockedToStriped", + MemberExprBase(), false, "sort_blocked_to_striped", + NDITEM, ARG(0), ARG(1))), + CASE(CheckArgCount(1, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortBlockedToStriped", + MemberExprBase(), false, "sort_blocked_to_striped", + NDITEM, ARG(0))), + OTHERWISE(UNSUPPORT_FACTORY_ENTRY( + "cub::BlockRadixSort.SortBlockedToStriped", + Diagnostics::API_NOT_MIGRATED, printCallExprPretty())))) // cub::BlockRadixSort.SortDescendingBlockedToStriped - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescendingBlockedToStriped", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockRadixSort.SortDescendingBlockedToStriped")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CASE_FACTORY_ENTRY( - CASE(makeCheckAnd( - makeCheckAnd( - CheckArgCount(3, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - CheckParamType(1, "int", /*isStrict=*/true)), - CheckParamType(2, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescendingBlockedToStriped", - MemberExprBase(), false, - "sort_descending_blocked_to_striped", NDITEM, ARG(0), - ARG(1), ARG(2))), - CASE( - makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CASE_FACTORY_ENTRY( + CASE( + makeCheckAnd( + makeCheckAnd(CheckArgCount(3, std::equal_to<>(), /*IncludeDefaultArg=*/false), CheckParamType(1, "int", /*isStrict=*/true)), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescendingBlockedToStriped", - MemberExprBase(), false, - "sort_descending_blocked_to_striped", NDITEM, ARG(0), - ARG(1))), - CASE(CheckArgCount(1, std::equal_to<>(), - /*IncludeDefaultArg=*/false), - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockRadixSort.SortDescendingBlockedToStriped", - MemberExprBase(), false, - "sort_descending_blocked_to_striped", NDITEM, - ARG(0))), - OTHERWISE(UNSUPPORT_FACTORY_ENTRY( + CheckParamType(2, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY( "cub::BlockRadixSort.SortDescendingBlockedToStriped", - Diagnostics::API_NOT_MIGRATED, printCallExprPretty()))))) + MemberExprBase(), false, + "sort_descending_blocked_to_striped", NDITEM, ARG(0), + ARG(1), ARG(2))), + CASE(makeCheckAnd(CheckArgCount(2, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + CheckParamType(1, "int", /*isStrict=*/true)), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescendingBlockedToStriped", + MemberExprBase(), false, + "sort_descending_blocked_to_striped", NDITEM, ARG(0), + ARG(1))), + CASE(CheckArgCount(1, std::equal_to<>(), + /*IncludeDefaultArg=*/false), + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescendingBlockedToStriped", + MemberExprBase(), false, + "sort_descending_blocked_to_striped", NDITEM, ARG(0))), + OTHERWISE(UNSUPPORT_FACTORY_ENTRY( + "cub::BlockRadixSort.SortDescendingBlockedToStriped", + Diagnostics::API_NOT_MIGRATED, printCallExprPretty())))) // cub::BlockExchange.BlockedToStriped - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockExchange.BlockedToStriped", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockExchange.BlockedToStriped")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockExchange.BlockedToStriped", MemberExprBase(), - false, "blocked_to_striped", NDITEM, ARG(0), ARG(1)))) + HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_GROUP_Utils, + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockExchange.BlockedToStriped", + MemberExprBase(), false, "blocked_to_striped", + NDITEM, ARG(0), ARG(1))) // cub::BlockExchange.StripedToBlocked - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockExchange.StripedToBlocked", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockExchange.StripedToBlocked")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockExchange.StripedToBlocked", MemberExprBase(), - false, "striped_to_blocked", NDITEM, ARG(0), ARG(1)))) + HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_GROUP_Utils, + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockExchange.StripedToBlocked", + MemberExprBase(), false, "striped_to_blocked", + NDITEM, ARG(0), ARG(1))) // cub::BlockExchange.ScatterToBlocked - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockExchange.ScatterToBlocked", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockExchange.ScatterToBlocked")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockExchange.ScatterToBlocked", MemberExprBase(), - false, "scatter_to_blocked", NDITEM, ARG(0), ARG(1)))) + HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_GROUP_Utils, + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockExchange.ScatterToBlocked", + MemberExprBase(), false, "scatter_to_blocked", + NDITEM, ARG(0), ARG(1))) // cub::BlockExchange.ScatterToStriped - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY( - "cub::BlockExchange.ScatterToStriped", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockExchange.ScatterToStriped")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - MEMBER_CALL_FACTORY_ENTRY( - "cub::BlockExchange.ScatterToStriped", MemberExprBase(), - false, "scatter_to_striped", NDITEM, ARG(0), ARG(1)))) + HEADER_INSERT_FACTORY(HeaderType::HT_DPCT_GROUP_Utils, + MEMBER_CALL_FACTORY_ENTRY( + "cub::BlockExchange.ScatterToStriped", + MemberExprBase(), false, "scatter_to_striped", + NDITEM, ARG(0), ARG(1))) // cub::BlockLoad.Load - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::BlockLoad.Load", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockLoad::Load")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CONDITIONAL_FACTORY_ENTRY( - makeCheckAnd(CheckArgCount(2), CheckCUBEnumTemplateArg(3)), - MEMBER_CALL_FACTORY_ENTRY("cub::BlockLoad.Load", - MemberExprBase(), false, "load", - NDITEM, ARG(0), ARG(1)), - UNSUPPORT_FACTORY_ENTRY("cub::BlockLoad.Load", - Diagnostics::API_NOT_MIGRATED, - printCallExprPretty())))) + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CONDITIONAL_FACTORY_ENTRY( + makeCheckAnd(CheckArgCount(2), CheckCUBEnumTemplateArg(3)), + MEMBER_CALL_FACTORY_ENTRY("cub::BlockLoad.Load", MemberExprBase(), + false, "load", NDITEM, ARG(0), ARG(1)), + UNSUPPORT_FACTORY_ENTRY("cub::BlockLoad.Load", + Diagnostics::API_NOT_MIGRATED, + printCallExprPretty()))) // cub::BlockStore.Store - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::BlockStore.Store", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::BlockStore::Store")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CONDITIONAL_FACTORY_ENTRY( - makeCheckAnd(CheckArgCount(2), CheckCUBEnumTemplateArg(3)), - MEMBER_CALL_FACTORY_ENTRY("cub::BlockStore.Store", - MemberExprBase(), false, "store", - NDITEM, ARG(0), ARG(1)), - UNSUPPORT_FACTORY_ENTRY("cub::BlockStore.Store", - Diagnostics::API_NOT_MIGRATED, - printCallExprPretty())))) + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CONDITIONAL_FACTORY_ENTRY( + makeCheckAnd(CheckArgCount(2), CheckCUBEnumTemplateArg(3)), + MEMBER_CALL_FACTORY_ENTRY("cub::BlockStore.Store", + MemberExprBase(), false, "store", + NDITEM, ARG(0), ARG(1)), + UNSUPPORT_FACTORY_ENTRY("cub::BlockStore.Store", + Diagnostics::API_NOT_MIGRATED, + printCallExprPretty()))) }; } diff --git a/clang/lib/DPCT/Rewriters/CUB/RewriterUtilityFunctions.cpp b/clang/lib/DPCT/Rewriters/CUB/RewriterUtilityFunctions.cpp index 0adc0ef467a0..4d50c0b36de3 100644 --- a/clang/lib/DPCT/Rewriters/CUB/RewriterUtilityFunctions.cpp +++ b/clang/lib/DPCT/Rewriters/CUB/RewriterUtilityFunctions.cpp @@ -159,103 +159,71 @@ RewriterMap dpct::createUtilityFunctionsRewriterMap() { MEMBER_CALL_FACTORY_ENTRY("cub::RowMajorTid", NDITEM, /*IsArrow=*/false, "get_local_linear_id") // cub::LoadDirectBlocked - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::LoadDirectBlocked", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::LoadDirectBlocked")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CALL_FACTORY_ENTRY( - "cub::LoadDirectBlocked", - CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() + - "group::load_direct_blocked", - 0, 1, 2), - NDITEM, ARG(1), ARG(2))))) + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CALL_FACTORY_ENTRY( + "cub::LoadDirectBlocked", + CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() + + "group::load_direct_blocked", + 0, 1, 2), + NDITEM, ARG(1), ARG(2)))) // cub::LoadDirectStriped - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::LoadDirectStriped", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::LoadDirectStriped")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CALL_FACTORY_ENTRY( - "cub::LoadDirectStriped", - CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() + - "group::load_direct_striped", - 1, 2, 3), - NDITEM, ARG(1), ARG(2))))) + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CALL_FACTORY_ENTRY( + "cub::LoadDirectStriped", + CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() + + "group::load_direct_striped", + 1, 2, 3), + NDITEM, ARG(1), ARG(2)))) // cub::StoreDirectBlocked - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::StoreDirectBlocked", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::StoreDirectBlocked")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CALL_FACTORY_ENTRY("cub::StoreDirectBlocked", - CALL(PRETTY_TEMPLATED_CALLEE( - MapNames::getDpctNamespace() + - "group::store_direct_blocked", - 0, 1, 2), - NDITEM, ARG(1), ARG(2))))) + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CALL_FACTORY_ENTRY( + "cub::StoreDirectBlocked", + CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() + + "group::store_direct_blocked", + 0, 1, 2), + NDITEM, ARG(1), ARG(2)))) // cub::StoreDirectStriped - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::StoreDirectStriped", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::StoreDirectStriped")), - HEADER_INSERT_FACTORY( - HeaderType::HT_DPCT_GROUP_Utils, - CALL_FACTORY_ENTRY("cub::StoreDirectStriped", - CALL(PRETTY_TEMPLATED_CALLEE( - MapNames::getDpctNamespace() + - "group::store_direct_striped", - 1, 2, 3), - NDITEM, ARG(1), ARG(2))))) + HEADER_INSERT_FACTORY( + HeaderType::HT_DPCT_GROUP_Utils, + CALL_FACTORY_ENTRY( + "cub::StoreDirectStriped", + CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() + + "group::store_direct_striped", + 1, 2, 3), + NDITEM, ARG(1), ARG(2)))) // cub::ShuffleDown - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::ShuffleDown", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::ShuffleDown")), - SUBGROUPSIZE_FACTORY( - UINT_MAX, - MapNames::getDpctNamespace() + - "experimental::shift_sub_group_left", - CONDITIONAL_FACTORY_ENTRY( - UseNonUniformGroups, - CALL_FACTORY_ENTRY( - "cub::ShuffleDown", - CALL(TEMPLATED_CALLEE( - MapNames::getDpctNamespace() + - "experimental::shift_sub_group_left", - 0, 1), - SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))), - UNSUPPORT_FACTORY_ENTRY("cub::ShuffleDown", - Diagnostics::API_NOT_MIGRATED, - LITERAL("cub::ShuffleDown"))))) + SUBGROUPSIZE_FACTORY( + UINT_MAX, + MapNames::getDpctNamespace() + "experimental::shift_sub_group_left", + CONDITIONAL_FACTORY_ENTRY( + UseNonUniformGroups, + CALL_FACTORY_ENTRY( + "cub::ShuffleDown", + CALL( + TEMPLATED_CALLEE(MapNames::getDpctNamespace() + + "experimental::shift_sub_group_left", + 0, 1), + SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))), + UNSUPPORT_FACTORY_ENTRY("cub::ShuffleDown", + Diagnostics::API_NOT_MIGRATED, + LITERAL("cub::ShuffleDown")))) // cub::ShuffleUp - CONDITIONAL_FACTORY_ENTRY( - UseSYCLCompat, - UNSUPPORT_FACTORY_ENTRY("cub::ShuffleUp", - Diagnostics::UNSUPPORT_SYCLCOMPAT, - LITERAL("cub::ShuffleUp")), - SUBGROUPSIZE_FACTORY( - UINT_MAX, - MapNames::getDpctNamespace() + - "experimental::shift_sub_group_right", - CONDITIONAL_FACTORY_ENTRY( - UseNonUniformGroups, - CALL_FACTORY_ENTRY( - "cub::ShuffleUp", - CALL(TEMPLATED_CALLEE( - MapNames::getDpctNamespace() + - "experimental::shift_sub_group_right", - 0, 1), - SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))), - UNSUPPORT_FACTORY_ENTRY("cub::ShuffleUp", - Diagnostics::API_NOT_MIGRATED, - LITERAL("cub::ShuffleUp")))))}; + SUBGROUPSIZE_FACTORY( + UINT_MAX, + MapNames::getDpctNamespace() + "experimental::shift_sub_group_right", + CONDITIONAL_FACTORY_ENTRY( + UseNonUniformGroups, + CALL_FACTORY_ENTRY( + "cub::ShuffleUp", + CALL(TEMPLATED_CALLEE( + MapNames::getDpctNamespace() + + "experimental::shift_sub_group_right", + 0, 1), + SUBGROUP, ARG(0), ARG(1), ARG(2), ARG(3))), + UNSUPPORT_FACTORY_ENTRY("cub::ShuffleUp", + Diagnostics::API_NOT_MIGRATED, + LITERAL("cub::ShuffleUp"))))}; } diff --git a/clang/lib/DPCT/Rewriters/RewriterSYCLcompat.cpp b/clang/lib/DPCT/Rewriters/RewriterSYCLcompat.cpp index 2d5f54e75b3e..e7e5157e6afb 100644 --- a/clang/lib/DPCT/Rewriters/RewriterSYCLcompat.cpp +++ b/clang/lib/DPCT/Rewriters/RewriterSYCLcompat.cpp @@ -11,16 +11,11 @@ namespace clang { namespace dpct { -void initRewriterMapSYCLcompatUnsupport( - std::unordered_map> - &RewriterMap) { - #define ARG(x) makeCallArgCreator(x) -#define UNSUPPORT_FACTORY_ENTRY(FuncName, ...) \ - std::make_pair( \ - FuncName, createUnsupportRewriterFactory(FuncName, __VA_ARGS__)), -#define SYCLCOMPAT_UNSUPPORT(NAME) \ +#define UNSUPPORT_FACTORY_ENTRY(FuncName, ...) \ + std::make_pair(FuncName, \ + createUnsupportRewriterFactory(FuncName, __VA_ARGS__)), +#define SYCLCOMPAT_UNSUPPORT(NAME) \ UNSUPPORT_FACTORY_ENTRY(NAME, Diagnostics::UNSUPPORT_SYCLCOMPAT, ARG(NAME)) #define ENTRY_UNSUPPORTED(...) UNSUPPORT_FACTORY_ENTRY(__VA_ARGS__) #define CONDITIONAL_FACTORY_ENTRY(COND, A, B) A @@ -37,7 +32,12 @@ void initRewriterMapSYCLcompatUnsupport( #define MULTI_STMTS_FACTORY_ENTRY(NAME, ...) SYCLCOMPAT_UNSUPPORT(NAME) #define WARNING_FACTORY_ENTRY(NAME, ...) SYCLCOMPAT_UNSUPPORT(NAME) -RewriterMap.insert({ +void CallExprRewriterFactoryBase::initRewriterMapSYCLcompat( + std::unordered_map> + &RewriterMap) { + // clang-format off + RewriterMap.insert({ #include "../APINamesGraph.inc" #include "../APINamesTexture.inc" SYCLCOMPAT_UNSUPPORT("cudaMemcpy2DArrayToArray") @@ -61,14 +61,34 @@ SYCLCOMPAT_UNSUPPORT("cuMemcpyPeer") SYCLCOMPAT_UNSUPPORT("cuMemcpyPeerAsync") SYCLCOMPAT_UNSUPPORT("cudaMemcpyPeer") SYCLCOMPAT_UNSUPPORT("cudaMemcpyPeerAsync") - }); +SYCLCOMPAT_UNSUPPORT("cub::LoadDirectBlocked") +SYCLCOMPAT_UNSUPPORT("cub::LoadDirectStriped") +SYCLCOMPAT_UNSUPPORT("cub::StoreDirectBlocked") +SYCLCOMPAT_UNSUPPORT("cub::StoreDirectStriped") +SYCLCOMPAT_UNSUPPORT("cub::ShuffleDown") +SYCLCOMPAT_UNSUPPORT("cub::ShuffleUp") + }); + // clang-format on } -void CallExprRewriterFactoryBase::initRewriterMapSYCLcompat( +void CallExprRewriterFactoryBase::initRewriterMethodMapSYCLcompat( std::unordered_map> - &RewriterMap) { - initRewriterMapSYCLcompatUnsupport(RewriterMap); + &MethodRewriterMap) { + // clang-format off + MethodRewriterMap.insert({ +SYCLCOMPAT_UNSUPPORT("cub::BlockRadixSort.Sort") +SYCLCOMPAT_UNSUPPORT("cub::BlockRadixSort.SortDescending") +SYCLCOMPAT_UNSUPPORT("cub::BlockRadixSort.SortBlockedToStriped") +SYCLCOMPAT_UNSUPPORT("cub::BlockRadixSort.SortDescendingBlockedToStriped") +SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.BlockedToStriped") +SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.StripedToBlocked") +SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.ScatterToBlocked") +SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.ScatterToStriped") +SYCLCOMPAT_UNSUPPORT("cub::BlockLoad.Load") +SYCLCOMPAT_UNSUPPORT("cub::BlockStore.Store") + }); + // clang-format on } } // namespace dpct