Skip to content

Commit

Permalink
JIT: Unify arm64 and x64 GT_SELECT handling (#82610)
Browse files Browse the repository at this point in the history
This unifies GT_SELECT/GT_SELECTCC handling between arm64 and x64. The
arm64 backend no longer uses containment for compare chains; instead,
there is a new GT_CCMP node that both produces and consumes flags, and
lowering can lower GT_AND(op, relop) down to this node.
  • Loading branch information
jakobbotsch committed Mar 1, 2023
1 parent 0a15c3b commit 9cfd11d
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 405 deletions.
4 changes: 1 addition & 3 deletions src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -883,8 +883,7 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
void genCkfinite(GenTree* treeNode);
void genCodeForCompare(GenTreeOp* tree);
#ifdef TARGET_ARM64
void genCodeForConditionalCompare(GenTreeOp* tree, GenCondition prevCond);
void genCodeForContainedCompareChain(GenTree* tree, bool* inchain, GenCondition* prevCond);
void genCodeForCCMP(GenTreeCCMP* ccmp);
#endif
void genCodeForSelect(GenTreeOp* select);
void genIntrinsic(GenTreeIntrinsic* treeNode);
Expand Down Expand Up @@ -1559,7 +1558,6 @@ XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
#endif // TARGET_XARCH

#if defined(TARGET_ARM64)
static insCflags InsCflagsForCcmp(GenCondition cond);
static insCond JumpKindToInsCond(emitJumpKind condition);
#elif defined(TARGET_XARCH)
static instruction JumpKindToCmov(emitJumpKind condition);
Expand Down
223 changes: 35 additions & 188 deletions src/coreclr/jit/codegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2696,35 +2696,6 @@ void CodeGen::genCodeForBinary(GenTreeOp* tree)
return;
}

if (tree->isContainedCompareChainSegment(op2))
{
GenCondition cond;
bool chain = false;

JITDUMP("Generating compare chain:\n");
if (op1->isContained())
{
// Generate Op1 into flags.
genCodeForContainedCompareChain(op1, &chain, &cond);
assert(chain);
}
else
{
// Op1 is not contained, move it from a register into flags.
emit->emitIns_R_I(INS_cmp, emitActualTypeSize(op1), op1->GetRegNum(), 0);
cond = GenCondition::NE;
chain = true;
}
// Gen Op2 into flags.
genCodeForContainedCompareChain(op2, &chain, &cond);
assert(chain);

// Move the result from flags into a register.
inst_SETCC(cond, tree->TypeGet(), targetReg);
genProduceReg(tree);
return;
}

instruction ins = genGetInsForOper(tree->OperGet(), targetType);

