Skip to content

Commit

Permalink
Respond to PR feedback and try to reduce TP regression more
Browse files Browse the repository at this point in the history
  • Loading branch information
tannergooding committed Jul 25, 2023
1 parent 136e898 commit 8a0c9a3
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 24 deletions.
12 changes: 11 additions & 1 deletion src/coreclr/jit/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3386,9 +3386,19 @@ void Compiler::compInitOptions(JitFlags* jitFlags)
{
rbmAllMask |= RBM_ALLMASK_EVEX;
rbmMskCalleeTrash |= RBM_MSK_CALLEE_TRASH_EVEX;
cntCalleeTrashMask += CNT_CALLEE_TRASH_MASK;
cntCalleeTrashMask += CNT_CALLEE_TRASH_MASK_EVEX;
}

// Make sure we copy the register info and initialize the
// trash regs after the underlying fields are initialized

const regMaskTP vtCalleeTrashRegs[TYP_COUNT]{
#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) ctr,
#include "typelist.h"
#undef DEF_TP
};
memcpy(varTypeCalleeTrashRegs, vtCalleeTrashRegs, sizeof(regMaskTP) * TYP_COUNT);

codeGen->CopyRegisterInfo();
#endif // TARGET_XARCH
}
Expand Down
13 changes: 7 additions & 6 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -10899,15 +10899,15 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
unsigned cntCalleeTrashFloat;

public:
regMaskTP get_RBM_ALLFLOAT() const
FORCEINLINE regMaskTP get_RBM_ALLFLOAT() const
{
return this->rbmAllFloat;
}
regMaskTP get_RBM_FLT_CALLEE_TRASH() const
FORCEINLINE regMaskTP get_RBM_FLT_CALLEE_TRASH() const
{
return this->rbmFltCalleeTrash;
}
unsigned get_CNT_CALLEE_TRASH_FLOAT() const
FORCEINLINE unsigned get_CNT_CALLEE_TRASH_FLOAT() const
{
return this->cntCalleeTrashFloat;
}
Expand Down Expand Up @@ -10935,17 +10935,18 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
regMaskTP rbmAllMask;
regMaskTP rbmMskCalleeTrash;
unsigned cntCalleeTrashMask;
regMaskTP varTypeCalleeTrashRegs[TYP_COUNT];

public:
regMaskTP get_RBM_ALLMASK() const
FORCEINLINE regMaskTP get_RBM_ALLMASK() const
{
return this->rbmAllMask;
}
regMaskTP get_RBM_MSK_CALLEE_TRASH() const
FORCEINLINE regMaskTP get_RBM_MSK_CALLEE_TRASH() const
{
return this->rbmMskCalleeTrash;
}
unsigned get_CNT_CALLEE_TRASH_MASK() const
FORCEINLINE unsigned get_CNT_CALLEE_TRASH_MASK() const
{
return this->cntCalleeTrashMask;
}
Expand Down
111 changes: 111 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,39 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_AVX512F_NotMask:
{
uint32_t simdSize = node->GetSimdSize();
uint32_t count = simdSize / genTypeSize(baseType);

if (count <= 8)
{
assert((count == 2) || (count == 4) || (count == 8));
ins = INS_knotb;
}
else if (count == 16)
{
ins = INS_knotw;
}
else if (count == 32)
{
ins = INS_knotd;
}
else
{
assert(count == 64);
ins = INS_knotq;
}

op1Reg = op1->GetRegNum();

assert(emitter::isMaskReg(targetReg));
assert(emitter::isMaskReg(op1Reg));

emit->emitIns_R_R(ins, EA_8BYTE, targetReg, op1Reg);
break;
}

