EqualityComparerBuilder<T>
Generates hash code and equality check functions for the particular type.
using DotNext.Collections.Generic;
using DotNext.Reflection;
using DotNext.Runtime;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
namespace DotNext
{
[System.Runtime.CompilerServices.NullableContext(1)]
[System.Runtime.CompilerServices.Nullable(0)]
public readonly struct EqualityComparerBuilder<[System.Runtime.CompilerServices.Nullable(2)] T>
{
private sealed class ConstructedEqualityComparer : IEqualityComparer<T>
{
private readonly Func<T, T, bool> equality;
private readonly Func<T, int> hashCode;
internal ConstructedEqualityComparer(Func<T, T, bool> equality, Func<T, int> hashCode)
{
this.equality = equality;
this.hashCode = hashCode;
}
bool IEqualityComparer<T>.Equals(T x, T y)
{
return equality(x, y);
}
int IEqualityComparer<T>.GetHashCode(T obj)
{
return hashCode(obj);
}
}
private const BindingFlags PublicStaticFlags = BindingFlags.DeclaredOnly | BindingFlags.Static | BindingFlags.Public;
private readonly IReadOnlySet<string> excludedFields;
public string[] ExcludedFields {
set {
excludedFields = new HashSet<string>(value);
}
}
public bool SaltedHashCode { get; set; }
private bool IsIncluded(FieldInfo field)
{
return excludedFields?.Contains(field.Name) ?? true;
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private static MethodCallExpression EqualsMethodForValueType(MemberExpression first, MemberExpression second)
{
return Expression.Call(typeof(BitwiseComparer<>).MakeGenericType(first.Type).GetMethod("Equals", BindingFlags.DeclaredOnly | BindingFlags.Static | BindingFlags.Public)?.MakeGenericMethod(second.Type), first, second);
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private static MethodCallExpression HashCodeMethodForValueType(Expression expr, ConstantExpression salted)
{
return Expression.Call(typeof(BitwiseComparer<>).MakeGenericType(expr.Type).GetMethod("GetHashCode", 0, new Type[2] {
expr.Type.MakeByRefType(),
typeof(bool)
}), expr, salted);
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private static MethodInfo EqualsMethodForArrayElementType(Type itemType)
{
Type type = Type.MakeGenericMethodParameter(0).MakeArrayType();
if (!itemType.IsValueType)
return new Func<IEnumerable<object>, IEnumerable<object>, bool>(Sequence.SequenceEqual).Method;
return typeof(OneDimensionalArray).GetMethod("BitwiseEquals", 1, BindingFlags.DeclaredOnly | BindingFlags.Static | BindingFlags.Public, null, new Type[2] {
type,
type
}, null).MakeGenericMethod(itemType);
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private static MethodCallExpression EqualsMethodForArrayElementType(MemberExpression fieldX, MemberExpression fieldY)
{
return Expression.Call(EqualsMethodForArrayElementType(fieldX.Type.GetElementType()), fieldX, fieldY);
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private static MethodInfo HashCodeMethodForArrayElementType(Type itemType)
{
Type type = Type.MakeGenericMethodParameter(0).MakeArrayType();
if (!itemType.IsValueType)
return typeof(Sequence).GetMethod("SequenceHashCode", new Type[2] {
typeof(IEnumerable<object>),
typeof(bool)
});
return typeof(OneDimensionalArray).GetMethod("BitwiseHashCode", 1, BindingFlags.DeclaredOnly | BindingFlags.Static | BindingFlags.Public, null, new Type[2] {
type,
typeof(bool)
}, null).MakeGenericMethod(itemType);
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private static MethodCallExpression HashCodeMethodForArrayElementType(Expression expr, ConstantExpression salted)
{
return Expression.Call(HashCodeMethodForArrayElementType(expr.Type.GetElementType()), expr, salted);
}
private static IEnumerable<FieldInfo> GetAllFields(Type type)
{
foreach (Type baseType in type.GetBaseTypes(true, false)) {
FieldInfo[] fields = baseType.GetFields(BindingFlags.DeclaredOnly | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
foreach (FieldInfo fieldInfo in fields) {
yield return fieldInfo;
}
}
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private Func<T, T, bool> BuildEquals()
{
if (!RuntimeFeature.IsDynamicCodeSupported)
throw new PlatformNotSupportedException();
ParameterExpression parameterExpression = Expression.Parameter(typeof(T));
if (parameterExpression.Type.IsPrimitive) {
EqualityComparer<T> default = EqualityComparer<T>.Default;
return default.Equals;
}
if (parameterExpression.Type.IsSZArray)
return EqualsMethodForArrayElementType(parameterExpression.Type.GetElementType()).CreateDelegate<Func<T, T, bool>>();
ParameterExpression parameterExpression2 = Expression.Parameter(parameterExpression.Type);
Expression expression = parameterExpression.Type.IsClass ? Expression.ReferenceNotEqual(parameterExpression2, Expression.Constant(null, parameterExpression2.Type)) : null;
foreach (FieldInfo allField in GetAllFields(parameterExpression.Type)) {
if (!IsIncluded(allField))
continue;
MemberExpression memberExpression = Expression.Field(parameterExpression, allField);
MemberExpression memberExpression2 = Expression.Field(parameterExpression2, allField);
Type fieldType = allField.FieldType;
if ((object)fieldType == null)
goto IL_0130;
Expression expression2;
if (!fieldType.IsPointer && !fieldType.IsPrimitive && !fieldType.IsEnum) {
if (!fieldType.IsValueType) {
if (!fieldType.IsSZArray)
goto IL_0130;
expression2 = EqualsMethodForArrayElementType(memberExpression, memberExpression2);
} else
expression2 = EqualsMethodForValueType(memberExpression, memberExpression2);
} else
expression2 = Expression.Equal(memberExpression, memberExpression2);
goto IL_014c;
IL_014c:
Expression expression3 = expression2;
expression = ((expression == null) ? expression3 : Expression.AndAlso(expression, expression3));
continue;
IL_0130:
expression2 = Expression.Call(new Func<object, object, bool>(object.Equals).Method, memberExpression, memberExpression2);
goto IL_014c;
}
if (parameterExpression.Type.IsClass) {
BinaryExpression binaryExpression = Expression.ReferenceEqual(parameterExpression, parameterExpression2);
expression = ((expression == null) ? binaryExpression : Expression.OrElse(binaryExpression, expression));
} else if (expression == null) {
expression = Expression.Constant(true, typeof(bool));
}
return Expression.Lambda<Func<T, T, bool>>(expression, false, new ParameterExpression[2] {
parameterExpression,
parameterExpression2
}).Compile();
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
private Func<T, int> BuildGetHashCode()
{
if (!RuntimeFeature.IsDynamicCodeSupported)
throw new PlatformNotSupportedException();
ParameterExpression parameterExpression = Expression.Parameter(typeof(T));
if (parameterExpression.Type.IsPrimitive) {
EqualityComparer<T> default = EqualityComparer<T>.Default;
return default.GetHashCode;
}
Expression body;
if (parameterExpression.Type.IsSZArray) {
body = HashCodeMethodForArrayElementType(parameterExpression, Expression.Constant(SaltedHashCode));
return Expression.Lambda<Func<T, int>>(body, true, new ParameterExpression[1] {
parameterExpression
}).Compile();
}
ParameterExpression parameterExpression2 = Expression.Parameter(typeof(int));
ICollection<Expression> collection = new LinkedList<Expression>();
foreach (FieldInfo allField in GetAllFields(parameterExpression.Type)) {
if (!IsIncluded(allField))
continue;
body = Expression.Field(parameterExpression, allField);
Type fieldType = allField.FieldType;
if ((object)fieldType == null)
goto IL_0169;
Expression expression;
if (!fieldType.IsPointer) {
if (!fieldType.IsPrimitive) {
if (!fieldType.IsValueType) {
if (!fieldType.IsSZArray)
goto IL_0169;
expression = HashCodeMethodForArrayElementType(body, Expression.Constant(SaltedHashCode));
} else
expression = HashCodeMethodForValueType(body, Expression.Constant(SaltedHashCode));
} else
expression = Expression.Call(body, "GetHashCode", Type.EmptyTypes, Array.Empty<Expression>());
} else
expression = Expression.Call(typeof(Intrinsics).GetMethod("PointerHashCode", BindingFlags.DeclaredOnly | BindingFlags.Static | BindingFlags.Public), body);
goto IL_01ac;
IL_0169:
expression = Expression.Condition(Expression.ReferenceEqual(body, Expression.Constant(null, body.Type)), Expression.Constant(0, typeof(int)), Expression.Call(body, "GetHashCode", Type.EmptyTypes, Array.Empty<Expression>()));
goto IL_01ac;
IL_01ac:
body = expression;
body = Expression.Assign(parameterExpression2, Expression.Add(Expression.Multiply(parameterExpression2, Expression.Constant(-1521134295)), body));
collection.Add(body);
}
collection.Add(parameterExpression2);
body = Expression.Block(typeof(int), Sequence.Singleton<ParameterExpression>(parameterExpression2), collection);
return Expression.Lambda<Func<T, int>>(body, false, new ParameterExpression[1] {
parameterExpression
}).Compile();
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
public void Build([System.Runtime.CompilerServices.Nullable(new byte[] {
1,
2,
2
})] out Func<T, T, bool> equals, out Func<T, int> hashCode)
{
equals = BuildEquals();
hashCode = BuildGetHashCode();
}
[RequiresUnreferencedCode("Dynamic code generation may be incompatible with IL trimming")]
public IEqualityComparer<T> Build()
{
if (!typeof(T).IsPrimitive)
return new ConstructedEqualityComparer(BuildEquals(), BuildGetHashCode());
return EqualityComparer<T>.Default;
}
}
}