if ((tree->gtFlags & GTF_SET_FLAGS) != 0)
Expand Down Expand Up @@ -4600,108 +4571,36 @@ void CodeGen::genCodeForCompare(GenTreeOp* tree)
// tree - a compare node (GT_EQ etc)
// cond - the condition of the previous generated compare.
//
void CodeGen::genCodeForConditionalCompare(GenTreeOp* tree, GenCondition prevCond)
void CodeGen::genCodeForCCMP(GenTreeCCMP* ccmp)
{
emitter* emit = GetEmitter();

GenTree* op1 = tree->gtGetOp1();
GenTree* op2 = tree->gtGetOp2();
var_types op1Type = genActualType(op1->TypeGet());
var_types op2Type = genActualType(op2->TypeGet());
emitAttr cmpSize = EA_ATTR(genTypeSize(op1Type));
regNumber targetReg = tree->GetRegNum();
regNumber srcReg1 = op1->GetRegNum();
genConsumeOperands(ccmp);
GenTree* op1 = ccmp->gtGetOp1();
GenTree* op2 = ccmp->gtGetOp2();
var_types op1Type = genActualType(op1->TypeGet());
var_types op2Type = genActualType(op2->TypeGet());
emitAttr cmpSize = emitActualTypeSize(op1Type);
regNumber srcReg1 = op1->GetRegNum();

// No float support or swapping op1 and op2 to generate cmp reg, imm.
assert(!varTypeIsFloating(op2Type));
assert(!op1->isContainedIntOrIImmed());

// Should only be called on contained nodes.
assert(targetReg == REG_NA);

// Should not be called for test conditionals (Arm64 does not have a ctst).
assert(tree->OperIsCmpCompare());

// For the ccmp flags, invert the condition of the compare.
insCflags cflags = InsCflagsForCcmp(GenCondition::FromRelop(tree));

// For the condition, use the previous compare.
const GenConditionDesc& prevDesc = GenConditionDesc::Get(prevCond);
insCond prevInsCond = JumpKindToInsCond(prevDesc.jumpKind1);
const GenConditionDesc& condDesc = GenConditionDesc::Get(ccmp->gtCondition);
insCond insCond = JumpKindToInsCond(condDesc.jumpKind1);

if (op2->isContainedIntOrIImmed())
{
GenTreeIntConCommon* intConst = op2->AsIntConCommon();
emit->emitIns_R_I_FLAGS_COND(INS_ccmp, cmpSize, srcReg1, (int)intConst->IconValue(), cflags, prevInsCond);
emit->emitIns_R_I_FLAGS_COND(INS_ccmp, cmpSize, srcReg1, (int)intConst->IconValue(), ccmp->gtFlagsVal, insCond);
}
else
{
regNumber srcReg2 = op2->GetRegNum();
emit->emitIns_R_R_FLAGS_COND(INS_ccmp, cmpSize, srcReg1, srcReg2, cflags, prevInsCond);
}
}

//------------------------------------------------------------------------
// genCodeForContainedCompareChain: Produce code for a chain of conditional compares.
//
// Only generates for contained nodes. Nodes that are not contained are assumed to be
// generated as part of standard tree generation.
//
// Arguments:
// tree - the node. Either a compare or a tree of compares connected by ANDs.
// inChain - whether a contained chain is in progress.
// prevCond - If a chain is in progress, the condition of the previous compare.
// Return:
// The last compare node generated.
//
void CodeGen::genCodeForContainedCompareChain(GenTree* tree, bool* inChain, GenCondition* prevCond)
{
assert(tree->isContained());

if (tree->OperIs(GT_AND))
{
GenTree* op1 = tree->gtGetOp1();
GenTree* op2 = tree->gtGetOp2();

assert(op2->isContained());

// If Op1 is contained, generate into flags. Otherwise, move the result into flags.
if (op1->isContained())
{
genCodeForContainedCompareChain(op1, inChain, prevCond);
assert(*inChain);
}
else
{
emitter* emit = GetEmitter();
emit->emitIns_R_I(INS_cmp, emitActualTypeSize(op1), op1->GetRegNum(), 0);
*prevCond = GenCondition::NE;
*inChain = true;
}

// Generate Op2 based on Op1.
genCodeForContainedCompareChain(op2, inChain, prevCond);
assert(*inChain);
}
else
{
assert(tree->OperIsCmpCompare());

// Generate the compare, putting the result in the flags register.
if (!*inChain)
{
// First item in a chain. Use a standard compare.
genCodeForCompare(tree->AsOp());
}
else
{
// Within the chain. Use a conditional compare (which is
// dependent on the previous emitted compare).
genCodeForConditionalCompare(tree->AsOp(), *prevCond);
}

*inChain = true;
*prevCond = GenCondition::FromRelop(tree);
emit->emitIns_R_R_FLAGS_COND(INS_ccmp, cmpSize, srcReg1, srcReg2, ccmp->gtFlagsVal, insCond);
}
}

