Skip to content

Commit

Permalink
[SYCL][NATIVECPU] Implement missing work group collectives in Native …
Browse files Browse the repository at this point in the history
…CPU libdevice (#15618)

Fixes some issues in Native CPU's libdevice:
* Remove an unused definition of `DefineBroadCastImpl`
* Fix typo in `DefineBroadCastImpl` that lead to incorrect results for
broadcast
* Define  `__spirv_GroupAny`/` __spirv_GroupAll` for work groups
  • Loading branch information
PietroGhg authored Oct 9, 2024
1 parent 5581c34 commit 1d58d6b
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions libdevice/nativecpu_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,19 @@ template <class T> struct vtypes {
DefSubgroupBlockINTEL(uint32_t) DefSubgroupBlockINTEL(uint64_t)
DefSubgroupBlockINTEL(uint8_t) DefSubgroupBlockINTEL(uint16_t)

#define DefineGOp1(spir_sfx, mux_name)\
DEVICE_EXTERN_C bool mux_name(bool);\
#define DefineGOp1(spir_sfx, name)\
DEVICE_EXTERN_C bool __mux_sub_group_##name##_i1(bool);\
DEVICE_EXTERN_C bool __mux_work_group_##name##_i1(uint32_t id, bool val);\
DEVICE_EXTERNAL bool __spirv_Group ## spir_sfx(unsigned g, bool val) {\
if (__spv::Scope::Flag::Subgroup == g)\
return mux_name(val);\
return __mux_sub_group_##name##_i1(val);\
else if (__spv::Scope::Flag::Workgroup == g)\
return __mux_work_group_##name##_i1(0, val);\
return false;\
}

DefineGOp1(Any, __mux_sub_group_any_i1)
DefineGOp1(All, __mux_sub_group_all_i1)
DefineGOp1(Any, any)
DefineGOp1(All, all)


#define DefineGOp(Type, MuxType, spir_sfx, mux_sfx) \
Expand Down Expand Up @@ -184,18 +187,6 @@ DefineBitwiseGroupOp(uint64_t, int64_t, i64)

DefineLogicalGroupOp(bool, bool, i1)

#define DefineBroadCastImpl(Type, Sfx, MuxType, IDType) \
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
int32_t id, MuxType val, int64_t lidx, int64_t lidy, int64_t lidz); \
DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
int32_t sg_lid); \
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
IDType l) { \
if (__spv::Scope::Flag::Subgroup == g) \
return __mux_sub_group_broadcast_##Sfx(v, l); \
return Type(); /*todo: add support for other flags as they are tested*/ \
}

#define DefineBroadcastMuxType(Type, Sfx, MuxType, IDType) \
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
int32_t id, MuxType val, uint64_t lidx, uint64_t lidy, uint64_t lidz); \
Expand All @@ -216,7 +207,7 @@ DefineLogicalGroupOp(bool, bool, i1)
if (__spv::Scope::Flag::Subgroup == g) \
return __mux_sub_group_broadcast_##Sfx(v, l[0]); \
else \
return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[0], 0); \
return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[1], 0); \
} \
\
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
Expand Down

0 comments on commit 1d58d6b

Please sign in to comment.