Skip to content

Commit

Permalink
Arm64/SVE: Implemented ConvertToint64 and ConvertToUInt64 (#104069)
Browse files Browse the repository at this point in the history
* Added ConverToInt32 and ConvertToUInt32 for float inputs.

* Added flags to handle only low predicate registers.

* Fix whitespace

* Remove special codegen flag

* Added new test template for operations with different return types.

* Add new test template.

* Added api for ConvertToInt32 and ConvertToUInt 32 for double.

* Completed SVE Apis for ConvertToInt64 and ConvertToUInt64.

* ConvertToSingle for int and uint.

* ConvertToSingle for long and ulong.

* Started ConvertToDouble.

* Changed Validation Template Test name.

* ConvertToInt64.

* ConvertToInt64 passes optimized tests.

* Added cases for ConvertToSingle and ConvertToDouble.

* double or long to 32 bit value.

* Removed ConvertToDouble and ConvertToSingle.

* Removed more of ConvertToSingle and ConvertToDouble.

* all tests pass.

* addressed comments.

* jit format:

* Remove trailing space

---------

Co-authored-by: Kunal Pathak <Kunal.Pathak@microsoft.com>
  • Loading branch information
ebepho and kunalspathak committed Jun 29, 2024
1 parent 3a294ed commit 6f1d8c5
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,8 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
{
case NI_Sve_ConvertToInt32:
case NI_Sve_ConvertToUInt32:
case NI_Sve_ConvertToInt64:
case NI_Sve_ConvertToUInt64:
// Save the base type of return SIMD. It is used to contain this intrinsic inside
// ConditionalSelect.
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sig->retTypeSigClass));
Expand Down
17 changes: 14 additions & 3 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,22 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

// Special handling for ConvertTo* APIs
// Just need to change the opt here.
insOpts embOpt = opt;
switch (intrinEmbMask.id)
{
case NI_Sve_ConvertToInt32:
case NI_Sve_ConvertToUInt32:
{
opt = intrinEmbMask.baseType == TYP_DOUBLE ? INS_OPTS_D_TO_S : INS_OPTS_SCALABLE_S;
embOpt = emitTypeSize(intrinEmbMask.baseType) == EA_8BYTE ? INS_OPTS_D_TO_S
: INS_OPTS_SCALABLE_S;
break;
}

case NI_Sve_ConvertToInt64:
case NI_Sve_ConvertToUInt64:
{
embOpt = emitTypeSize(intrinEmbMask.baseType) == EA_4BYTE ? INS_OPTS_S_TO_D
: INS_OPTS_SCALABLE_D;
break;
}
default:
Expand Down Expand Up @@ -555,7 +565,8 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

// We cannot use use `movprfx` here to move falseReg to targetReg because that will
// overwrite the value of embMaskOp1Reg which is present in targetReg.
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg,
embOpt);

GetEmitter()->emitIns_R_R_R_R(INS_sve_sel, emitSize, targetReg, maskReg, targetReg,
falseReg, opt);
Expand All @@ -569,7 +580,7 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
}
}

GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, opt);
GetEmitter()->emitIns_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg, embOpt);
break;
}

Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ HARDWARE_INTRINSIC(Sve, Compute64BitAddresses,
HARDWARE_INTRINSIC(Sve, Compute8BitAddresses, -1, 2, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_adr, INS_invalid, INS_sve_adr, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen)
HARDWARE_INTRINSIC(Sve, ConditionalSelect, -1, 3, true, {INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel, INS_sve_sel}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation|HW_Flag_SupportsContainment)
HARDWARE_INTRINSIC(Sve, ConvertToInt32, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fcvtzs, INS_sve_fcvtzs}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertToInt64, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fcvtzs, INS_sve_fcvtzs}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertToUInt32, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fcvtzu, INS_sve_fcvtzu}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertToUInt64, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fcvtzu, INS_sve_fcvtzu}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, Count16BitElements, 0, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cnth, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_NoFloatingPointUsed)
HARDWARE_INTRINSIC(Sve, Count32BitElements, 0, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cntw, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_NoFloatingPointUsed)
HARDWARE_INTRINSIC(Sve, Count64BitElements, 0, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cntd, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_NoFloatingPointUsed)
Expand Down
4 changes: 3 additions & 1 deletion src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3390,7 +3390,9 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
// For now, make sure that we get here only for intrinsics that we are
// sure about to rely on auxiliary type's size.
assert((embOp->GetHWIntrinsicId() == NI_Sve_ConvertToInt32) ||
(embOp->GetHWIntrinsicId() == NI_Sve_ConvertToUInt32));
(embOp->GetHWIntrinsicId() == NI_Sve_ConvertToUInt32) ||
(embOp->GetHWIntrinsicId() == NI_Sve_ConvertToInt64) ||
(embOp->GetHWIntrinsicId() == NI_Sve_ConvertToUInt64));