Expand All @@ -4713,45 +4612,37 @@ void CodeGen::genCodeForContainedCompareChain(GenTree* tree, bool* inChain, GenC
//
void CodeGen::genCodeForSelect(GenTreeOp* tree)
{
assert(tree->OperIs(GT_SELECT));
GenTreeConditional* select = tree->AsConditional();
emitter* emit = GetEmitter();
assert(tree->OperIs(GT_SELECT, GT_SELECTCC));
GenTree* opcond = nullptr;
if (tree->OperIs(GT_SELECT))
{
opcond = tree->AsConditional()->gtCond;
genConsumeRegs(opcond);
}

GenTree* opcond = select->gtCond;
GenTree* op1 = select->gtOp1;
GenTree* op2 = select->gtOp2;
var_types op1Type = genActualType(op1->TypeGet());
var_types op2Type = genActualType(op2->TypeGet());
emitAttr attr = emitActualTypeSize(select->TypeGet());
emitter* emit = GetEmitter();

GenTree* op1 = tree->gtOp1;
GenTree* op2 = tree->gtOp2;
var_types op1Type = genActualType(op1);
var_types op2Type = genActualType(op2);
emitAttr attr = emitActualTypeSize(tree);

assert(!op1->isUsedFromMemory());
assert(genTypeSize(op1Type) == genTypeSize(op2Type));

GenCondition prevCond;
genConsumeRegs(opcond);
if (opcond->isContained())
GenCondition cond;

if (opcond != nullptr)
{
// Generate the contained condition.
if (opcond->OperIsCompare())
{
genCodeForCompare(opcond->AsOp());
prevCond = GenCondition::FromRelop(opcond);
}
else
{
// Condition is a compare chain. Try to contain it.
assert(opcond->OperIs(GT_AND));
bool chain = false;
JITDUMP("Generating compare chain:\n");
genCodeForContainedCompareChain(opcond, &chain, &prevCond);
assert(chain);
}
// Condition has been generated into a register - move it into flags.
emit->emitIns_R_I(INS_cmp, emitActualTypeSize(opcond), opcond->GetRegNum(), 0);
cond = GenCondition::NE;
}
else
{
// Condition has been generated into a register - move it into flags.
emit->emitIns_R_I(INS_cmp, emitActualTypeSize(opcond), opcond->GetRegNum(), 0);
prevCond = GenCondition::NE;
assert(tree->OperIs(GT_SELECTCC));
cond = tree->AsOpCC()->gtCondition;
}

assert(!op1->isContained() || op1->IsIntegralConst(0));
Expand All @@ -4760,7 +4651,7 @@ void CodeGen::genCodeForSelect(GenTreeOp* tree)
regNumber targetReg = tree->GetRegNum();
regNumber srcReg1 = op1->IsIntegralConst(0) ? REG_ZR : genConsumeReg(op1);
regNumber srcReg2 = op2->IsIntegralConst(0) ? REG_ZR : genConsumeReg(op2);
const GenConditionDesc& prevDesc = GenConditionDesc::Get(prevCond);
const GenConditionDesc& prevDesc = GenConditionDesc::Get(cond);

emit->emitIns_R_R_R_COND(INS_csel, attr, targetReg, srcReg1, srcReg2, JumpKindToInsCond(prevDesc.jumpKind1));

Expand Down Expand Up @@ -10382,50 +10273,6 @@ void CodeGen::genCodeForCond(GenTreeOp* tree)
genProduceReg(tree);
}

//------------------------------------------------------------------------
// InsCflagsForCcmp: Get the Cflags for a required for a CCMP instruction.
//
// Consider:
// cmp w, x
// ccmp y, z, A, COND
// This is: compare w and x, if this matches condition COND, then compare y and z.
// Otherwise set flags to A - this should match the case where cmp failed.
// Given COND, this function returns A.
//
// Arguments:
// cond - the GenCondition.
//
insCflags CodeGen::InsCflagsForCcmp(GenCondition cond)
{
GenCondition inverted = GenCondition::Reverse(cond);
switch (inverted.GetCode())
{
case GenCondition::EQ:
return INS_FLAGS_Z;
case GenCondition::NE:
return INS_FLAGS_NONE;
case GenCondition::SGE:
return INS_FLAGS_Z;
case GenCondition::SGT:
return INS_FLAGS_NONE;
case GenCondition::SLT:
return INS_FLAGS_NC;
case GenCondition::SLE:
return INS_FLAGS_NZC;
case GenCondition::UGE:
return INS_FLAGS_C;
case GenCondition::UGT:
return INS_FLAGS_C;
case GenCondition::ULT:
return INS_FLAGS_NONE;
case GenCondition::ULE:
return INS_FLAGS_Z;
default:
NO_WAY("unexpected condition type");
return INS_FLAGS_NONE;
}
}

//------------------------------------------------------------------------
// JumpKindToInsCond: Convert a Jump Kind to a condition.
//
Expand Down
8 changes: 8 additions & 0 deletions src/coreclr/jit/codegenarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,20 @@ void CodeGen::genCodeForTreeNode(GenTree* treeNode)
case GT_SELECT:
genCodeForSelect(treeNode->AsConditional());
break;

case GT_SELECTCC:
genCodeForSelect(treeNode->AsOp());
break;
#endif

#ifdef TARGET_ARM64
case GT_JCMP:
genCodeForJumpCompare(treeNode->AsOp());
break;

case GT_CCMP:
genCodeForCCMP(treeNode->AsCCMP());
break;
#endif // TARGET_ARM64

case GT_JCC:
Expand Down
30 changes: 29 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ void GenTree::InitNodeSize()
static_assert_no_msg(sizeof(GenTreeLclFld) <= TREE_NODE_SZ_SMALL);
static_assert_no_msg(sizeof(GenTreeCC) <= TREE_NODE_SZ_SMALL);
static_assert_no_msg(sizeof(GenTreeOpCC) <= TREE_NODE_SZ_SMALL);
#ifdef TARGET_ARM64
static_assert_no_msg(sizeof(GenTreeCCMP) <= TREE_NODE_SZ_SMALL);
#endif
static_assert_no_msg(sizeof(GenTreeConditional) <= TREE_NODE_SZ_SMALL);
static_assert_no_msg(sizeof(GenTreeCast) <= TREE_NODE_SZ_LARGE); // *** large node
static_assert_no_msg(sizeof(GenTreeBox) <= TREE_NODE_SZ_LARGE); // *** large node
static_assert_no_msg(sizeof(GenTreeField) <= TREE_NODE_SZ_LARGE); // *** large node
Expand Down Expand Up @@ -11317,6 +11321,17 @@ void Compiler::gtDispLclVarStructType(unsigned lclNum)
}
}

