Skip to content

Commit

Permalink
Extend RuntimeHelpers.IsBitwiseEquatable to more types (#75640)
Browse files Browse the repository at this point in the history
* Extend RuntimeHelpers.IsBitwiseEquatable to more types

Today, RuntimeHelpers.IsBitwiseEquatable is hardcoded to a fixed list of types.  This means that almost all of the vectorization we've done with arrays and spans is limited to just those types; developers can themselves use MemoryMarshal.Cast to convert spans of other types to spans of supported one, but it doesn't naturally happen.

This extends IsBitwiseEquatable a bit more. We already have a CanCompareBitsOrUseFastGetHashCode helper used by ValueType.Equals to determine whether structs that don't override Equals can be compared with the equivalent of memcmp.  This extends that same helper to be used by IsBitwiseEquatable.  However, IsBitwiseEquatable also needs to rule out types that implement `IEquatable<T>` (the existing helper doesn't because it's about the implementation of the object.Equals override where the interface doesn't come into play).

The upside of this is APIs like Array.IndexOf will now automatically vectorize with more types.  The main downside is that types which provide their own equality implementation still don't benefit, which in turn means adding an `IEquality<T>` implementation could in the future be a deoptimization (we should consider some kind of attribute or marker interface a type can use to say "I promise my equality implementation is the same as a bitwise comparison").  We also currently constrain most of our MemoryExtensions methods to types that implement `IEquatable<T>`, so there are only a handful of public methods today that benefit from this.

* Fix contract on CanCompareBitsOrUseFastGetHashCode

* Add more SequenceEqual tests

* Remove duplicative check

* Add IsBitwiseEquatable extension to ilc

* Address PR feedback

* Update src/coreclr/tools/Common/TypeSystem/IL/Stubs/ComparerIntrinsics.cs
  • Loading branch information
stephentoub committed Sep 27, 2022
1 parent b2af65a commit c10520d
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 127 deletions.
113 changes: 113 additions & 0 deletions src/coreclr/tools/Common/TypeSystem/IL/Stubs/ComparerIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,118 @@ private static bool ImplementsInterfaceOfSelf(TypeDesc type, string interfaceNam

return false;
}

public static bool CanCompareValueTypeBits(MetadataType type, MethodDesc objectEqualsMethod)
{
Debug.Assert(type.IsValueType);

if (type.ContainsGCPointers)
return false;

if (type.IsGenericDefinition)
return false;

OverlappingFieldTracker overlappingFieldTracker = new OverlappingFieldTracker(type);

bool result = true;
foreach (var field in type.GetFields())
{
if (field.IsStatic)
continue;

if (!overlappingFieldTracker.TrackField(field))
{
// This field overlaps with another field - can't compare memory
result = false;
break;
}

TypeDesc fieldType = field.FieldType;
if (fieldType.IsPrimitive || fieldType.IsEnum || fieldType.IsPointer || fieldType.IsFunctionPointer)
{
TypeFlags category = fieldType.UnderlyingType.Category;
if (category == TypeFlags.Single || category == TypeFlags.Double)
{
// Double/Single have weird behaviors around negative/positive zero
result = false;
break;
}
}
else
{
// Would be a suprise if this wasn't a valuetype. We checked ContainsGCPointers above.
Debug.Assert(fieldType.IsValueType);

// If the field overrides Equals, we can't use the fast helper because we need to call the method.
if (fieldType.FindVirtualFunctionTargetMethodOnObjectType(objectEqualsMethod).OwningType == fieldType)
{
result = false;
break;
}

if (!CanCompareValueTypeBits((MetadataType)fieldType, objectEqualsMethod))
{
result = false;
break;
}
}
}

// If there are gaps, we can't memcompare
if (result && overlappingFieldTracker.HasGaps)
result = false;

return result;
}

