From 1d58d6bf3643f7ec51c5b6420d863983adff818e Mon Sep 17 00:00:00 2001 From: Pietro Ghiglio Date: Wed, 9 Oct 2024 13:08:47 +0200 Subject: [PATCH] [SYCL][NATIVECPU] Implement missing work group collectives in Native CPU libdevice (#15618) Fixes some issues in Native CPU's libdevice: * Remove an unused definition of `DefineBroadCastImpl` * Fix typo in `DefineBroadCastImpl` that lead to incorrect results for broadcast * Define `__spirv_GroupAny`/` __spirv_GroupAll` for work groups --- libdevice/nativecpu_utils.cpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/libdevice/nativecpu_utils.cpp b/libdevice/nativecpu_utils.cpp index 8dc2807539653..0c431b8814dbe 100644 --- a/libdevice/nativecpu_utils.cpp +++ b/libdevice/nativecpu_utils.cpp @@ -92,16 +92,19 @@ template struct vtypes { DefSubgroupBlockINTEL(uint32_t) DefSubgroupBlockINTEL(uint64_t) DefSubgroupBlockINTEL(uint8_t) DefSubgroupBlockINTEL(uint16_t) -#define DefineGOp1(spir_sfx, mux_name)\ -DEVICE_EXTERN_C bool mux_name(bool);\ +#define DefineGOp1(spir_sfx, name)\ +DEVICE_EXTERN_C bool __mux_sub_group_##name##_i1(bool);\ +DEVICE_EXTERN_C bool __mux_work_group_##name##_i1(uint32_t id, bool val);\ DEVICE_EXTERNAL bool __spirv_Group ## spir_sfx(unsigned g, bool val) {\ if (__spv::Scope::Flag::Subgroup == g)\ - return mux_name(val);\ + return __mux_sub_group_##name##_i1(val);\ + else if (__spv::Scope::Flag::Workgroup == g)\ + return __mux_work_group_##name##_i1(0, val);\ return false;\ } -DefineGOp1(Any, __mux_sub_group_any_i1) -DefineGOp1(All, __mux_sub_group_all_i1) +DefineGOp1(Any, any) +DefineGOp1(All, all) #define DefineGOp(Type, MuxType, spir_sfx, mux_sfx) \ @@ -184,18 +187,6 @@ DefineBitwiseGroupOp(uint64_t, int64_t, i64) DefineLogicalGroupOp(bool, bool, i1) -#define DefineBroadCastImpl(Type, Sfx, MuxType, IDType) \ - DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \ - int32_t id, MuxType val, int64_t lidx, int64_t lidy, int64_t lidz); \ - DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \ - int32_t sg_lid); \ - DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \ - IDType l) { \ - if (__spv::Scope::Flag::Subgroup == g) \ - return __mux_sub_group_broadcast_##Sfx(v, l); \ - return Type(); /*todo: add support for other flags as they are tested*/ \ - } - #define DefineBroadcastMuxType(Type, Sfx, MuxType, IDType) \ DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \ int32_t id, MuxType val, uint64_t lidx, uint64_t lidy, uint64_t lidz); \ @@ -216,7 +207,7 @@ DefineLogicalGroupOp(bool, bool, i1) if (__spv::Scope::Flag::Subgroup == g) \ return __mux_sub_group_broadcast_##Sfx(v, l[0]); \ else \ - return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[0], 0); \ + return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[1], 0); \ } \ \ DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \