Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARM64-SVE: GatherPrefetch #103826

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27513,6 +27513,10 @@ bool GenTreeHWIntrinsic::OperIsMemoryLoad(GenTree** pAddr) const
addr = Op(2);
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Expand Down
24 changes: 24 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,14 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
break;

#elif defined(TARGET_ARM64)
case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
assert(varTypeIsSIMD(op2->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op2ClsHnd));
break;

case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Expand Down Expand Up @@ -1615,6 +1623,22 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
assert(!isScalar);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);

#if defined(TARGET_ARM64)
switch (intrinsic)
{
case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
assert(varTypeIsSIMD(op3->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
break;

default:
break;
}
#endif
break;
}

Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ void HWIntrinsicInfo::lookupImmBounds(
}
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
case NI_Sve_PrefetchBytes:
case NI_Sve_PrefetchInt16:
case NI_Sve_PrefetchInt32:
Expand Down
89 changes: 67 additions & 22 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1496,22 +1496,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
assert(hasImmediateOperand);
assert(HWIntrinsicInfo::HasEnumOperand(intrin.id));
if (intrin.op3->IsCnsIntOrI())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this block as HWIntrinsicImmOpHelper does all those checks for you.

{
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize,
(insSvePrfop)intrin.op3->AsIntConCommon()->IconValue(), op1Reg,
op2Reg, 0);
}
else
HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
assert(!intrin.op3->isContainedIntOrIImmed());

HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize, prfop, op1Reg, op2Reg, 0);
}
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize, prfop, op1Reg, op2Reg, 0);
}
break;
}
Expand Down Expand Up @@ -1873,6 +1862,61 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
{
assert(hasImmediateOperand);

if (!varTypeIsSIMD(intrin.op2->gtType))
{
// GatherPrefetch...(Vector<T> mask, T* address, Vector<T2> indices, SvePrefetchType prefetchType)

assert(intrin.numOperands == 4);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;

if (baseSize == EA_8BYTE)
{
// Index is multiplied.
sopt = (ins == INS_sve_prfb) ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_LSL_N;
}
else
{
// Index is sign or zero extended to 64bits, then multiplied.
assert(baseSize == EA_4BYTE);
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
: INS_OPTS_SCALABLE_S_SXTW;

sopt = (ins == INS_sve_prfb) ? INS_SCALABLE_OPTS_NONE : INS_SCALABLE_OPTS_MOD_N;
}

HWIntrinsicImmOpHelper helper(this, intrin.op4, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_R(ins, emitSize, prfop, op1Reg, op2Reg, op3Reg, opt, sopt);
}
}
else
{
// GatherPrefetch...(Vector<T> mask, Vector<T2> addresses, SvePrefetchType prefetchType)

opt = emitter::optGetSveInsOpt(emitTypeSize(node->GetAuxiliaryType()));

assert(intrin.numOperands == 3);
HWIntrinsicImmOpHelper helper(this, intrin.op3, node);
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
{
const insSvePrfop prfop = (insSvePrfop)helper.ImmValue();
GetEmitter()->emitIns_PRFOP_R_R_I(ins, emitSize, prfop, op1Reg, op2Reg, 0, opt);
}
}

break;
}

case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Expand All @@ -1890,14 +1934,14 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
// GatherVector...(Vector<T> mask, T* address, Vector<T2> indices)

assert(intrin.numOperands == 3);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
emitAttr baseSize = emitActualTypeSize(intrin.baseType);
insScalableOpts sopt = INS_SCALABLE_OPTS_NONE;

if (baseSize == EA_8BYTE)
{
// Index is multiplied.
insScalableOpts sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_LSL_N;
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_LSL_N;
}
else
{
Expand All @@ -1906,10 +1950,11 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
: INS_OPTS_SCALABLE_S_SXTW;

insScalableOpts sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_MOD_N;
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
sopt = (ins == INS_sve_ld1b || ins == INS_sve_ld1sb) ? INS_SCALABLE_OPTS_NONE
: INS_SCALABLE_OPTS_MOD_N;
}