private struct OverlappingFieldTracker
{
private bool[] _usedBytes;

public OverlappingFieldTracker(MetadataType type)
{
_usedBytes = new bool[type.InstanceFieldSize.AsInt];
}

public bool TrackField(FieldDesc field)
{
int fieldBegin = field.Offset.AsInt;

TypeDesc fieldType = field.FieldType;

int fieldEnd;
if (fieldType.IsPointer || fieldType.IsFunctionPointer)
{
fieldEnd = fieldBegin + field.Context.Target.PointerSize;
}
else
{
Debug.Assert(fieldType.IsValueType);
fieldEnd = fieldBegin + ((DefType)fieldType).InstanceFieldSize.AsInt;
}

for (int i = fieldBegin; i < fieldEnd; i++)
{
if (_usedBytes[i])
return false;
_usedBytes[i] = true;
}

return true;
}

public bool HasGaps
{
get
{
for (int i = 0; i < _usedBytes.Length; i++)
if (!_usedBytes[i])
return true;

return false;
}
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,24 @@ public static MethodIL EmitIL(MethodDesc method)
result = true;
break;
default:
var mdType = elementType as MetadataType;
if (mdType != null && mdType.Name == "Rune" && mdType.Namespace == "System.Text")
result = true;
else if (mdType != null && mdType.Name == "Char8" && mdType.Namespace == "System")
result = true;
else
result = false;
result = false;
if (elementType is MetadataType mdType)
{
if (mdType.Module == mdType.Context.SystemModule &&
mdType.Namespace == "System.Text" &&
mdType.Name == "Rune")
{
result = true;
}
else if (mdType.IsValueType && !ComparerIntrinsics.ImplementsIEquatable(mdType.GetTypeDefinition()))
{
// Value type that can use memcmp and that doesn't override object.Equals or implement IEquatable<T>.Equals.
MethodDesc objectEquals = mdType.Context.GetWellKnownType(WellKnownType.Object).GetMethod("Equals", null);
result =
mdType.FindVirtualFunctionTargetMethodOnObjectType(objectEquals).OwningType != mdType &&
ComparerIntrinsics.CanCompareValueTypeBits(mdType, objectEquals);
}
}
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ public bool CanCompareValueTypeBits
if ((flags & Flags.CanCompareValueTypeBitsComputed) == 0)
{
Debug.Assert(Type.IsValueType);
if (ComputeCanCompareValueTypeBits((MetadataType)Type))
MetadataType mdType = (MetadataType)Type;
if (ComparerIntrinsics.CanCompareValueTypeBits(mdType, ((CompilerTypeSystemContext)mdType.Context)._objectEqualsMethod))
flags |= Flags.CanCompareValueTypeBits;
flags |= Flags.CanCompareValueTypeBitsComputed;

Expand All @@ -112,71 +113,6 @@ public TypeState(TypeDesc type, TypeStateHashtable hashtable)
Type = type;
_hashtable = hashtable;
}

private bool ComputeCanCompareValueTypeBits(MetadataType type)
{
Debug.Assert(type.IsValueType);

if (type.ContainsGCPointers)
return false;

if (type.IsGenericDefinition)
return false;

OverlappingFieldTracker overlappingFieldTracker = new OverlappingFieldTracker(type);

bool result = true;
foreach (var field in type.GetFields())
{
if (field.IsStatic)
continue;

if (!overlappingFieldTracker.TrackField(field))
{
// This field overlaps with another field - can't compare memory
result = false;
break;
}

TypeDesc fieldType = field.FieldType;
if (fieldType.IsPrimitive || fieldType.IsEnum || fieldType.IsPointer || fieldType.IsFunctionPointer)
{
TypeFlags category = fieldType.UnderlyingType.Category;
if (category == TypeFlags.Single || category == TypeFlags.Double)
{
// Double/Single have weird behaviors around negative/positive zero
result = false;
break;
}
}
else
{
// Would be a suprise if this wasn't a valuetype. We checked ContainsGCPointers above.
Debug.Assert(fieldType.IsValueType);

MethodDesc objectEqualsMethod = ((CompilerTypeSystemContext)fieldType.Context)._objectEqualsMethod;

// If the field overrides Equals, we can't use the fast helper because we need to call the method.
if (fieldType.FindVirtualFunctionTargetMethodOnObjectType(objectEqualsMethod).OwningType == fieldType)
{
result = false;
break;
}

if (!_hashtable.GetOrCreateValue((MetadataType)fieldType).CanCompareValueTypeBits)
{
result = false;
break;
}
}
}

// If there are gaps, we can't memcompare
if (result && overlappingFieldTracker.HasGaps)
result = false;

return result;
}
}

private sealed class TypeStateHashtable : LockFreeReaderHashtable<TypeDesc, TypeState>
Expand All @@ -192,54 +128,5 @@ protected override TypeState CreateValueFromKey(TypeDesc key)
}
}
private TypeStateHashtable _typeStateHashtable = new TypeStateHashtable();

