From 829fb7cf58c379ba2d3821950fb3f79f3116d2d4 Mon Sep 17 00:00:00 2001 From: Cody Tapscott Date: Tue, 2 Jul 2024 16:31:10 -0400 Subject: [PATCH] TOML: Improve type-stability This changes the output of the TOML parser to provide specialize `Vector{T}` less aggressively, so that combinatorially expensive types like `Vector{Vector{Float64}}` or `Vector{Union{Float64,Int64}}` are instead returned as `Vector{Any}` Vectors of homogeneous leaf types, like `Vector{Float64}` are still supported as before. This change makes the TOML parser fully type-stable, except for its dynamic usage of Dates. Co-authored-by: Gabriel Baraldi --- base/loading.jl | 1 - base/toml_parser.jl | 100 ++++++++++++++++++++----------------- stdlib/TOML/test/values.jl | 2 +- 3 files changed, 54 insertions(+), 49 deletions(-) diff --git a/base/loading.jl b/base/loading.jl index 4dc735f0099d8..e0a2563dd0178 100644 --- a/base/loading.jl +++ b/base/loading.jl @@ -269,7 +269,6 @@ struct TOMLCache{Dates} d::Dict{String, CachedTOMLDict} end TOMLCache(p::TOML.Parser) = TOMLCache(p, Dict{String, CachedTOMLDict}()) -# TODO: Delete this converting constructor once Pkg stops using it TOMLCache(p::TOML.Parser, d::Dict{String, Dict{String, Any}}) = TOMLCache(p, convert(Dict{String, CachedTOMLDict}, d)) const TOML_CACHE = TOMLCache(TOML.Parser{nothing}()) diff --git a/base/toml_parser.jl b/base/toml_parser.jl index cc1455f61928b..4d07cfed05d8a 100644 --- a/base/toml_parser.jl +++ b/base/toml_parser.jl @@ -84,9 +84,6 @@ mutable struct Parser{Dates} # Filled in in case we are parsing a file to improve error messages filepath::Union{String, Nothing} - - # Optionally populate with the Dates stdlib to change the type of Date types returned - Dates::Union{Module, Nothing} # TODO: remove once Pkg is updated end function Parser{Dates}(str::String; filepath=nothing) where {Dates} @@ -106,8 +103,7 @@ function Parser{Dates}(str::String; filepath=nothing) where {Dates} IdSet{Any}(), # static_arrays IdSet{TOMLDict}(), # defined_tables root, - filepath, - nothing + filepath ) startup(l) return l @@ -495,8 +491,10 @@ function recurse_dict!(l::Parser, d::Dict, dotted_keys::AbstractVector{String}, d = d::TOMLDict key = dotted_keys[i] d = get!(TOMLDict, d, key) - if d isa Vector + if d isa Vector{Any} d = d[end] + elseif d isa Vector + return ParserError(ErrKeyAlreadyHasValue) end check && @try check_allowed_add_key(l, d, i == length(dotted_keys)) end @@ -537,7 +535,7 @@ function parse_array_table(l)::Union{Nothing, ParserError} end d = @try recurse_dict!(l, l.root, @view(table_key[1:end-1]), false) k = table_key[end] - old = get!(() -> [], d, k) + old = get!(() -> Any[], d, k) if old isa Vector if old in l.static_arrays return ParserError(ErrAddArrayToStaticArray) @@ -546,7 +544,7 @@ function parse_array_table(l)::Union{Nothing, ParserError} return ParserError(ErrArrayTreatedAsDictionary) end d_new = TOMLDict() - push!(old, d_new) + push!(old::Vector{Any}, d_new) push!(l.defined_tables, d_new) l.active_table = d_new @@ -668,41 +666,20 @@ end # Array # ######### -function push!!(v::Vector, el) - # Since these types are typically non-inferable, they are a big invalidation risk, - # and since it's used by the package-loading infrastructure the cost of invalidation - # is high. Therefore, this is written to reduce the "exposed surface area": e.g., rather - # than writing `T[el]` we write it as `push!(Vector{T}(undef, 1), el)` so that there - # is no ambiguity about what types of objects will be created. - T = eltype(v) - t = typeof(el) - if el isa T || t === T - push!(v, el::T) - return v - elseif T === Union{} - out = Vector{t}(undef, 1) - out[1] = el - return out - else - if T isa Union - newT = Any - else - newT = Union{T, typeof(el)} - end - new = Array{newT}(undef, length(v)) - copy!(new, v) - return push!(new, el) +function copyto_typed!(a::Vector{T}, b::Vector) where T + for i in 1:length(b) + a[i] = b[i]::T end + return nothing end -function parse_array(l::Parser)::Err{Vector} +function parse_array(l::Parser{Dates})::Err{Vector} where Dates skip_ws_nl(l) - array = Vector{Union{}}() + array = Vector{Any}() empty_array = accept(l, ']') while !empty_array v = @try parse_value(l) - # TODO: Worth to function barrier this? - array = push!!(array, v) + array = push!(array, v) # There can be an arbitrary number of newlines and comments before a value and before the closing bracket. skip_ws_nl(l) comma = accept(l, ',') @@ -712,8 +689,40 @@ function parse_array(l::Parser)::Err{Vector} return ParserError(ErrExpectedCommaBetweenItemsArray) end end - push!(l.static_arrays, array) - return array + # check for static type throughout array + T = !isempty(array) ? typeof(array[1]) : Union{} + for el in array + if typeof(el) != T + T = Any + break + end + end + if T === Any + new = array + elseif T === String + new = Array{T}(undef, length(array)) + copyto_typed!(new, array) + elseif T === Bool + new = Array{T}(undef, length(array)) + copyto_typed!(new, array) + elseif T === Int64 + new = Array{T}(undef, length(array)) + copyto_typed!(new, array) + elseif T === UInt64 + new = Array{T}(undef, length(array)) + copyto_typed!(new, array) + elseif T === Float64 + new = Array{T}(undef, length(array)) + copyto_typed!(new, array) + elseif T === Union{} + new = Any[] + elseif (T === TOMLDict) || (T == BigInt) || (T === UInt128) || (T === Int128) || (T <: Vector) || + (T === Dates.Date) || (T === Dates.Time) || (T === Dates.DateTime) + # do nothing, leave as Vector{Any} + new = array + else @assert false end + push!(l.static_arrays, new) + return new end @@ -1025,10 +1034,9 @@ function parse_datetime(l) end function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) where Dates - if Dates !== nothing || p.Dates !== nothing - mod = Dates !== nothing ? Dates : p.Dates + if Dates !== nothing try - return mod.DateTime(year, month, day, h, m, s, ms) + return Dates.DateTime(year, month, day, h, m, s, ms) catch ex ex isa ArgumentError && return ParserError(ErrParsingDateTime) rethrow() @@ -1039,10 +1047,9 @@ function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) wh end function try_return_date(p::Parser{Dates}, year, month, day) where Dates - if Dates !== nothing || p.Dates !== nothing - mod = Dates !== nothing ? Dates : p.Dates + if Dates !== nothing try - return mod.Date(year, month, day) + return Dates.Date(year, month, day) catch ex ex isa ArgumentError && return ParserError(ErrParsingDateTime) rethrow() @@ -1062,10 +1069,9 @@ function parse_local_time(l::Parser) end function try_return_time(p::Parser{Dates}, h, m, s, ms) where Dates - if Dates !== nothing || p.Dates !== nothing - mod = Dates !== nothing ? Dates : p.Dates + if Dates !== nothing try - return mod.Time(h, m, s, ms) + return Dates.Time(h, m, s, ms) catch ex ex isa ArgumentError && return ParserError(ErrParsingDateTime) rethrow() diff --git a/stdlib/TOML/test/values.jl b/stdlib/TOML/test/values.jl index 4fc49d47fc98d..53be1b04708b3 100644 --- a/stdlib/TOML/test/values.jl +++ b/stdlib/TOML/test/values.jl @@ -172,6 +172,6 @@ end @testset "Array" begin @test testval("[1,2,3]", Int64[1,2,3]) @test testval("[1.0, 2.0, 3.0]", Float64[1.0, 2.0, 3.0]) - @test testval("[1.0, 2.0, 3]", Union{Int64, Float64}[1.0, 2.0, Int64(3)]) + @test testval("[1.0, 2.0, 3]", Any[1.0, 2.0, Int64(3)]) @test testval("[1.0, 2, \"foo\"]", Any[1.0, Int64(2), "foo"]) end