GetEmitter()->emitIns_R_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, op3Reg, opt, sopt);
}
else
{
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ HARDWARE_INTRINSIC(Sve, FusedMultiplyAddNegated,
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtract, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractBySelectedScalar, -1, 4, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fmls, INS_sve_fmls}, HW_Category_SIMDByIndexedElement, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasRMWSemantics|HW_Flag_FmaIntrinsic|HW_Flag_LowVectorOperation)
HARDWARE_INTRINSIC(Sve, FusedMultiplySubtractNegated, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fnmls, INS_sve_fnmls}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation|HW_Flag_FmaIntrinsic|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, GatherPrefetch16Bit, -1, -1, false, {INS_invalid, INS_invalid, INS_sve_prfh, INS_sve_prfh, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherPrefetch32Bit, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_prfw, INS_sve_prfw, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherPrefetch64Bit, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_prfd, INS_sve_prfd, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherPrefetch8Bit, -1, -1, false, {INS_sve_prfb, INS_sve_prfb, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand)
HARDWARE_INTRINSIC(Sve, GatherVector, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1w, INS_sve_ld1w, INS_sve_ld1d, INS_sve_ld1d, INS_sve_ld1w, INS_sve_ld1d}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, GatherVectorByteZeroExtend, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1b, INS_sve_ld1b, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, GatherVectorInt16SignExtend, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ld1sh, INS_sve_ld1sh, INS_sve_ld1sh, INS_sve_ld1sh, INS_invalid, INS_invalid}, HW_Category_MemoryLoad, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down
23 changes: 23 additions & 0 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3386,6 +3386,29 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
}
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
assert(hasImmediateOperand);
if (!varTypeIsSIMD(intrin.op2->gtType))
{
assert(varTypeIsIntegral(intrin.op4));
if (intrin.op4->IsCnsIntOrI())
{
MakeSrcContained(node, intrin.op4);
}
}
else
{
assert(varTypeIsIntegral(intrin.op3));
if (intrin.op3->IsCnsIntOrI())
{
MakeSrcContained(node, intrin.op3);
}
}
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
Expand Down
37 changes: 37 additions & 0 deletions src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,20 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
needBranchTargetReg = !intrin.op1->isContainedIntOrIImmed();
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
if (!varTypeIsSIMD(intrin.op2->gtType))
{
needBranchTargetReg = !intrin.op4->isContainedIntOrIImmed();
}
else
{
needBranchTargetReg = !intrin.op3->isContainedIntOrIImmed();
}
break;

case NI_Sve_SaturatingDecrementBy16BitElementCount:
case NI_Sve_SaturatingDecrementBy32BitElementCount:
case NI_Sve_SaturatingDecrementBy64BitElementCount:
Expand Down Expand Up @@ -1993,6 +2007,29 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
srcCount += BuildAddrUses(intrin.op2);
break;

case NI_Sve_GatherPrefetch8Bit:
case NI_Sve_GatherPrefetch16Bit:
case NI_Sve_GatherPrefetch32Bit:
case NI_Sve_GatherPrefetch64Bit:
case NI_Sve_GatherVector:
case NI_Sve_GatherVectorByteZeroExtend:
case NI_Sve_GatherVectorInt16SignExtend:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we missed adding this previously when we added other GatherVector APIs.

case NI_Sve_GatherVectorInt16WithByteOffsetsSignExtend:
case NI_Sve_GatherVectorInt32SignExtend:
case NI_Sve_GatherVectorInt32WithByteOffsetsSignExtend:
case NI_Sve_GatherVectorSByteSignExtend:
case NI_Sve_GatherVectorUInt16WithByteOffsetsZeroExtend:
case NI_Sve_GatherVectorUInt16ZeroExtend:
case NI_Sve_GatherVectorUInt32WithByteOffsetsZeroExtend:
case NI_Sve_GatherVectorUInt32ZeroExtend:
assert(intrinsicTree->OperIsMemoryLoadOrStore());
if (!varTypeIsSIMD(intrin.op2->gtType))
{
srcCount += BuildAddrUses(intrin.op2);
break;
}
FALLTHROUGH;

default:
{
SingleTypeRegSet candidates = lowVectorOperandNum == 2 ? lowVectorCandidates : RBM_NONE;
Expand Down
Loading