Skip to content

Commit

Permalink
ARM64-SVE: Ensure MOVPRFX is next to SVE instruction in imm tables (#…
Browse files Browse the repository at this point in the history
…106125)

* ARM64-SVE: Ensure MOVPRFX is next to SVE instruction in immediate jump tables

* Add emitInsMovPrfxHelper

* Fix formatting

* Restore a predicated movprfx use

* Fix use of predicated movprfx
  • Loading branch information
a74nh committed Aug 13, 2024
1 parent 506e749 commit 0e74147
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 46 deletions.
3 changes: 2 additions & 1 deletion src/coreclr/jit/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ class CodeGen final : public CodeGenInterface
class HWIntrinsicImmOpHelper final
{
public:
HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin);
HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin, int numInstrs = 1);

HWIntrinsicImmOpHelper(
CodeGen* codeGen, regNumber immReg, int immLowerBound, int immUpperBound, GenTreeHWIntrinsic* intrin);
Expand Down Expand Up @@ -1058,6 +1058,7 @@ class CodeGen final : public CodeGenInterface
int immUpperBound;
regNumber nonConstImmReg;
regNumber branchTargetReg;
int numInstrs;
};

#endif // TARGET_ARM64
Expand Down
142 changes: 97 additions & 45 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
// codeGen -- an instance of CodeGen class.
// immOp -- an immediate operand of the intrinsic.
// intrin -- a hardware intrinsic tree node.
// numInstrs -- number of instructions that will be in each switch entry. Default 1.
//
// Note: This class is designed to be used in the following way
// HWIntrinsicImmOpHelper helper(this, immOp, intrin);
Expand All @@ -35,11 +36,15 @@
// This allows to combine logic for cases when immOp->isContainedIntOrIImmed() is either true or false in a form
// of a for-loop.
//
CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* codeGen, GenTree* immOp, GenTreeHWIntrinsic* intrin)
CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(CodeGen* codeGen,
GenTree* immOp,
GenTreeHWIntrinsic* intrin,
int numInstrs)
: codeGen(codeGen)
, endLabel(nullptr)
, nonZeroLabel(nullptr)
, branchTargetReg(REG_NA)
, numInstrs(numInstrs)
{
assert(codeGen != nullptr);
assert(varTypeIsIntegral(immOp));
Expand Down Expand Up @@ -132,6 +137,7 @@ CodeGen::HWIntrinsicImmOpHelper::HWIntrinsicImmOpHelper(
, immUpperBound(immUpperBound)
, nonConstImmReg(immReg)
, branchTargetReg(REG_NA)
, numInstrs(1)
{
assert(codeGen != nullptr);

Expand Down Expand Up @@ -181,18 +187,32 @@ void CodeGen::HWIntrinsicImmOpHelper::EmitBegin()
}
else
{
// Here we assume that each case consists of one arm64 instruction followed by "b endLabel".
assert(numInstrs == 1 || numInstrs == 2);

// Here we assume that each case consists of numInstrs arm64 instructions followed by "b endLabel".
// Since an arm64 instruction is 4 bytes, we branch to AddressOf(beginLabel) + (nonConstImmReg << 3).
GetEmitter()->emitIns_R_L(INS_adr, EA_8BYTE, beginLabel, branchTargetReg);
GetEmitter()->emitIns_R_R_R_I(INS_add, EA_8BYTE, branchTargetReg, branchTargetReg, nonConstImmReg, 3,
INS_OPTS_LSL);

// For two instructions, add the extra one.
if (numInstrs == 2)
{
GetEmitter()->emitIns_R_R_R_I(INS_add, EA_8BYTE, branchTargetReg, branchTargetReg, nonConstImmReg, 2,
INS_OPTS_LSL);
}

// If the lower bound is non zero we need to adjust the branch target value by subtracting
// (immLowerBound << 3).
// the lower bound
if (immLowerBound != 0)
{
GetEmitter()->emitIns_R_R_I(INS_sub, EA_8BYTE, branchTargetReg, branchTargetReg,
((ssize_t)immLowerBound << 3));
ssize_t lowerReduce = ((ssize_t)immLowerBound << 3);
if (numInstrs == 2)
{
lowerReduce += ((ssize_t)immLowerBound << 2);
}

GetEmitter()->emitIns_R_R_I(INS_sub, EA_8BYTE, branchTargetReg, branchTargetReg, lowerReduce);
}

GetEmitter()->emitIns_R(INS_br, EA_8BYTE, branchTargetReg);
Expand Down Expand Up @@ -516,6 +536,15 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
}

// Shared code for setting up embedded mask arg for intrinsics with 3+ operands

auto emitEmbeddedMaskSetupInstrs = [&] {
if (intrin.op3->IsVectorZero() || (targetReg != falseReg) || (targetReg != embMaskOp1Reg))
{
return 1;
}
return 0;
};

auto emitEmbeddedMaskSetup = [&] {
if (intrin.op3->IsVectorZero())
{
Expand Down Expand Up @@ -721,6 +750,24 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
}
};

auto emitInsMovPrfxHelper = [&](regNumber reg1, regNumber reg2, regNumber reg3, regNumber reg4) {
if (hasShift)
{
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op2, op2->AsHWIntrinsic(), 2);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, EA_SCALABLE, reg1, reg2, reg3, opt);
GetEmitter()->emitInsSve_R_R_I(insEmbMask, emitSize, reg1, reg2, helper.ImmValue(),
embOpt, sopt);
}
}
else
{
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, EA_SCALABLE, reg1, reg2, reg3, opt);
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, reg1, reg2, reg4, embOpt, sopt);
}
};

