Skip to content

Commit

Permalink
Artifacts: Improve type-stability (#55707)
Browse files Browse the repository at this point in the history
This improves Artifacts.jl to make `artifact"..."` fully type-stable, so
that it can be used with `--trim`.

This is a requirement for JLL support w/ trimmed executables.

Dependent on #55016

---------

Co-authored-by: Gabriel Baraldi <baraldigabriel@gmail.com>
  • Loading branch information
topolarity and gbaraldi committed Sep 7, 2024
1 parent fa1c6b2 commit e95eedd
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 65 deletions.
1 change: 0 additions & 1 deletion base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ struct TOMLCache{Dates}
d::Dict{String, CachedTOMLDict}
end
TOMLCache(p::TOML.Parser) = TOMLCache(p, Dict{String, CachedTOMLDict}())
# TODO: Delete this converting constructor once Pkg stops using it
TOMLCache(p::TOML.Parser, d::Dict{String, Dict{String, Any}}) = TOMLCache(p, convert(Dict{String, CachedTOMLDict}, d))

const TOML_CACHE = TOMLCache(TOML.Parser{nothing}())
Expand Down
100 changes: 53 additions & 47 deletions base/toml_parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ mutable struct Parser{Dates}

# Filled in in case we are parsing a file to improve error messages
filepath::Union{String, Nothing}

# Optionally populate with the Dates stdlib to change the type of Date types returned
Dates::Union{Module, Nothing} # TODO: remove once Pkg is updated
end

function Parser{Dates}(str::String; filepath=nothing) where {Dates}
Expand All @@ -106,8 +103,7 @@ function Parser{Dates}(str::String; filepath=nothing) where {Dates}
IdSet{Any}(), # static_arrays
IdSet{TOMLDict}(), # defined_tables
root,
filepath,
nothing
filepath
)
startup(l)
return l
Expand Down Expand Up @@ -495,8 +491,10 @@ function recurse_dict!(l::Parser, d::Dict, dotted_keys::AbstractVector{String},
d = d::TOMLDict
key = dotted_keys[i]
d = get!(TOMLDict, d, key)
if d isa Vector
if d isa Vector{Any}
d = d[end]
elseif d isa Vector
return ParserError(ErrKeyAlreadyHasValue)
end
check && @try check_allowed_add_key(l, d, i == length(dotted_keys))
end
Expand Down Expand Up @@ -537,7 +535,7 @@ function parse_array_table(l)::Union{Nothing, ParserError}
end
d = @try recurse_dict!(l, l.root, @view(table_key[1:end-1]), false)
k = table_key[end]
old = get!(() -> [], d, k)
old = get!(() -> Any[], d, k)
if old isa Vector
if old in l.static_arrays
return ParserError(ErrAddArrayToStaticArray)
Expand All @@ -546,7 +544,7 @@ function parse_array_table(l)::Union{Nothing, ParserError}
return ParserError(ErrArrayTreatedAsDictionary)
end
d_new = TOMLDict()
push!(old, d_new)
push!(old::Vector{Any}, d_new)
push!(l.defined_tables, d_new)
l.active_table = d_new

Expand Down Expand Up @@ -668,41 +666,20 @@ end
# Array #
#########

function push!!(v::Vector, el)
# Since these types are typically non-inferable, they are a big invalidation risk,
# and since it's used by the package-loading infrastructure the cost of invalidation
# is high. Therefore, this is written to reduce the "exposed surface area": e.g., rather
# than writing `T[el]` we write it as `push!(Vector{T}(undef, 1), el)` so that there
# is no ambiguity about what types of objects will be created.
T = eltype(v)
t = typeof(el)
if el isa T || t === T
push!(v, el::T)
return v
elseif T === Union{}
out = Vector{t}(undef, 1)
out[1] = el
return out
else
if T isa Union
newT = Any
else
newT = Union{T, typeof(el)}
end
new = Array{newT}(undef, length(v))
copy!(new, v)
return push!(new, el)
function copyto_typed!(a::Vector{T}, b::Vector) where T
for i in 1:length(b)
a[i] = b[i]::T
end
return nothing
end

function parse_array(l::Parser)::Err{Vector}
function parse_array(l::Parser{Dates})::Err{Vector} where Dates
skip_ws_nl(l)
array = Vector{Union{}}()
array = Vector{Any}()
empty_array = accept(l, ']')
while !empty_array
v = @try parse_value(l)
# TODO: Worth to function barrier this?
array = push!!(array, v)
array = push!(array, v)
# There can be an arbitrary number of newlines and comments before a value and before the closing bracket.
skip_ws_nl(l)
comma = accept(l, ',')
Expand All @@ -712,8 +689,40 @@ function parse_array(l::Parser)::Err{Vector}
return ParserError(ErrExpectedCommaBetweenItemsArray)
end
end
push!(l.static_arrays, array)
return array
# check for static type throughout array
T = !isempty(array) ? typeof(array[1]) : Union{}
for el in array
if typeof(el) != T
T = Any
break
end
end
if T === Any
new = array
elseif T === String
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Bool
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Int64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === UInt64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Float64
new = Array{T}(undef, length(array))
copyto_typed!(new, array)
elseif T === Union{}
new = Any[]
elseif (T === TOMLDict) || (T == BigInt) || (T === UInt128) || (T === Int128) || (T <: Vector) ||
(T === Dates.Date) || (T === Dates.Time) || (T === Dates.DateTime)
# do nothing, leave as Vector{Any}
new = array
else @assert false end
push!(l.static_arrays, new)
return new
end


Expand Down Expand Up @@ -1025,10 +1034,9 @@ function parse_datetime(l)
end

function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.DateTime(year, month, day, h, m, s, ms)
return Dates.DateTime(year, month, day, h, m, s, ms)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand All @@ -1039,10 +1047,9 @@ function try_return_datetime(p::Parser{Dates}, year, month, day, h, m, s, ms) wh
end

function try_return_date(p::Parser{Dates}, year, month, day) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.Date(year, month, day)
return Dates.Date(year, month, day)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand All @@ -1062,10 +1069,9 @@ function parse_local_time(l::Parser)
end

function try_return_time(p::Parser{Dates}, h, m, s, ms) where Dates
if Dates !== nothing || p.Dates !== nothing
mod = Dates !== nothing ? Dates : p.Dates
if Dates !== nothing
try
return mod.Time(h, m, s, ms)
return Dates.Time(h, m, s, ms)
catch ex
ex isa ArgumentError && return ParserError(ErrParsingDateTime)
rethrow()
Expand Down
30 changes: 14 additions & 16 deletions stdlib/Artifacts/src/Artifacts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,11 @@ function load_overrides(;force::Bool = false)::Dict{Symbol, Any}
end
end

overrides = Dict{Symbol,Any}(
# Overrides by UUID
:UUID => overrides_uuid,

# Overrides by hash
:hash => overrides_hash
)
overrides = Dict{Symbol,Any}()
# Overrides by UUID
overrides[:UUID] = overrides_uuid
# Overrides by hash
overrides[:hash] = overrides_hash

ARTIFACT_OVERRIDES[] = overrides
return overrides
Expand Down Expand Up @@ -351,7 +349,7 @@ function process_overrides(artifact_dict::Dict, pkg_uuid::Base.UUID)

# If we've got a platform-specific friend, override all hashes:
artifact_dict_name = artifact_dict[name]
if isa(artifact_dict_name, Array)
if isa(artifact_dict_name, Vector{Any})
for entry in artifact_dict_name
entry = entry::Dict{String,Any}
hash = SHA1(entry["git-tree-sha1"]::String)
Expand Down Expand Up @@ -544,7 +542,7 @@ function jointail(dir, tail)
end
end

function _artifact_str(__module__, artifacts_toml, name, path_tail, artifact_dict, hash, platform, @nospecialize(lazyartifacts))
function _artifact_str(__module__, artifacts_toml, name, path_tail, artifact_dict, hash, platform, ::Val{LazyArtifacts}) where LazyArtifacts
pkg = Base.PkgId(__module__)
if pkg.uuid !== nothing
# Process overrides for this UUID, if we know what it is
Expand All @@ -563,11 +561,11 @@ function _artifact_str(__module__, artifacts_toml, name, path_tail, artifact_dic
# If not, try determining what went wrong:
meta = artifact_meta(name, artifact_dict, artifacts_toml; platform)
if meta !== nothing && get(meta, "lazy", false)
if lazyartifacts isa Module && isdefined(lazyartifacts, :ensure_artifact_installed)
if nameof(lazyartifacts) in (:Pkg, :Artifacts)
if LazyArtifacts isa Module && isdefined(LazyArtifacts, :ensure_artifact_installed)
if nameof(LazyArtifacts) in (:Pkg, :Artifacts)
Base.depwarn("using Pkg instead of using LazyArtifacts is deprecated", :var"@artifact_str", force=true)
end
return jointail(lazyartifacts.ensure_artifact_installed(string(name), meta, artifacts_toml; platform), path_tail)
return jointail(LazyArtifacts.ensure_artifact_installed(string(name), meta, artifacts_toml; platform), path_tail)
end
error("Artifact $(repr(name)) is a lazy artifact; package developers must call `using LazyArtifacts` in $(__module__) before using lazy artifacts.")
end
Expand Down Expand Up @@ -699,10 +697,10 @@ macro artifact_str(name, platform=nothing)

# Check if the user has provided `LazyArtifacts`, and thus supports lazy artifacts
# If not, check to see if `Pkg` or `Pkg.Artifacts` has been imported.
lazyartifacts = nothing
LazyArtifacts = nothing
for module_name in (:LazyArtifacts, :Pkg, :Artifacts)
if isdefined(__module__, module_name)
lazyartifacts = GlobalRef(__module__, module_name)
LazyArtifacts = GlobalRef(__module__, module_name)
break
end
end
Expand All @@ -714,7 +712,7 @@ macro artifact_str(name, platform=nothing)
platform = HostPlatform()
artifact_name, artifact_path_tail, hash = artifact_slash_lookup(name, artifact_dict, artifacts_toml, platform)
return quote
Base.invokelatest(_artifact_str, $(__module__), $(artifacts_toml), $(artifact_name), $(artifact_path_tail), $(artifact_dict), $(hash), $(platform), $(lazyartifacts))::String
Base.invokelatest(_artifact_str, $(__module__), $(artifacts_toml), $(artifact_name), $(artifact_path_tail), $(artifact_dict), $(hash), $(platform), Val($(LazyArtifacts)))::String
end
else
if platform === nothing
Expand All @@ -723,7 +721,7 @@ macro artifact_str(name, platform=nothing)
return quote
local platform = $(esc(platform))
local artifact_name, artifact_path_tail, hash = artifact_slash_lookup($(esc(name)), $(artifact_dict), $(artifacts_toml), platform)
Base.invokelatest(_artifact_str, $(__module__), $(artifacts_toml), artifact_name, artifact_path_tail, $(artifact_dict), hash, platform, $(lazyartifacts))::String
Base.invokelatest(_artifact_str, $(__module__), $(artifacts_toml), artifact_name, artifact_path_tail, $(artifact_dict), hash, platform, Val($(LazyArtifacts)))::String
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion stdlib/TOML/test/values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,6 @@ end
@testset "Array" begin
@test testval("[1,2,3]", Int64[1,2,3])
@test testval("[1.0, 2.0, 3.0]", Float64[1.0, 2.0, 3.0])
@test testval("[1.0, 2.0, 3]", Union{Int64, Float64}[1.0, 2.0, Int64(3)])
@test testval("[1.0, 2.0, 3]", Any[1.0, 2.0, Int64(3)])
@test testval("[1.0, 2, \"foo\"]", Any[1.0, Int64(2), "foo"])
end

0 comments on commit e95eedd

Please sign in to comment.