case NI_AVX512F_OrMask:
{
uint32_t simdSize = node->GetSimdSize();
Expand Down Expand Up @@ -2174,6 +2207,84 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_AVX512F_ShiftLeftMask:
{
uint32_t simdSize = node->GetSimdSize();
uint32_t count = simdSize / genTypeSize(baseType);

if (count <= 8)
{
assert((count == 2) || (count == 4) || (count == 8));
ins = INS_kshiftlb;
}
else if (count == 16)
{
ins = INS_kshiftlw;
}
else if (count == 32)
{
ins = INS_kshiftld;
}
else
{
assert(count == 64);
ins = INS_kshiftlq;
}

op1Reg = op1->GetRegNum();

GenTree* op2 = node->Op(2);
assert(op2->IsCnsIntOrI() && op2->isContained());

assert(emitter::isMaskReg(targetReg));
assert(emitter::isMaskReg(op1Reg));

ssize_t ival = op2->AsIntCon()->IconValue();
assert((ival >= 0) && (ival <= 255));

emit->emitIns_R_R_I(ins, EA_8BYTE, targetReg, op1Reg, (int8_t)ival);
break;
}

case NI_AVX512F_ShiftRightMask:
{
uint32_t simdSize = node->GetSimdSize();
uint32_t count = simdSize / genTypeSize(baseType);

if (count <= 8)
{
assert((count == 2) || (count == 4) || (count == 8));
ins = INS_kshiftrb;
}
else if (count == 16)
{
ins = INS_kshiftrw;
}
else if (count == 32)
{
ins = INS_kshiftrd;
}
else
{
assert(count == 64);
ins = INS_kshiftrq;
}

op1Reg = op1->GetRegNum();

GenTree* op2 = node->Op(2);
assert(op2->IsCnsIntOrI() && op2->isContained());

assert(emitter::isMaskReg(targetReg));
assert(emitter::isMaskReg(op1Reg));

ssize_t ival = op2->AsIntCon()->IconValue();
assert((ival >= 0) && (ival <= 255));

emit->emitIns_R_R_I(ins, EA_8BYTE, targetReg, op1Reg, (int8_t)ival);
break;
}

case NI_AVX512F_XorMask:
{
uint32_t simdSize = node->GetSimdSize();
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsiclistxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,8 @@ HARDWARE_INTRINSIC(AVX512F, NotMask,
HARDWARE_INTRINSIC(AVX512F, op_EqualityMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative)
HARDWARE_INTRINSIC(AVX512F, op_InequalityMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative)
HARDWARE_INTRINSIC(AVX512F, OrMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(AVX512F, ShiftLeftMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(AVX512F, ShiftRightMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_IMM, HW_Flag_FullRangeIMM|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(AVX512F, XorMask, -1, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Special, HW_Flag_NoContainment|HW_Flag_Commutative|HW_Flag_ReturnsPerElementMask)

#endif // FEATURE_HW_INTRINSIC
Expand Down
31 changes: 29 additions & 2 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2002,10 +2002,37 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm

default:
{
maskIntrinsicId = NI_AVX512F_NotMask;
maskNode = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, maskIntrinsicId,
// We don't have a well known intrinsic, so we need to inverse the mask keeping the upper
// n-bits clear. If we have 1 element, then the upper 7-bits need to be cleared. If we have
// 2, then the upper 6-bits, and if we have 4, then the upper 4-bits.
//
// There isn't necessarily a trivial way to do this outside not, shift-left by n,
// shift-right by n. This preserves count bits, while clearing the upper n-bits

GenTree* cnsNode;

maskNode = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, NI_AVX512F_NotMask,
simdBaseJitType, simdSize);
BlockRange().InsertBefore(node, maskNode);

cnsNode = comp->gtNewIconNode(8 - count);
BlockRange().InsertAfter(maskNode, cnsNode);

maskNode =
comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, cnsNode, NI_AVX512F_ShiftLeftMask,
simdBaseJitType, simdSize);
BlockRange().InsertAfter(cnsNode, maskNode);
LowerNode(maskNode);

cnsNode = comp->gtNewIconNode(8 - count);
BlockRange().InsertAfter(maskNode, cnsNode);

maskNode =
comp->gtNewSimdHWIntrinsicNode(TYP_MASK, maskNode, cnsNode, NI_AVX512F_ShiftRightMask,
simdBaseJitType, simdSize);
BlockRange().InsertAfter(cnsNode, maskNode);

maskIntrinsicId = NI_AVX512F_ShiftRightMask;
break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lsra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,11 @@ LinearScan::LinearScan(Compiler* theCompiler)
#if defined(TARGET_XARCH)
rbmAllMask = compiler->rbmAllMask;
rbmMskCalleeTrash = compiler->rbmMskCalleeTrash;
memcpy(varTypeCalleeTrashRegs, compiler->varTypeCalleeTrashRegs, sizeof(regMaskTP) * TYP_COUNT);

if (!compiler->canUseEvexEncoding())
{
availableRegCount -= CNT_HIGHFLOAT;
availableRegCount -= CNT_MASK_REGS;
availableRegCount -= (CNT_HIGHFLOAT + CNT_MASK_REGS);
}
#endif // TARGET_XARCH

Expand Down
23 changes: 16 additions & 7 deletions src/coreclr/jit/lsra.h
Original file line number Diff line number Diff line change
Expand Up @@ -2027,11 +2027,11 @@ class LinearScan : public LinearScanInterface
regMaskTP rbmAllFloat;
regMaskTP rbmFltCalleeTrash;

regMaskTP get_RBM_ALLFLOAT() const
FORCEINLINE regMaskTP get_RBM_ALLFLOAT() const
{
return this->rbmAllFloat;
}
regMaskTP get_RBM_FLT_CALLEE_TRASH() const
FORCEINLINE regMaskTP get_RBM_FLT_CALLEE_TRASH() const
{
return this->rbmFltCalleeTrash;
}
Expand All @@ -2041,19 +2041,19 @@ class LinearScan : public LinearScanInterface
regMaskTP rbmAllMask;
regMaskTP rbmMskCalleeTrash;

regMaskTP get_RBM_ALLMASK() const
FORCEINLINE regMaskTP get_RBM_ALLMASK() const
{
return this->rbmAllMask;
}
regMaskTP get_RBM_MSK_CALLEE_TRASH() const
FORCEINLINE regMaskTP get_RBM_MSK_CALLEE_TRASH() const
{
return this->rbmMskCalleeTrash;
}
#endif // TARGET_XARCH

unsigned availableRegCount;

unsigned get_AVAILABLE_REG_COUNT() const
FORCEINLINE unsigned get_AVAILABLE_REG_COUNT() const
{
return this->availableRegCount;
}
Expand All @@ -2064,7 +2064,7 @@ class LinearScan : public LinearScanInterface
// NOTE: we currently don't need a LinearScan `this` pointer for this definition, and some callers
// don't have one available, so make is static.
//
static regMaskTP calleeSaveRegs(RegisterType rt)
static FORCEINLINE regMaskTP calleeSaveRegs(RegisterType rt)
{
static const regMaskTP varTypeCalleeSaveRegs[] = {
#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) csr,
Expand All @@ -2076,16 +2076,25 @@ class LinearScan : public LinearScanInterface
return varTypeCalleeSaveRegs[rt];
}

#if defined(TARGET_XARCH)
// Not all of the callee trash values are constant, so don't declare this as a method local static
// doing so results in significantly more complex codegen and we'd rather just initialize this once
// as part of initializing LSRA instead
regMaskTP varTypeCalleeTrashRegs[TYP_COUNT];
#endif // TARGET_XARCH

//------------------------------------------------------------------------
// callerSaveRegs: Get the set of caller-save registers of the given RegisterType
//
regMaskTP callerSaveRegs(RegisterType rt) const
FORCEINLINE regMaskTP callerSaveRegs(RegisterType rt) const
{
#if !defined(TARGET_XARCH)
static const regMaskTP varTypeCalleeTrashRegs[] = {
#define DEF_TP(tn, nm, jitType, sz, sze, asze, st, al, regTyp, regFld, csr, ctr, tf) ctr,
#include "typelist.h"
#undef DEF_TP
};
#endif // !TARGET_XARCH

assert((unsigned)rt < ArrLen(varTypeCalleeTrashRegs));
return varTypeCalleeTrashRegs[rt];
Expand Down
6 changes: 3 additions & 3 deletions src/coreclr/jit/lsrabuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,10 +891,10 @@ regMaskTP LinearScan::getKillSetForCall(GenTreeCall* call)
// if there is no FP used, we can ignore the FP kills
if (!compiler->compFloatingPointUsed)
{
killMask &= ~RBM_FLT_CALLEE_TRASH;

#if defined(TARGET_XARCH)
killMask &= ~RBM_MSK_CALLEE_TRASH;
killMask &= ~(RBM_FLT_CALLEE_TRASH | RBM_MSK_CALLEE_TRASH);
#else
killMask &= ~RBM_FLT_CALLEE_TRASH;
#endif // TARGET_XARCH
}
#ifdef TARGET_ARM
Expand Down
14 changes: 13 additions & 1 deletion src/coreclr/jit/vartype.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,19 @@ inline bool varTypeUsesFloatReg(T vt)
template <class T>
inline bool varTypeUsesMaskReg(T vt)
{
return varTypeRegister[TypeGet(vt)] == VTR_MASK;
// The technically correct check is:
// return varTypeRegister[TypeGet(vt)] == VTR_MASK;
//
// However, we only have one type that uses VTR_MASK today
// and so its quite a bit cheaper to just check that directly

#if defined(FEATURE_SIMD) && defined(TARGET_XARCH)
assert((TypeGet(vt) == TYP_MASK) || (varTypeRegister[TypeGet(vt)] != VTR_MASK));
return TypeGet(vt) == TYP_MASK;
#else
assert(varTypeRegister[TypeGet(vt)] != VTR_MASK);
return false;
#endif
}

template <class T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,9 @@ public static unsafe int IndexOfNullCharacter(char* searchSpace)

Vector512<ushort> search = *(Vector512<ushort>*)(searchSpace + (nuint)offset);

// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
// AVX-512 returns comparison results in a mask register, so we want to optimize
// the core check to simply be an "none match" check. This will slightly increase
// the cost for the early match case, but greatly improves perf otherwise.
if (!Vector512.EqualsAny(search, Vector512<ushort>.Zero))
{
// Zero flags set so no matches
Expand All @@ -657,6 +658,9 @@ public static unsafe int IndexOfNullCharacter(char* searchSpace)
continue;
}

// Note that ExtractMostSignificantBits has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
//
// Find bitflag offset of first match and add to current offset,
// flags are in bytes so divide for chars
ulong matches = Vector512.Equals(search, Vector512<ushort>.Zero).ExtractMostSignificantBits();
Expand Down

0 comments on commit 8a0c9a3

Please sign in to comment.