Skip to content

Commit

Permalink
JIT: Added SVE APIs CreateMaskForFirstActiveElement and `CreateMask…
Browse files Browse the repository at this point in the history
…ForNextActiveElement` (#104002)

* Initial work

* Added tests. Fixed parameter names.

* Use delay free for op1 if the target preference is op2. Use sve_mov instead of mov.

* Feedback

* Feedback

* Update Helpers.cs

* Handle RMW for non-explicit masked operation

* Remove handling as its already handled it looks like

* Feedback

* Feedback

* Feedback
  • Loading branch information
TIHan committed Jun 29, 2024
1 parent 0746dd3 commit 3a294ed
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Sve_CreateMaskForFirstActiveElement:
case NI_Sve_CreateMaskForNextActiveElement:
case NI_Sve_GetActiveElementCount:
case NI_Sve_TestAnyTrue:
case NI_Sve_TestFirstTrue:
Expand Down
33 changes: 32 additions & 1 deletion src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,23 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
assert(!node->IsEmbMaskOp());
if (HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id))
{
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
if (isRMW)
{
if (targetReg != op2Reg)
{
assert(targetReg != op1Reg);

GetEmitter()->emitIns_Mov(ins_Move_Extend(intrin.op2->TypeGet(), false),
emitTypeSize(node), targetReg, op2Reg,
/* canSkip */ true);
}

GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
}
else
{
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
}
}
else
{
Expand Down Expand Up @@ -2211,6 +2227,21 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_Sve_CreateMaskForFirstActiveElement:
{
assert(isRMW);
assert(HWIntrinsicInfo::IsExplicitMaskedOperation(intrin.id));

if (targetReg != op2Reg)
{
assert(targetReg != op1Reg);
GetEmitter()->emitIns_Mov(INS_sve_mov, emitTypeSize(node), targetReg, op2Reg, /* canSkip */ true);
}

GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, INS_OPTS_SCALABLE_B);
break;
}

default:
unreached();
}
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ HARDWARE_INTRINSIC(Sve, CreateFalseMaskSingle,
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt16, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt32, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateFalseMaskUInt64, -1, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_pfalse, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateMaskForFirstActiveElement, -1, 2, true, {INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_sve_pfirst, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, CreateMaskForNextActiveElement, -1, 2, true, {INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_sve_pnext, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_ReturnsPerElementMask|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskByte, -1, 1, false, {INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskDouble, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskInt16, -1, 1, false, {INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_ReturnsPerElementMask)
Expand Down
2 changes: 0 additions & 2 deletions src/coreclr/jit/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,6 @@ instruction CodeGen::ins_Move_Extend(var_types srcType, bool srcInReg)
#if defined(TARGET_XARCH)
return INS_kmovq_msk;
#elif defined(TARGET_ARM64)
unreached(); // TODO-SVE: This needs testing
return INS_sve_mov;
#endif
}
Expand Down Expand Up @@ -2085,7 +2084,6 @@ instruction CodeGen::ins_Copy(regNumber srcReg, var_types dstType)
#if defined(TARGET_XARCH)
return INS_kmovq_gpr;
#elif defined(TARGET_ARM64)
unreached(); // TODO-SVE: This needs testing
return INS_sve_mov;
#endif
}
Expand Down
9 changes: 8 additions & 1 deletion src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,14 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
predMask = RBM_LOWMASK.GetPredicateRegSet();
}

srcCount += BuildOperandUses(intrin.op1, predMask);
if (tgtPrefOp2)
{
srcCount += BuildDelayFreeUses(intrin.op1, intrin.op2, predMask);
}
else
{
srcCount += BuildOperandUses(intrin.op1, predMask);
}
}
}
else if (intrinsicTree->OperIsMemoryLoadOrStore())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,79 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateFalseMaskUInt64() { throw new PlatformNotSupportedException(); }