uint32_t auxSize = genTypeSize(embOp->GetAuxiliaryType());
if (maskSize == auxSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,26 @@ internal Arm64() { }
public static unsafe Vector<int> ConvertToInt32(Vector<float> value) { throw new PlatformNotSupportedException(); }


/// ConvertToInt64 : Floating-point convert

/// <summary>
/// svint64_t svcvt_s64[_f64]_m(svint64_t inactive, svbool_t pg, svfloat64_t op)
/// FCVTZS Ztied.D, Pg/M, Zop.D
/// svint64_t svcvt_s64[_f64]_x(svbool_t pg, svfloat64_t op)
/// FCVTZS Ztied.D, Pg/M, Ztied.D
/// svint64_t svcvt_s64[_f64]_z(svbool_t pg, svfloat64_t op)
/// </summary>
public static unsafe Vector<long> ConvertToInt64(Vector<double> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svint64_t svcvt_s64[_f32]_m(svint64_t inactive, svbool_t pg, svfloat32_t op)
/// FCVTZS Ztied.D, Pg/M, Zop.S
/// svint64_t svcvt_s64[_f32]_x(svbool_t pg, svfloat32_t op)
/// FCVTZS Ztied.D, Pg/M, Ztied.S
/// svint64_t svcvt_s64[_f32]_z(svbool_t pg, svfloat32_t op)
/// </summary>
public static unsafe Vector<long> ConvertToInt64(Vector<float> value) { throw new PlatformNotSupportedException(); }

/// ConvertToUInt32 : Floating-point convert

/// <summary>
Expand All @@ -890,6 +910,27 @@ internal Arm64() { }
public static unsafe Vector<uint> ConvertToUInt32(Vector<float> value) { throw new PlatformNotSupportedException(); }


/// ConvertToUInt64 : Floating-point convert

/// <summary>
/// svuint64_t svcvt_u64[_f64]_m(svuint64_t inactive, svbool_t pg, svfloat64_t op)
/// FCVTZU Ztied.D, Pg/M, Zop.D
/// svuint64_t svcvt_u64[_f64]_x(svbool_t pg, svfloat64_t op)
/// FCVTZU Ztied.D, Pg/M, Ztied.D
/// svuint64_t svcvt_u64[_f64]_z(svbool_t pg, svfloat64_t op)
/// </summary>
public static unsafe Vector<ulong> ConvertToUInt64(Vector<double> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint64_t svcvt_u64[_f32]_m(svuint64_t inactive, svbool_t pg, svfloat32_t op)
/// FCVTZU Ztied.D, Pg/M, Zop.S
/// svuint64_t svcvt_u64[_f32]_x(svbool_t pg, svfloat32_t op)
/// FCVTZU Ztied.D, Pg/M, Ztied.S
/// svuint64_t svcvt_u64[_f32]_z(svbool_t pg, svfloat32_t op)
/// </summary>
public static unsafe Vector<ulong> ConvertToUInt64(Vector<float> value) { throw new PlatformNotSupportedException(); }


/// Count16BitElements : Count the number of 16-bit elements in a vector

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,27 @@ internal Arm64() { }
public static unsafe Vector<int> ConvertToInt32(Vector<float> value) => ConvertToInt32(value);


/// ConvertToInt64 : Floating-point convert

/// <summary>
/// svint64_t svcvt_s64[_f64]_m(svint64_t inactive, svbool_t pg, svfloat64_t op)
/// FCVTZS Ztied.D, Pg/M, Zop.D
/// svint64_t svcvt_s64[_f64]_x(svbool_t pg, svfloat64_t op)
/// FCVTZS Ztied.D, Pg/M, Ztied.D
/// svint64_t svcvt_s64[_f64]_z(svbool_t pg, svfloat64_t op)
/// </summary>
public static unsafe Vector<long> ConvertToInt64(Vector<double> value) => ConvertToInt64(value);

/// <summary>
/// svint64_t svcvt_s64[_f32]_m(svint64_t inactive, svbool_t pg, svfloat32_t op)
/// FCVTZS Ztied.D, Pg/M, Zop.S
/// svint64_t svcvt_s64[_f32]_x(svbool_t pg, svfloat32_t op)
/// FCVTZS Ztied.D, Pg/M, Ztied.S
/// svint64_t svcvt_s64[_f32]_z(svbool_t pg, svfloat32_t op)
/// </summary>
public static unsafe Vector<long> ConvertToInt64(Vector<float> value) => ConvertToInt64(value);


/// ConvertToUInt32 : Floating-point convert

/// <summary>
Expand All @@ -947,6 +968,27 @@ internal Arm64() { }
public static unsafe Vector<uint> ConvertToUInt32(Vector<float> value) => ConvertToUInt32(value);


/// ConvertToUInt64 : Floating-point convert

/// <summary>
/// svuint64_t svcvt_u64[_f64]_m(svuint64_t inactive, svbool_t pg, svfloat64_t op)
/// FCVTZU Ztied.D, Pg/M, Zop.D
/// svuint64_t svcvt_u64[_f64]_x(svbool_t pg, svfloat64_t op)
/// FCVTZU Ztied.D, Pg/M, Ztied.D
/// svuint64_t svcvt_u64[_f64]_z(svbool_t pg, svfloat64_t op)
/// </summary>
public static unsafe Vector<ulong> ConvertToUInt64(Vector<double> value) => ConvertToUInt64(value);

/// <summary>
/// svuint64_t svcvt_u64[_f32]_m(svuint64_t inactive, svbool_t pg, svfloat32_t op)
/// FCVTZU Ztied.D, Pg/M, Zop.S
/// svuint64_t svcvt_u64[_f32]_x(svbool_t pg, svfloat32_t op)
/// FCVTZU Ztied.D, Pg/M, Ztied.S
/// svuint64_t svcvt_u64[_f32]_z(svbool_t pg, svfloat32_t op)
/// </summary>
public static unsafe Vector<ulong> ConvertToUInt64(Vector<float> value) => ConvertToUInt64(value);


/// Count16BitElements : Count the number of 16-bit elements in a vector

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4317,8 +4317,12 @@ internal Arm64() { }

public static System.Numerics.Vector<int> ConvertToInt32(System.Numerics.Vector<double> value) { throw null; }
public static System.Numerics.Vector<int> ConvertToInt32(System.Numerics.Vector<float> value) { throw null; }
public static System.Numerics.Vector<long> ConvertToInt64(System.Numerics.Vector<double> value) { throw null; }
public static System.Numerics.Vector<long> ConvertToInt64(System.Numerics.Vector<float> value) { throw null; }
public static System.Numerics.Vector<uint> ConvertToUInt32(System.Numerics.Vector<double> value) { throw null; }
public static System.Numerics.Vector<uint> ConvertToUInt32(System.Numerics.Vector<float> value) { throw null; }
public static System.Numerics.Vector<ulong> ConvertToUInt64(System.Numerics.Vector<double> value) { throw null; }
public static System.Numerics.Vector<ulong> ConvertToUInt64(System.Numerics.Vector<float> value) { throw null; }

public static ulong Count16BitElements([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
public static ulong Count32BitElements([ConstantExpected] SveMaskPattern pattern = SveMaskPattern.All) { throw null; }
Expand Down
Loading

0 comments on commit 6f1d8c5

Please sign in to comment.