#if defined(DEBUG) && defined(TARGET_ARM64)
static const char* InsCflagsToString(insCflags flags)
{
const static char* s_table[16] = {"0", "v", "c", "cv", "z", "zv", "zc", "zcv",
"n", "nv", "nc", "ncv", "nz", "nzv", "nzc", "nzcv"};
unsigned index = (unsigned)flags;
assert((0 <= index) && (index < ArrLen(s_table)));
return s_table[index];
}
#endif

//------------------------------------------------------------------------
// gtDispSsaName: Display the SSA use/def for a given local.
//
Expand Down Expand Up @@ -12162,6 +12177,13 @@ void Compiler::gtDispTree(GenTree* tree,
{
printf(" cond=%s", tree->AsOpCC()->gtCondition.Name());
}
#ifdef TARGET_ARM64
else if (tree->OperIs(GT_CCMP))
{
printf(" cond=%s flags=%s", tree->AsCCMP()->gtCondition.Name(),
InsCflagsToString(tree->AsCCMP()->gtFlagsVal));
}
#endif

gtDispCommonEndLine(tree);

Expand Down Expand Up @@ -18837,7 +18859,13 @@ bool GenTree::SupportsSettingZeroFlag()
}
#endif
#elif defined(TARGET_ARM64)
if (OperIs(GT_AND, GT_ADD, GT_SUB))
if (OperIs(GT_AND))
{
return true;
}

// We do not support setting zero flag for madd/msub.
if (OperIs(GT_ADD, GT_SUB) && (!gtGetOp2()->OperIs(GT_MUL) || !gtGetOp2()->isContained()))
{
return true;
}
Expand Down
Loading

0 comments on commit 9cfd11d

Please sign in to comment.