Skip to content

Commit

Permalink
Added checks when setting properties that if null, throws
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwardCooke committed Jul 5, 2024
1 parent 7ae209b commit 08fbe44
Show file tree
Hide file tree
Showing 17 changed files with 197 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public override void Write(ClassSyntaxReceiver classSyntaxReceiver)
Write("public bool CanWrite { get; }");
Write("public Type Type { get; }");
Write("public Type TypeOverride { get; set; }");
Write("public bool AllowNulls { get; set; }");
Write("public int Order { get; set; }");
Write("public YamlDotNet.Core.ScalarStyle ScalarStyle { get; set; }");
Write("public T GetCustomAttribute<T>() where T : Attribute");
Expand All @@ -61,7 +62,7 @@ public override void Write(ClassSyntaxReceiver classSyntaxReceiver)
Write("{"); Indent();
Write("_accessor.Set(Name, target, value);");
UnIndent(); Write("}");
Write("public StaticPropertyDescriptor(YamlDotNet.Serialization.ITypeResolver typeResolver, YamlDotNet.Serialization.IObjectAccessor accessor, string name, bool canWrite, Type type, Attribute[] attributes)");
Write("public StaticPropertyDescriptor(YamlDotNet.Serialization.ITypeResolver typeResolver, YamlDotNet.Serialization.IObjectAccessor accessor, string name, bool canWrite, Type type, Attribute[] attributes, bool allowNulls)");
Write("{"); Indent();
Write("this._typeResolver = typeResolver;");
Write("this._accessor = accessor;");
Expand All @@ -70,6 +71,7 @@ public override void Write(ClassSyntaxReceiver classSyntaxReceiver)
Write("this.CanWrite = canWrite;");
Write("this.Type = type;");
Write("this.ScalarStyle = YamlDotNet.Core.ScalarStyle.Any;");
Write("this.AllowNulls = allowNulls;");
UnIndent(); Write("}");
UnIndent(); Write("}");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ namespace YamlDotNet.Analyzers.StaticGenerator
{
public class StaticTypeInspectorFile : File
{
private readonly GeneratorExecutionContext context;

public StaticTypeInspectorFile(Action<string, bool> Write, Action indent, Action unindent, GeneratorExecutionContext context) : base(Write, indent, unindent, context)
{
this.context = context;
}

public override void Write(ClassSyntaxReceiver classSyntaxReceiver)
Expand Down Expand Up @@ -102,6 +105,8 @@ public override void Write(ClassSyntaxReceiver classSyntaxReceiver)

private void WritePropertyDescriptor(string name, ITypeSymbol type, bool isReadonly, ImmutableArray<AttributeData> attributes, char finalChar)
{
var allowNulls = type.NullableAnnotation.HasFlag(NullableAnnotation.Annotated) && context.Compilation.Options.NullableContextOptions.AnnotationsEnabled();

Write($"new StaticPropertyDescriptor(_typeResolver, accessor, \"{name}\", {(!isReadonly).ToString().ToLower()}, typeof({type.GetFullName().Replace("?", string.Empty)}), new Attribute[] {{");
foreach (var attribute in attributes)
{
Expand Down Expand Up @@ -145,7 +150,7 @@ private void WritePropertyDescriptor(string name, ITypeSymbol type, bool isReado
break;
}
}
Write($"}}){finalChar}");
Write($"}}, {allowNulls.ToString().ToLower()}){finalChar}");

}
}
Expand Down
1 change: 1 addition & 0 deletions YamlDotNet.Core7AoTCompileTest.Model/ExternalModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
public class ExternalModel
{
public string? Text { get; set; }
public string NotNull { get; set; } = string.Empty;
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net70</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
Expand Down
33 changes: 33 additions & 0 deletions YamlDotNet.Core7AoTCompileTest/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Security.Cryptography.X509Certificates;
using YamlDotNet.Core;
using YamlDotNet.Core7AoTCompileTest.Model;
using YamlDotNet.Serialization;
Expand Down Expand Up @@ -171,6 +172,31 @@
Console.WriteLine("Items[0]: <{0}>", string.Join(',', o[0].myArray));
Console.WriteLine("Items[1]: <{0}>", string.Join(',', o[1].myArray));

deserializer = new StaticDeserializerBuilder(aotContext).WithEnforceNullability().Build();
yaml = "Nullable: null";
var nullable = deserializer.Deserialize<NullableTestClass>(yaml);
Console.WriteLine("Nullable Value (should be empty): <{0}>", nullable.Nullable);
yaml = "NotNullable: test";
nullable = deserializer.Deserialize<NullableTestClass>(yaml);
Console.WriteLine("NotNullable Value (should be test): <{0}>", nullable.NotNullable);
try
{
yaml = "NotNullable: null";
nullable = deserializer.Deserialize<NullableTestClass>(yaml);
throw new Exception("NotNullable should not be allowed to be set to null.");
}
catch (YamlException exception)
{
if (exception.InnerException is NullReferenceException)
{
Console.WriteLine("Exception thrown while setting non nullable value to null, as it should.");
}
else
{
throw new Exception("NotNullable should not be allowed to be set to null.");
}
}

[YamlSerializable]
public class MyArray
{
Expand All @@ -183,6 +209,13 @@ public class Inner
public string? Text { get; set; }
}

[YamlSerializable]
public class NullableTestClass
{
public string? Nullable { get; set; }
public string NotNullable { get; set; }
}

[YamlSerializable]
public class PrimitiveTypes
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net70</TargetFramework>
<TargetFramework>net7.0</TargetFramework>
<PublishAot>true</PublishAot>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
<Nullable>enable</Nullable>
Expand All @@ -20,9 +20,7 @@
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\YamlDotNet.Analyzers.StaticGenerator\YamlDotNet.Analyzers.StaticGenerator.csproj"
OutputItemType="Analyzer"
ReferenceOutputAssembly="false" />
<ProjectReference Include="..\YamlDotNet.Analyzers.StaticGenerator\YamlDotNet.Analyzers.StaticGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
<ProjectReference Include="..\YamlDotNet.Core7AoTCompileTest.Model\YamlDotNet.Core7AoTCompileTest.Model.csproj" />
<ProjectReference Include="..\YamlDotNet\YamlDotNet.csproj" />
</ItemGroup>
Expand Down
39 changes: 39 additions & 0 deletions YamlDotNet.Test/Serialization/DeserializerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,38 @@ public void DeserializeWithoutDuplicateKeyChecking_YamlWithDuplicateKeys_DoesNot
act.ShouldNotThrow<YamlException>("Because duplicate key checking is not enabled");
}

[Fact]
public void EnforceNulalbleTypesWhenNullThrowsException()
{
var deserializer = new DeserializerBuilder().WithEnforceNullability().Build();
var yaml = @"
Test: null
";
try
{
var o = deserializer.Deserialize<NonNullableClass>(yaml);
}
catch (YamlException e)
{
if (e.InnerException is NullReferenceException)
{
return;
}
}

throw new Exception("Non nullable property was set to null.");
}

[Fact]
public void EnforceNullableTypesWhenNotNullDoesNotThrowException()
{
var deserializer = new DeserializerBuilder().WithEnforceNullability().Build();
var yaml = @"
Test: test 123
";
var o = deserializer.Deserialize<NonNullableClass>(yaml);
}

[Fact]
public void SerializeStateMethodsGetCalledOnce()
{
Expand All @@ -344,6 +376,13 @@ public void SerializeStateMethodsGetCalledOnce()
Assert.Equal(1, test.OnDeserializingCallCount);
}

#nullable enable
public class NonNullableClass
{
public string Test { get; set; } = "Some default value";
}
#nullable disable

public class TestState
{
public int OnDeserializedCallCount { get; set; }
Expand Down
45 changes: 43 additions & 2 deletions YamlDotNet/ReflectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
// SOFTWARE.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
Expand Down Expand Up @@ -270,16 +271,56 @@ public static Attribute[] GetAllCustomAttributes<TAttribute>(this PropertyInfo m
// on netstandard1.3
var result = new List<Attribute>();
var type = member.DeclaringType;
var name = member.Name;

while (type != null)
{
type.GetPublicProperty(member.Name);
result.AddRange(member.GetCustomAttributes(typeof(TAttribute)));
var property = type.GetPublicProperty(name);

if (property != null)
{
result.AddRange(property.GetCustomAttributes(typeof(TAttribute)));
}

type = type.BaseType();
}

return result.ToArray();
}
private static readonly ConcurrentDictionary<Type, bool> typesHaveNullContext = new ConcurrentDictionary<Type, bool>();
public static bool AcceptsNull(this MemberInfo member)
{
var result = true; //default to allowing nulls, this will be set to false if there is a null context on the type
#if NET8_0_OR_GREATER
var typeHasNullContext = typesHaveNullContext.GetOrAdd(member.DeclaringType, (Type t) =>
{
var attributes = t.GetCustomAttributes(typeof(System.Runtime.CompilerServices.NullableContextAttribute), true);
return (attributes?.Length ?? 0) > 0;
});

if (typeHasNullContext)
{
// we have a nullable context on that type, only allow null if the NullableAttribute is on the member.
var memberAttributes = member.GetCustomAttributes(typeof(System.Runtime.CompilerServices.NullableAttribute), true);
result = (memberAttributes?.Length ?? 0) > 0;
}

return result;
#else
var typeHasNullContext = typesHaveNullContext.GetOrAdd(member.DeclaringType, (Type t) =>
{
var attributes = t.GetCustomAttributes(true);
return attributes.Any(x => x.GetType().FullName == "System.Runtime.CompilerServices.NullableContextAttribute");
});

if (typeHasNullContext)
{
var memberAttributes = member.GetCustomAttributes(true);
result = memberAttributes.Any(x => x.GetType().FullName == "System.Runtime.CompilerServices.NullableAttribute");
}

return result;
#endif
}
}
}
14 changes: 13 additions & 1 deletion YamlDotNet/Serialization/DeserializerBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public sealed class DeserializerBuilder : BuilderSkeleton<DeserializerBuilder>
private bool ignoreUnmatched;
private bool duplicateKeyChecking;
private bool attemptUnknownTypeDeserialization;
private bool enforceNullability;