/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForFirstActiveElement(Vector<byte> mask, Vector<byte> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<short> CreateMaskForFirstActiveElement(Vector<short> mask, Vector<short> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<int> CreateMaskForFirstActiveElement(Vector<int> mask, Vector<int> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<long> CreateMaskForFirstActiveElement(Vector<long> mask, Vector<long> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<sbyte> CreateMaskForFirstActiveElement(Vector<sbyte> mask, Vector<sbyte> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ushort> CreateMaskForFirstActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<uint> CreateMaskForFirstActiveElement(Vector<uint> mask, Vector<uint> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ulong> CreateMaskForFirstActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b8(svbool_t pg, svbool_t op)
/// PNEXT Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForNextActiveElement(Vector<byte> mask, Vector<byte> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b16(svbool_t pg, svbool_t op)
/// PNEXT Ptied.H, Pg, Ptied.H
/// </summary>
public static unsafe Vector<ushort> CreateMaskForNextActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b32(svbool_t pg, svbool_t op)
/// PNEXT Ptied.S, Pg, Ptied.S
/// </summary>
public static unsafe Vector<uint> CreateMaskForNextActiveElement(Vector<uint> mask, Vector<uint> srcMask) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svpnext_b64(svbool_t pg, svbool_t op)
/// PNEXT Ptied.D, Pg, Ptied.D
/// </summary>
public static unsafe Vector<ulong> CreateMaskForNextActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) { throw new PlatformNotSupportedException(); }


/// CreateTrueMaskByte : Set predicate elements to true

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,79 @@ internal Arm64() { }
public static unsafe Vector<ulong> CreateFalseMaskUInt64() => CreateFalseMaskUInt64();


/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForFirstActiveElement(Vector<byte> mask, Vector<byte> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<short> CreateMaskForFirstActiveElement(Vector<short> mask, Vector<short> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<int> CreateMaskForFirstActiveElement(Vector<int> mask, Vector<int> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<long> CreateMaskForFirstActiveElement(Vector<long> mask, Vector<long> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<sbyte> CreateMaskForFirstActiveElement(Vector<sbyte> mask, Vector<sbyte> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ushort> CreateMaskForFirstActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<uint> CreateMaskForFirstActiveElement(Vector<uint> mask, Vector<uint> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpfirst[_b](svbool_t pg, svbool_t op)
/// PFIRST Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<ulong> CreateMaskForFirstActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) => CreateMaskForFirstActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b8(svbool_t pg, svbool_t op)
/// PNEXT Ptied.B, Pg, Ptied.B
/// </summary>
public static unsafe Vector<byte> CreateMaskForNextActiveElement(Vector<byte> mask, Vector<byte> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b16(svbool_t pg, svbool_t op)
/// PNEXT Ptied.H, Pg, Ptied.H
/// </summary>
public static unsafe Vector<ushort> CreateMaskForNextActiveElement(Vector<ushort> mask, Vector<ushort> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b32(svbool_t pg, svbool_t op)
/// PNEXT Ptied.S, Pg, Ptied.S
/// </summary>
public static unsafe Vector<uint> CreateMaskForNextActiveElement(Vector<uint> mask, Vector<uint> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);

/// <summary>
/// svbool_t svpnext_b64(svbool_t pg, svbool_t op)
/// PNEXT Ptied.D, Pg, Ptied.D
/// </summary>
public static unsafe Vector<ulong> CreateMaskForNextActiveElement(Vector<ulong> mask, Vector<ulong> srcMask) => CreateMaskForNextActiveElement(mask, srcMask);


/// CreateTrueMaskByte : Set predicate elements to true

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4335,6 +4335,20 @@ internal Arm64() { }
public static System.Numerics.Vector<ushort> CreateFalseMaskUInt16() { throw null; }
public static System.Numerics.Vector<uint> CreateFalseMaskUInt32() { throw null; }
public static System.Numerics.Vector<ulong> CreateFalseMaskUInt64() { throw null; }

public static unsafe System.Numerics.Vector<byte> CreateMaskForFirstActiveElement(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<short> CreateMaskForFirstActiveElement(System.Numerics.Vector<short> mask, System.Numerics.Vector<short> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<int> CreateMaskForFirstActiveElement(System.Numerics.Vector<int> mask, System.Numerics.Vector<int> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<long> CreateMaskForFirstActiveElement(System.Numerics.Vector<long> mask, System.Numerics.Vector<long> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<sbyte> CreateMaskForFirstActiveElement(System.Numerics.Vector<sbyte> mask, System.Numerics.Vector<sbyte> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ushort> CreateMaskForFirstActiveElement(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<uint> CreateMaskForFirstActiveElement(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ulong> CreateMaskForFirstActiveElement(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<byte> CreateMaskForNextActiveElement(System.Numerics.Vector<byte> mask, System.Numerics.Vector<byte> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ushort> CreateMaskForNextActiveElement(System.Numerics.Vector<ushort> mask, System.Numerics.Vector<ushort> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<uint> CreateMaskForNextActiveElement(System.Numerics.Vector<uint> mask, System.Numerics.Vector<uint> srcMask) { throw null; }
public static unsafe System.Numerics.Vector<ulong> CreateMaskForNextActiveElement(System.Numerics.Vector<ulong> mask, System.Numerics.Vector<ulong> srcMask) { throw null; }

public static System.Numerics.Vector<byte> CreateTrueMaskByte([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static System.Numerics.Vector<double> CreateTrueMaskDouble([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static System.Numerics.Vector<short> CreateTrueMaskInt16([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
Expand Down
Loading

0 comments on commit 3a294ed

Please sign in to comment.