private struct OverlappingFieldTracker
{
private bool[] _usedBytes;

public OverlappingFieldTracker(MetadataType type)
{
_usedBytes = new bool[type.InstanceFieldSize.AsInt];
}

public bool TrackField(FieldDesc field)
{
int fieldBegin = field.Offset.AsInt;

TypeDesc fieldType = field.FieldType;

int fieldEnd;
if (fieldType.IsPointer || fieldType.IsFunctionPointer)
{
fieldEnd = fieldBegin + field.Context.Target.PointerSize;
}
else
{
Debug.Assert(fieldType.IsValueType);
fieldEnd = fieldBegin + ((DefType)fieldType).InstanceFieldSize.AsInt;
}

for (int i = fieldBegin; i < fieldEnd; i++)
{
if (_usedBytes[i])
return false;
_usedBytes[i] = true;
}

return true;
}

public bool HasGaps
{
get
{
for (int i = 0; i < _usedBytes.Length; i++)
if (!_usedBytes[i])
return true;

return false;
}
}
}
}
}
4 changes: 2 additions & 2 deletions src/coreclr/vm/comutilnative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1639,13 +1639,13 @@ static BOOL HasOverriddenMethod(MethodTable* mt, MethodTable* classMT, WORD meth
return TRUE;
}

static BOOL CanCompareBitsOrUseFastGetHashCode(MethodTable* mt)
BOOL CanCompareBitsOrUseFastGetHashCode(MethodTable* mt)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_COOPERATIVE;
MODE_ANY;
} CONTRACTL_END;

_ASSERTE(mt != NULL);
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/vm/comutilnative.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,4 +243,6 @@ class StreamNative {
static FCDECL1(FC_BOOL_RET, HasOverriddenBeginEndWrite, Object *stream);
};

BOOL CanCompareBitsOrUseFastGetHashCode(MethodTable* mt);

#endif // _COMUTILNATIVE_H_
27 changes: 24 additions & 3 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7142,6 +7142,25 @@ bool getILIntrinsicImplementationForInterlocked(MethodDesc * ftn,
return true;
}

bool IsBitwiseEquatable(TypeHandle typeHandle, MethodTable * methodTable)
{
if (!methodTable->IsValueType() ||
!CanCompareBitsOrUseFastGetHashCode(methodTable))
{
return false;
}

// CanCompareBitsOrUseFastGetHashCode checks for an object.Equals override.
// We also need to check for an IEquatable<T> implementation.
Instantiation inst(&typeHandle, 1);
if (typeHandle.CanCastTo(TypeHandle(CoreLibBinder::GetClass(CLASS__IEQUATABLEGENERIC)).Instantiate(inst)))
{
return false;
}

return true;
}

bool getILIntrinsicImplementationForRuntimeHelpers(MethodDesc * ftn,
CORINFO_METHOD_INFO * methInfo)
{
Expand Down Expand Up @@ -7192,8 +7211,9 @@ bool getILIntrinsicImplementationForRuntimeHelpers(MethodDesc * ftn,
static const BYTE returnFalse[] = { CEE_LDC_I4_0, CEE_RET };

// Ideally we could detect automatically whether a type is trivially equatable
// (i.e., its operator == could be implemented via memcmp). But for now we'll
// do the simple thing and hardcode the list of types we know fulfill this contract.
// (i.e., its operator == could be implemented via memcmp). The best we can do
// for now is hardcode a list of known supported types and then also include anything
// that doesn't provide its own object.Equals override / IEquatable<T> implementation.
// n.b. This doesn't imply that the type's CompareTo method can be memcmp-implemented,
// as a method like CompareTo may need to take a type's signedness into account.

Expand All @@ -7210,7 +7230,8 @@ bool getILIntrinsicImplementationForRuntimeHelpers(MethodDesc * ftn,
|| methodTable == CoreLibBinder::GetClass(CLASS__INTPTR)
|| methodTable == CoreLibBinder::GetClass(CLASS__UINTPTR)
|| methodTable == CoreLibBinder::GetClass(CLASS__RUNE)
|| methodTable->IsEnum())
|| methodTable->IsEnum()
|| IsBitwiseEquatable(typeHandle, methodTable))
{
methInfo->ILCode = const_cast<BYTE*>(returnTrue);
}
Expand Down
Loading

0 comments on commit c10520d

Please sign in to comment.