/// <summary>
/// Initializes a new <see cref="DeserializerBuilder" /> using the default component registrations.
Expand Down Expand Up @@ -103,7 +104,8 @@ public DeserializerBuilder()
ignoreUnmatched,
duplicateKeyChecking,
typeConverter,
enumNamingConvention)
enumNamingConvention,
enforceNullability)
}
};

Expand Down Expand Up @@ -333,6 +335,16 @@ Action<ITrackingRegistrationLocationSelectionSyntax<INodeTypeResolver>> where
return this;
}

/// <summary>
/// Enforce whether null values can be set on non-nullable properties and fields.
/// </summary>
/// <returns>This deserializer builder.</returns>
public DeserializerBuilder WithEnforceNullability()
{
enforceNullability = true;
return this;
}

/// <summary>
/// Unregisters an existing <see cref="INodeTypeResolver" /> of type <typeparam name="TNodeTypeResolver" />.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions YamlDotNet/Serialization/IPropertyDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace YamlDotNet.Serialization
public interface IPropertyDescriptor
{
string Name { get; }
bool AllowNulls { get; }
bool CanWrite { get; }
Type Type { get; }
Type? TypeOverride { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.Serialization;
using YamlDotNet.Core;
using YamlDotNet.Core.Events;
Expand All @@ -36,20 +38,23 @@ public sealed class ObjectNodeDeserializer : INodeDeserializer
private readonly bool duplicateKeyChecking;
private readonly ITypeConverter typeConverter;
private readonly INamingConvention enumNamingConvention;
private readonly bool enforceNullability;

public ObjectNodeDeserializer(IObjectFactory objectFactory,
ITypeInspector typeDescriptor,
bool ignoreUnmatched,
bool duplicateKeyChecking,
ITypeConverter typeConverter,
INamingConvention enumNamingConvention)
INamingConvention enumNamingConvention,
bool enforceNullability)
{
this.objectFactory = objectFactory ?? throw new ArgumentNullException(nameof(objectFactory));
this.typeDescriptor = typeDescriptor ?? throw new ArgumentNullException(nameof(typeDescriptor));
this.ignoreUnmatched = ignoreUnmatched;
this.duplicateKeyChecking = duplicateKeyChecking;
this.typeConverter = typeConverter ?? throw new ArgumentNullException(nameof(typeConverter));
this.enumNamingConvention = enumNamingConvention ?? throw new ArgumentNullException(nameof(enumNamingConvention));
this.enforceNullability = enforceNullability;
}

public bool Deserialize(IParser parser, Type expectedType, Func<IParser, Type, object?> nestedObjectDeserializer, out object? value)
Expand All @@ -59,7 +64,6 @@ public bool Deserialize(IParser parser, Type expectedType, Func<IParser, Type, o
value = null;
return false;
}

// Strip off the nullable type, if present. This is needed for nullable structs.
var implementationType = Nullable.GetUnderlyingType(expectedType) ?? expectedType;

Expand Down Expand Up @@ -90,12 +94,18 @@ public bool Deserialize(IParser parser, Type expectedType, Func<IParser, Type, o
propertyValuePromise.ValueAvailable += v =>
{
var convertedValue = typeConverter.ChangeType(v, property.Type, enumNamingConvention);
NullCheck(convertedValue, property, propertyName);
property.Write(valueRef, convertedValue);
};
}
else
{
var convertedValue = typeConverter.ChangeType(propertyValue, property.Type, enumNamingConvention);

NullCheck(convertedValue, property, propertyName);

property.Write(value, convertedValue);
}
}
Expand All @@ -116,5 +126,15 @@ public bool Deserialize(IParser parser, Type expectedType, Func<IParser, Type, o
objectFactory.ExecuteOnDeserialized(value);
return true;
}

public void NullCheck(object value, IPropertyDescriptor property, Scalar propertyName)
{
if (enforceNullability &&
value == null &&
!property.AllowNulls)
{
throw new YamlException(propertyName.Start, propertyName.End, "Strict nullability enforcement error.", new NullReferenceException("Yaml value is null when target property requires non null values."));
}
}
}
}
Loading

0 comments on commit 08fbe44

Please sign in to comment.