if (intrin.op3->IsVectorZero())
{
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the
Expand All @@ -739,12 +786,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

default:
assert(targetReg != embMaskOp2Reg);
GetEmitter()->emitIns_R_R_R(INS_sve_movprfx, emitSize, targetReg, maskReg,
embMaskOp1Reg, opt);

// Finally, perform the actual "predicated" operation so that `targetReg` is the first
// operand and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);

emitInsMovPrfxHelper(targetReg, maskReg, embMaskOp1Reg, embMaskOp2Reg);
break;
}
}
Expand All @@ -768,30 +814,28 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
// into targetReg. Next, do the predicated operation on the targetReg and last,
// use "sel" to select the active lanes based on mask, and set inactive lanes
// to falseReg.

assert(targetReg != embMaskOp2Reg);
assert(HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinEmbMask.id));

GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, embMaskOp1Reg);

emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
emitInsMovPrfxHelper(targetReg, maskReg, embMaskOp1Reg, embMaskOp2Reg);
}

GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg,
falseReg, opt);
break;
}
else if (targetReg != embMaskOp1Reg)
{
// embMaskOp1Reg is same as `falseReg`, but not same as `targetReg`. Move the
// `embMaskOp1Reg` i.e. `falseReg` in `targetReg`, using "unpredicated movprfx", so the
// subsequent `insEmbMask` operation can be merged on top of it.
GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, falseReg);
emitInsMovPrfxHelper(targetReg, maskReg, falseReg, embMaskOp2Reg);
}
else
{
// Finally, perform the actual "predicated" operation so that `targetReg` is the first
// operand and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
}

// Finally, perform the actual "predicated" operation so that `targetReg` is the first operand
// and `embMaskOp2Reg` is the second operand.
emitInsHelper(targetReg, maskReg, embMaskOp2Reg);
}
else
{
Expand Down Expand Up @@ -907,21 +951,22 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
}
}

emitEmbeddedMaskSetup();

// Finally, perform the desired operation.
if (HWIntrinsicInfo::HasImmediateOperand(intrinEmbMask.id))
{
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op3, op2->AsHWIntrinsic());
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op3, op2->AsHWIntrinsic(),
emitEmbeddedMaskSetupInstrs() + 1);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
emitEmbeddedMaskSetup();
GetEmitter()->emitInsSve_R_R_R_I(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg,
helper.ImmValue(), opt);
}
}
else
{
assert(HWIntrinsicInfo::IsFmaIntrinsic(intrinEmbMask.id));
emitEmbeddedMaskSetup();
GetEmitter()->emitInsSve_R_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg,
embMaskOp3Reg, opt);
}
Expand All @@ -935,11 +980,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
assert(intrinEmbMask.op4->isContained() == (embMaskOp4Reg == REG_NA));
assert(HWIntrinsicInfo::HasImmediateOperand(intrinEmbMask.id));

emitEmbeddedMaskSetup();

HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op4, op2->AsHWIntrinsic());
HWIntrinsicImmOpHelper helper(this, intrinEmbMask.op4, op2->AsHWIntrinsic(),
emitEmbeddedMaskSetupInstrs() + 1);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
emitEmbeddedMaskSetup();
GetEmitter()->emitInsSve_R_R_R_R_I(insEmbMask, emitSize, targetReg, maskReg, embMaskOp2Reg,
embMaskOp3Reg, helper.ImmValue(), opt);
}
Expand Down Expand Up @@ -2333,17 +2378,17 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
assert(isRMW);

if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

HWIntrinsicImmOpHelper helper(this, intrin.op3, node);

for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitIns_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

const int elementIndex = helper.ImmValue();
const int byteIndex = genTypeSize(intrin.baseType) * elementIndex;

Expand Down Expand Up @@ -2483,17 +2528,17 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
assert(isRMW);

if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

HWIntrinsicImmOpHelper helper(this, intrin.op3, node);

for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

GetEmitter()->emitInsSve_R_R_I(ins, emitSize, targetReg, op2Reg, helper.ImmValue(), opt);
}
break;
Expand All @@ -2504,16 +2549,16 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
assert(isRMW);
assert(hasImmediateOperand);

if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

// If both immediates are constant, we don't need a jump table
if (intrin.op4->IsCnsIntOrI() && intrin.op5->IsCnsIntOrI())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

assert(intrin.op4->isContainedIntOrIImmed() && intrin.op5->isContainedIntOrIImmed());
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg,
intrin.op4->AsIntCon()->gtIconVal,
Expand All @@ -2537,6 +2582,13 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
HWIntrinsicImmOpHelper helper(this, op4Reg, 0, 7, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
}

// Extract index and rotation from the immediate
const int value = helper.ImmValue();
const ssize_t index = value & 1;
Expand Down

0 comments on commit 0e74147

Please sign in to comment.