diff --git a/src/coreclr/jit/compiler.cpp b/src/coreclr/jit/compiler.cpp index 24150005e9954..79fc72cd08bd6 100644 --- a/src/coreclr/jit/compiler.cpp +++ b/src/coreclr/jit/compiler.cpp @@ -3386,9 +3386,19 @@ void Compiler::compInitOptions(JitFlags* jitFlags) { rbmAllMask |= RBM_ALLMASK_EVEX; rbmMskCalleeTrash |= RBM_MSK_CALLEE_TRASH_EVEX; - cntCalleeTrashMask += CNT_CALLEE_TRASH_MASK; + cntCalleeTrashMask += CNT_CALLEE_TRASH_MASK_EVEX; } + // Make sure we copy the register info and initialize the + // trash regs after the underlying fields are initialized + + const regMaskTP vtCalleeTrashRegs[TYP_COUNT]{ +#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) ctr, +#include "typelist.h" +#undef DEF_TP + }; + memcpy(varTypeCalleeTrashRegs, vtCalleeTrashRegs, sizeof(regMaskTP) * TYP_COUNT); + codeGen->CopyRegisterInfo(); #endif // TARGET_XARCH } diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index 965d9809b3d4d..8ba162bd2de6f 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -10899,15 +10899,15 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX unsigned cntCalleeTrashFloat; public: - regMaskTP get_RBM_ALLFLOAT() const + FORCEINLINE regMaskTP get_RBM_ALLFLOAT() const { return this->rbmAllFloat; } - regMaskTP get_RBM_FLT_CALLEE_TRASH() const + FORCEINLINE regMaskTP get_RBM_FLT_CALLEE_TRASH() const { return this->rbmFltCalleeTrash; } - unsigned get_CNT_CALLEE_TRASH_FLOAT() const + FORCEINLINE unsigned get_CNT_CALLEE_TRASH_FLOAT() const { return this->cntCalleeTrashFloat; } @@ -10935,17 +10935,18 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX regMaskTP rbmAllMask; regMaskTP rbmMskCalleeTrash; unsigned cntCalleeTrashMask; + regMaskTP varTypeCalleeTrashRegs[TYP_COUNT]; public: - regMaskTP get_RBM_ALLMASK() const + FORCEINLINE regMaskTP get_RBM_ALLMASK() const { return this->rbmAllMask; } - regMaskTP get_RBM_MSK_CALLEE_TRASH() const + FORCEINLINE regMaskTP get_RBM_MSK_CALLEE_TRASH() const { return this->rbmMskCalleeTrash; } - unsigned get_CNT_CALLEE_TRASH_MASK() const + FORCEINLINE unsigned get_CNT_CALLEE_TRASH_MASK() const { return this->cntCalleeTrashMask; } diff --git a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp index 8f10f828c0957..accf1fc62552d 100644 --- a/src/coreclr/jit/hwintrinsiccodegenxarch.cpp +++ b/src/coreclr/jit/hwintrinsiccodegenxarch.cpp @@ -2136,6 +2136,39 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node) break; } + case NI_AVX512F_NotMask: + { + uint32_t simdSize = node->GetSimdSize(); + uint32_t count = simdSize / genTypeSize(baseType); + + if (count <= 8) + { + assert((count == 2) || (count == 4) || (count == 8)); + ins = INS_knotb; + } + else if (count == 16) + { + ins = INS_knotw; + } + else if (count == 32) + { + ins = INS_knotd; + } + else + { + assert(count == 64); + ins = INS_knotq; + } + + op1Reg = op1->GetRegNum(); + + assert(emitter::isMaskReg(targetReg)); + assert(emitter::isMaskReg(op1Reg)); + + emit->emitIns_R_R(ins, EA_8BYTE, targetReg, op1Reg); + break; + } + case NI_AVX512F_OrMask: { uint32_t simdSize = node->GetSimdSize(); @@ -2174,6 +2207,84 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node) break; } + case NI_AVX512F_ShiftLeftMask: + { + uint32_t simdSize = node->GetSimdSize(); + uint32_t count = simdSize / genTypeSize(baseType); + + if (count <= 8) + { + assert((count == 2) || (count == 4) || (count == 8)); + ins = INS_kshiftlb; + } + else if (count == 16) + { + ins = INS_kshiftlw; + } + else if (count == 32) + { + ins = INS_kshiftld; + } + else + { + assert(count == 64); + ins = INS_kshiftlq; + } + + op1Reg = op1->GetRegNum(); + + GenTree* op2 = node->Op(2); + assert(op2->IsCnsIntOrI() && op2->isContained()); + + assert(emitter::isMaskReg(targetReg)); + assert(emitter::isMaskReg(op1Reg)); + + ssize_t ival = op2->AsIntCon()->IconValue(); + assert((ival >= 0) && (ival <= 255)); + + emit->emitIns_R_R_I(ins, EA_8BYTE, targetReg, op1Reg, (int8_t)ival); + break; + } + + case NI_AVX512F_ShiftRightMask: + { + uint32_t simdSize = node->GetSimdSize(); + uint32_t count = simdSize / genTypeSize(baseType); + + if (count <= 8) + { + assert((count == 2) || (count == 4) || (count == 8)); + ins = INS_kshiftrb; + } + else if (count == 16) + { + ins = INS_kshiftrw; + } + else if (count == 32) + { + ins = INS_kshiftrd; + } + else + { + assert(count == 64); + ins = INS_kshiftrq; + } + + op1Reg = op1->GetRegNum(); + + GenTree* op2 = node->Op(2); + assert(op2->IsCnsIntOrI() && op2->isContained()); + + assert(emitter::isMaskReg(targetReg)); + assert(emitter::isMaskReg(op1Reg)); + + ssize_t ival = op2->AsIntCon()->IconValue(); + assert((ival >= 0) && (ival <= 255)); + + emit->emitIns_R_R_I(ins, EA_8BYTE, targetReg, op1Reg, (int8_t)ival); + break; + } + case NI_AVX512F_XorMask: { uint32_t simdSize = node->GetSimdSize(); diff --git a/src/coreclr/jit/hwintrinsiclistxarch.h b/src/coreclr/jit/hwintrinsiclistxarch.h index 20c3868012373..b17e7d7a3a8ce 100644 --- a/src/coreclr/jit/hwintrinsiclistxarch.h +++ b/src/coreclr/jit/hwintrinsiclistxarch.h @@ -1334,6 +1334,8 @@ HARDWARE_INTRINSIC(AVX512F, NotMask, HARDWARE_INTRINSIC(AVX512F, op_EqualityMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative) HARDWARE_INTRINSIC(AVX512F, op_InequalityMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative) HARDWARE_INTRINSIC(AVX512F, OrMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative|HW_Flag_ReturnsPerElementMask) +HARDWARE_INTRINSIC(AVX512F, ShiftLeftMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM|HW_Flag_SpecialCodeGen) +HARDWARE_INTRINSIC(AVX512F, ShiftRightMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM|HW_Flag_SpecialCodeGen) HARDWARE_INTRINSIC(AVX512F, XorMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative|HW_Flag_ReturnsPerElementMask) #endif // FEATURE_HW_INTRINSIC diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index 173d0c0177f75..f40acc150de1a 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -2002,10 +2002,37 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm default: { - maskIntrinsicId = NI_AVX512F_NotMask; - maskNode = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, maskIntrinsicId, + // We don't have a well known intrinsic, so we need to inverse the mask keeping the upper + // n-bits clear. If we have 1 element, then the upper 7-bits need to be cleared. If we have + // 2, then the upper 6-bits, and if we have 4, then the upper 4-bits. + // + // There isn't necessarily a trivial way to do this outside not, shift-left by n, + // shift-right by n. This preserves count bits, while clearing the upper n-bits + + GenTree* cnsNode; + + maskNode = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, NI_AVX512F_NotMask, simdBaseJitType, simdSize); BlockRange().InsertBefore(node, maskNode); + + cnsNode = comp->gtNewIconNode(8 - count); + BlockRange().InsertAfter(maskNode, cnsNode); + + maskNode = + comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, cnsNode, NI_AVX512F_ShiftLeftMask, + simdBaseJitType, simdSize); + BlockRange().InsertAfter(cnsNode, maskNode); + LowerNode(maskNode); + + cnsNode = comp->gtNewIconNode(8 - count); + BlockRange().InsertAfter(maskNode, cnsNode); + + maskNode = + comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, cnsNode, NI_AVX512F_ShiftRightMask, + simdBaseJitType, simdSize); + BlockRange().InsertAfter(cnsNode, maskNode); + + maskIntrinsicId = NI_AVX512F_ShiftRightMask; break; } } diff --git a/src/coreclr/jit/lsra.cpp b/src/coreclr/jit/lsra.cpp index 24b071bdc4dea..485ba019ed55f 100644 --- a/src/coreclr/jit/lsra.cpp +++ b/src/coreclr/jit/lsra.cpp @@ -716,11 +716,11 @@ LinearScan::LinearScan(Compiler* theCompiler) #if defined(TARGET_XARCH) rbmAllMask = compiler->rbmAllMask; rbmMskCalleeTrash = compiler->rbmMskCalleeTrash; + memcpy(varTypeCalleeTrashRegs, compiler->varTypeCalleeTrashRegs, sizeof(regMaskTP) * TYP_COUNT); if (!compiler->canUseEvexEncoding()) { - availableRegCount -= CNT_HIGHFLOAT; - availableRegCount -= CNT_MASK_REGS; + availableRegCount -= (CNT_HIGHFLOAT + CNT_MASK_REGS); } #endif // TARGET_XARCH diff --git a/src/coreclr/jit/lsra.h b/src/coreclr/jit/lsra.h index 8ebf1c46782ab..20941e45f9d1b 100644 --- a/src/coreclr/jit/lsra.h +++ b/src/coreclr/jit/lsra.h @@ -2027,11 +2027,11 @@ class LinearScan : public LinearScanInterface regMaskTP rbmAllFloat; regMaskTP rbmFltCalleeTrash; - regMaskTP get_RBM_ALLFLOAT() const + FORCEINLINE regMaskTP get_RBM_ALLFLOAT() const { return this->rbmAllFloat; } - regMaskTP get_RBM_FLT_CALLEE_TRASH() const + FORCEINLINE regMaskTP get_RBM_FLT_CALLEE_TRASH() const { return this->rbmFltCalleeTrash; } @@ -2041,11 +2041,11 @@ class LinearScan : public LinearScanInterface regMaskTP rbmAllMask; regMaskTP rbmMskCalleeTrash; - regMaskTP get_RBM_ALLMASK() const + FORCEINLINE regMaskTP get_RBM_ALLMASK() const { return this->rbmAllMask; } - regMaskTP get_RBM_MSK_CALLEE_TRASH() const + FORCEINLINE regMaskTP get_RBM_MSK_CALLEE_TRASH() const { return this->rbmMskCalleeTrash; } @@ -2053,7 +2053,7 @@ class LinearScan : public LinearScanInterface unsigned availableRegCount; - unsigned get_AVAILABLE_REG_COUNT() const + FORCEINLINE unsigned get_AVAILABLE_REG_COUNT() const { return this->availableRegCount; } @@ -2064,7 +2064,7 @@ class LinearScan : public LinearScanInterface // NOTE: we currently don't need a LinearScan `this` pointer for this definition, and some callers // don't have one available, so make is static. // - static regMaskTP calleeSaveRegs(RegisterType rt) + static FORCEINLINE regMaskTP calleeSaveRegs(RegisterType rt) { static const regMaskTP varTypeCalleeSaveRegs[] = { #define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) csr, @@ -2076,16 +2076,25 @@ class LinearScan : public LinearScanInterface return varTypeCalleeSaveRegs[rt]; } +#if defined(TARGET_XARCH) + // Not all of the callee trash values are constant, so don't declare this as a method local static + // doing so results in significantly more complex codegen and we'd rather just initialize this once + // as part of initializing LSRA instead + regMaskTP varTypeCalleeTrashRegs[TYP_COUNT]; +#endif // TARGET_XARCH + //------------------------------------------------------------------------ // callerSaveRegs: Get the set of caller-save registers of the given RegisterType // - regMaskTP callerSaveRegs(RegisterType rt) const + FORCEINLINE regMaskTP callerSaveRegs(RegisterType rt) const { +#if !defined(TARGET_XARCH) static const regMaskTP varTypeCalleeTrashRegs[] = { #define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) ctr, #include "typelist.h" #undef DEF_TP }; +#endif // !TARGET_XARCH assert((unsigned)rt < ArrLen(varTypeCalleeTrashRegs)); return varTypeCalleeTrashRegs[rt]; diff --git a/src/coreclr/jit/lsrabuild.cpp b/src/coreclr/jit/lsrabuild.cpp index 12aad1c8c77de..8c9025f61b703 100644 --- a/src/coreclr/jit/lsrabuild.cpp +++ b/src/coreclr/jit/lsrabuild.cpp @@ -891,10 +891,10 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call) // if there is no FP used, we can ignore the FP kills if (!compiler->compFloatingPointUsed) { - killMask &= ~RBM_FLT_CALLEE_TRASH; - #if defined(TARGET_XARCH) - killMask &= ~RBM_MSK_CALLEE_TRASH; + killMask &= ~(RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH); +#else + killMask &= ~RBM_FLT_CALLEE_TRASH; #endif // TARGET_XARCH } #ifdef TARGET_ARM diff --git a/src/coreclr/jit/vartype.h b/src/coreclr/jit/vartype.h index 316bd2a867430..116d5ce2c0519 100644 --- a/src/coreclr/jit/vartype.h +++ b/src/coreclr/jit/vartype.h @@ -328,7 +328,19 @@ inline bool varTypeUsesFloatReg(T vt) template inline bool varTypeUsesMaskReg(T vt) { - return varTypeRegister[TypeGet(vt)] == VTR_MASK; +// The technically correct check is: +// return varTypeRegister[TypeGet(vt)] == VTR_MASK; +// +// However, we only have one type that uses VTR_MASK today +// and so its quite a bit cheaper to just check that directly + +#if defined(FEATURE_SIMD) && defined(TARGET_XARCH) + assert((TypeGet(vt) == TYP_MASK) || (varTypeRegister[TypeGet(vt)] != VTR_MASK)); + return TypeGet(vt) == TYP_MASK; +#else + assert(varTypeRegister[TypeGet(vt)] != VTR_MASK); + return false; +#endif } template diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs index 1f99365d3e441..b6d92204f9d23 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs @@ -647,8 +647,9 @@ public static unsafe int IndexOfNullCharacter(char* searchSpace) Vector512 search = *(Vector512*)(searchSpace + (nuint)offset); - // Note that MoveMask has converted the equal vector elements into a set of bit flags, - // So the bit position in 'matches' corresponds to the element offset. + // AVX-512 returns comparison results in a mask register, so we want to optimize + // the core check to simply be an "none match" check. This will slightly increase + // the cost for the early match case, but greatly improves perf otherwise. if (!Vector512.EqualsAny(search, Vector512.Zero)) { // Zero flags set so no matches @@ -657,6 +658,9 @@ public static unsafe int IndexOfNullCharacter(char* searchSpace) continue; } + // Note that ExtractMostSignificantBits has converted the equal vector elements into a set of bit flags, + // So the bit position in 'matches' corresponds to the element offset. + // // Find bitflag offset of first match and add to current offset, // flags are in bytes so divide for chars ulong matches = Vector512.Equals(search, Vector512.Zero).ExtractMostSignificantBits();