diff --git a/src/Venflow/Venflow.Tests/SpecificTypes/CLREnumTests.cs b/src/Venflow/Venflow.Tests/SpecificTypes/CLREnumTests.cs index 460e409d..a0e2447b 100644 --- a/src/Venflow/Venflow.Tests/SpecificTypes/CLREnumTests.cs +++ b/src/Venflow/Venflow.Tests/SpecificTypes/CLREnumTests.cs @@ -6,6 +6,40 @@ namespace Venflow.Tests.SpecificTypes { public class CLREnumTests : TestBase { + [Fact] + public async Task Query() + { + var dummy = new UncommonType + { + CLREnum = DummyEnum.Foo + }; + + Assert.Equal(1, await Database.UncommonTypes.InsertAsync(dummy)); + + dummy = await Database.UncommonTypes.QueryInterpolatedSingle($@"SELECT * FROM ""UncommonTypes"" WHERE ""CLREnum"" = {dummy.CLREnum}").Build().QueryAsync(); + + Assert.Equal(DummyEnum.Foo, dummy.CLREnum); + + await Database.UncommonTypes.DeleteAsync(dummy); + } + + [Fact] + public async Task QueryNullableValue() + { + var dummy = new UncommonType + { + NCLREnum = DummyEnum.Foo + }; + + Assert.Equal(1, await Database.UncommonTypes.InsertAsync(dummy)); + + dummy = await Database.UncommonTypes.QueryInterpolatedSingle($@"SELECT * FROM ""UncommonTypes"" WHERE ""NCLREnum"" = {dummy.NCLREnum}").Build().QueryAsync(); + + Assert.Equal(DummyEnum.Foo, dummy.NCLREnum); + + await Database.UncommonTypes.DeleteAsync(dummy); + } + [Fact] public async Task Insert() { diff --git a/src/Venflow/Venflow/CastTypeHandler.cs b/src/Venflow/Venflow/CastTypeHandler.cs new file mode 100644 index 00000000..b6ac9025 --- /dev/null +++ b/src/Venflow/Venflow/CastTypeHandler.cs @@ -0,0 +1,10 @@ +using Npgsql; + +namespace Venflow +{ + internal class CastTypeHandler : IParameterTypeHandler + { + NpgsqlParameter IParameterTypeHandler.Handle(string name, object val) + => new NpgsqlParameter(name, (T)val); + } +} diff --git a/src/Venflow/Venflow/Database.cs b/src/Venflow/Venflow/Database.cs index dd68f24e..361e44f4 100644 --- a/src/Venflow/Venflow/Database.cs +++ b/src/Venflow/Venflow/Database.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Data; @@ -351,12 +351,12 @@ internal void ExecuteLoggers(IReadOnlyList loggers, NpgsqlComman private void Build() { - var type = this.GetType(); - - if (!DatabaseConfigurationCache.DatabaseConfigurations.TryGetValue(type, out var configuration)) + if (!DatabaseConfigurationCache.DatabaseConfigurations.TryGetValue(this.GetType(), out var configuration)) { lock (DatabaseConfigurationCache.BuildLocker) { + var type = this.GetType(); + if (!DatabaseConfigurationCache.DatabaseConfigurations.TryGetValue(type, out configuration)) { var dbConfigurator = new DatabaseConfigurationFactory(); diff --git a/src/Venflow/Venflow/DatabaseConfigurationOptionsBuilder.cs b/src/Venflow/Venflow/DatabaseConfigurationOptionsBuilder.cs index bf372751..656b82ae 100644 --- a/src/Venflow/Venflow/DatabaseConfigurationOptionsBuilder.cs +++ b/src/Venflow/Venflow/DatabaseConfigurationOptionsBuilder.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; using System.Reflection; +using System.Text; +using Npgsql; using Venflow.Modeling.Definitions; namespace Venflow @@ -59,5 +61,57 @@ public DatabaseConfigurationOptionsBuilder UseConfigurations(params Assembly[] a return this; } + + + /// + /// Maps a PostgreSQL enum to a CLR enum. + /// + /// The type of the enum. + /// The name of the enum in PostgreSQL, if none used it will try to convert the name of the CLR enum e.g. 'FooBar' to 'foo_bar' + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). Defaults to . + /// The same builder instance so that multiple calls can be chained. + public DatabaseConfigurationOptionsBuilder RegisterPostgresEnum(string? name = default, INpgsqlNameTranslator? npgsqlNameTranslator = default) where TEnum : struct, Enum + { + var type = typeof(TEnum); + + if (string.IsNullOrWhiteSpace(name)) + { + var underlyingType = Nullable.GetUnderlyingType(type); + + name = underlyingType is not null ? underlyingType.Name : type.Name; + + var nameBuilder = new StringBuilder(name.Length * 2 - 1); + + nameBuilder.Append(char.ToLowerInvariant(name[0])); + + var nameSpan = name.AsSpan(); + + for (int i = 1; i < nameSpan.Length; i++) + { + var c = nameSpan[i]; + + if (char.IsUpper(c)) + { + nameBuilder.Append('_'); + nameBuilder.Append(char.ToLowerInvariant(c)); + } + else + { + nameBuilder.Append(c); + } + } + + name = nameBuilder.ToString(); + } + + if (!ParameterTypeHandler.PostgreEnums.Contains(type)) + { + NpgsqlConnection.GlobalTypeMapper.MapEnum(name, npgsqlNameTranslator); + + ParameterTypeHandler.PostgreEnums.Add(type); + } + + return this; + } } -} +} \ No newline at end of file diff --git a/src/Venflow/Venflow/Modeling/Definitions/Builder/EntityBuilder.cs b/src/Venflow/Venflow/Modeling/Definitions/Builder/EntityBuilder.cs index a2f621e9..43640515 100644 --- a/src/Venflow/Venflow/Modeling/Definitions/Builder/EntityBuilder.cs +++ b/src/Venflow/Venflow/Modeling/Definitions/Builder/EntityBuilder.cs @@ -113,6 +113,7 @@ IEntityBuilder IEntityBuilder.MapId(Expression IEntityBuilder.MapPostgresEnum(Expression> propertySelector, string? name, INpgsqlNameTranslator? npgsqlNameTranslator) { var property = propertySelector.ValidatePropertySelector(); @@ -122,6 +123,7 @@ IEntityBuilder IEntityBuilder.MapPostgresEnum(Express return this; } + [Obsolete("This method will be removed in the next major version. Please instead use the DatabaseConfigurationOptionsBuilder.RegisterPostgresEnum method on the Database.Configure method.")] IEntityBuilder IEntityBuilder.MapPostgresEnum(Expression> propertySelector, string? name, INpgsqlNameTranslator? npgsqlNameTranslator) { var property = propertySelector.ValidatePropertySelector(); @@ -164,11 +166,11 @@ private void MapPostgresEnum(PropertyInfo property, string? name, INpgs name = nameBuilder.ToString(); } - if (!PostgreSQLEnums.Contains(name)) + if (!ParameterTypeHandler.PostgreEnums.Contains(property.PropertyType)) { NpgsqlConnection.GlobalTypeMapper.MapEnum(name, npgsqlNameTranslator); - PostgreSQLEnums.Add(name); + ParameterTypeHandler.PostgreEnums.Add(property.PropertyType); } ColumnDefinitions[property.Name].Options |= ColumnOptions.PostgreEnum; @@ -443,13 +445,6 @@ internal abstract class EntityBuilder { internal static uint RelationCounter { get; set; } - internal static ConcurrentBag PostgreSQLEnums { get; } - - static EntityBuilder() - { - PostgreSQLEnums = new ConcurrentBag(); - } - internal List Relations { get; } internal abstract Type Type { get; } @@ -508,6 +503,7 @@ protected EntityBuilder() /// The name of the enum in PostgreSQL, if none used it will try to convert the name of the CLR enum e.g. 'FooBar' to 'foo_bar' /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). Defaults to . /// The same builder instance so that multiple calls can be chained. + [Obsolete("This method will be removed in the next major version. Please instead use the DatabaseConfigurationOptionsBuilder.RegisterPostgresEnum method on the Database.Configure method.")] IEntityBuilder MapPostgresEnum(Expression> propertySelector, string? name = default, INpgsqlNameTranslator? npgsqlNameTranslator = default) where TTarget : struct, Enum; @@ -519,6 +515,7 @@ IEntityBuilder MapPostgresEnum(ExpressionThe name of the enum in PostgreSQL, if none used it will try to convert the name of the CLR enum e.g. 'FooBar' to 'foo_bar' /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). Defaults to . /// The same builder instance so that multiple calls can be chained. + [Obsolete("This method will be removed in the next major version. Please instead use the DatabaseConfigurationOptionsBuilder.RegisterPostgresEnum method on the Database.Configure method.")] IEntityBuilder MapPostgresEnum(Expression> propertySelector, string? name = default, INpgsqlNameTranslator? npgsqlNameTranslator = default) where TTarget : struct, Enum; } diff --git a/src/Venflow/Venflow/ParameterTypeHandler.cs b/src/Venflow/Venflow/ParameterTypeHandler.cs index 84d27a03..3ff440b8 100644 --- a/src/Venflow/Venflow/ParameterTypeHandler.cs +++ b/src/Venflow/Venflow/ParameterTypeHandler.cs @@ -1,5 +1,5 @@ using System; -using System.Runtime.CompilerServices; +using System.Collections.Generic; using Npgsql; namespace Venflow @@ -9,13 +9,19 @@ namespace Venflow /// public static class ParameterTypeHandler { - private readonly static ConditionalWeakTable _typeHandlers = new ConditionalWeakTable(); + internal static HashSet PostgreEnums => _postgreEnums; + + private readonly static Dictionary _typeHandlers = new Dictionary(); + private readonly static Dictionary _castHandlers = new Dictionary(1); + private readonly static HashSet _postgreEnums = new HashSet(0); static ParameterTypeHandler() { var uInt64Handler = new UInt64Handler(); AddTypeHandler(typeof(ulong), uInt64Handler); AddTypeHandler(typeof(ulong?), uInt64Handler); + + _castHandlers.Add(typeof(ulong), uInt64Handler); } /// @@ -28,6 +34,9 @@ public static void AddTypeHandler(Type type, IParameterTypeHandler typeHandler) internal static NpgsqlParameter HandleParameter(string name, object? val) { + Type? type = null; + IParameterTypeHandler? handler; + switch (val) { case null: @@ -35,9 +44,35 @@ internal static NpgsqlParameter HandleParameter(string name, object? val) case IKey key: val = key.BoxedValue; break; + case Enum: + type = val.GetType(); + + var tempType = Nullable.GetUnderlyingType(type) ?? type; + + if (!_postgreEnums.Contains(tempType)) + { + if (!_typeHandlers.TryGetValue(tempType, out handler)) + { + var underlyingType = tempType.GetEnumUnderlyingType(); + + if (!_castHandlers.TryGetValue(underlyingType, out handler)) + { + handler = (IParameterTypeHandler)Activator.CreateInstance(typeof(CastTypeHandler<>).MakeGenericType(underlyingType)); + + _castHandlers.Add(underlyingType, handler); + } + + _typeHandlers.Add(tempType, handler); + } + + return handler.Handle(name, val); + } + break; } - if (!_typeHandlers.TryGetValue(val.GetType(), out var handler)) + type ??= val.GetType(); + + if (!_typeHandlers.TryGetValue(type, out handler)) return new NpgsqlParameter(name, val); return handler.Handle(name, val); @@ -45,13 +80,13 @@ internal static NpgsqlParameter HandleParameter(string name, object? val) internal static NpgsqlParameter HandleParameter(string name, T? val) { + Type? type = null; IParameterTypeHandler? handler; switch (val) { case null: return new NpgsqlParameter(name, DBNull.Value); - case IKey key: var tempVal = key.BoxedValue; @@ -59,9 +94,35 @@ internal static NpgsqlParameter HandleParameter(string name, T? val) return new NpgsqlParameter(name, tempVal); return handler.Handle(name, tempVal); + case Enum: + type = val.GetType(); + + var tempType = Nullable.GetUnderlyingType(type) ?? type; + + if (!_postgreEnums.Contains(tempType)) + { + if (!_typeHandlers.TryGetValue(tempType, out handler)) + { + var underlyingType = tempType.GetEnumUnderlyingType(); + + if (!_castHandlers.TryGetValue(underlyingType, out handler)) + { + handler = (IParameterTypeHandler)Activator.CreateInstance(typeof(CastTypeHandler<>).MakeGenericType(underlyingType)); + + _castHandlers.Add(underlyingType, handler); + } + + _typeHandlers.Add(tempType, handler); + } + + return handler.Handle(name, val); + } + break; } - if (!_typeHandlers.TryGetValue(val.GetType(), out handler)) + type ??= val.GetType(); + + if (!_typeHandlers.TryGetValue(type, out handler)) return new NpgsqlParameter(name, val); return handler.Handle(name, val); @@ -69,6 +130,8 @@ internal static NpgsqlParameter HandleParameter(string name, T? val) internal static NpgsqlParameter HandleParameter(string name, Type type, object? val) { + IParameterTypeHandler? handler; + switch (val) { case null: @@ -76,9 +139,31 @@ internal static NpgsqlParameter HandleParameter(string name, Type type, object? case IKey key: val = key.BoxedValue; break; + case Enum: + var tempType = Nullable.GetUnderlyingType(type) ?? type; + + if (!_postgreEnums.Contains(tempType)) + { + if (!_typeHandlers.TryGetValue(tempType, out handler)) + { + var underlyingType = tempType.GetEnumUnderlyingType(); + + if (!_castHandlers.TryGetValue(underlyingType, out handler)) + { + handler = (IParameterTypeHandler)Activator.CreateInstance(typeof(CastTypeHandler<>).MakeGenericType(underlyingType)); + + _castHandlers.Add(underlyingType, handler); + } + + _typeHandlers.Add(tempType, handler); + } + + return handler.Handle(name, val); + } + break; } - if (!_typeHandlers.TryGetValue(type, out var handler)) + if (!_typeHandlers.TryGetValue(type, out handler)) return new NpgsqlParameter(name, val); return handler.Handle(name, val);