Skip to content

Commit

Permalink
inference: use the same method lookup cache across same inference trial
Browse files Browse the repository at this point in the history
Previously the method lookup result was created per frame and so the
look cache hasn't been use that much. With this change the cache is
created per inference, and so the cached result will be used when we
already saw the same match in the same inference shot, and it may speed
up the lookup time a bit.

This commit also setups new `AbstractInterpreter` interface `get_method_lookup_cache`
which specifies what method lookup cache is used by each `AbstractInterpreter`.
`NativeInterpreter` creates a cache per inference, and so it is valid
since lookup is done in the same world age in the same inference shot.
External `AbstractInterpreter` doesn't opt into this cache by default,
and its behavior won't change in anyway.
  • Loading branch information
aviatesk committed Feb 19, 2022
1 parent f5d9b86 commit 38d8fa1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ mutable struct InferenceState
cache === :global, false, false,
Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE,
inbounds_taints_consistency),
CachedMethodTable(method_table(interp)),
CachedMethodTable(get_method_lookup_cache(interp), method_table(interp)),
interp)
result.result = frame
cache !== :no && push!(get_inference_cache(interp), result)
Expand Down
29 changes: 8 additions & 21 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,6 @@

abstract type MethodTableView; end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
struct InternalMethodTable <: MethodTableView
Expand Down Expand Up @@ -46,12 +30,11 @@ Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T} <: MethodTableView
cache::IdDict{Any, Union{Missing, MethodLookupResult}}
cache::MethodLookupCache
table::T
CachedMethodTable(cache::MethodLookupCache, table::T) where T = new{T}(cache, table)
CachedMethodTable(::Nothing, table::T) where T = new{T}(MethodLookupCache(), table)
end
CachedMethodTable(table::T) where T =
CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodLookupResult}}(),
table)

"""
findall(sig::Type, view::MethodTableView; limit=typemax(Int))
Expand Down Expand Up @@ -92,9 +75,13 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
end

function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int))
if isconcretetype(sig)
# we have equivalent cache in this concrete DataType's hash table, so don't bother to cache it here
return findall(sig, table.table; limit)
end
box = Core.Box(sig)
return get!(table.cache, sig) do
findall(box.contents, table.table; limit=limit)
findall(box.contents, table.table; limit)
end
end

Expand Down
25 changes: 23 additions & 2 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,23 @@ decode_effects_override(e::UInt8) =
(e & 0x08) != 0x00,
(e & 0x10) != 0x00)

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
const MethodLookupCache = IdDict{Any, Union{Missing, MethodLookupResult}}

"""
InferenceResult
Expand Down Expand Up @@ -242,6 +259,8 @@ It contains many parameters used by the compilation pipeline.
struct NativeInterpreter <: AbstractInterpreter
# Cache of inference results for this particular interpreter
cache::Vector{InferenceResult}
# cache of method lookup results
method_lookup_cache::MethodLookupCache
# The world age we're working inside of
world::UInt

Expand All @@ -263,10 +282,10 @@ struct NativeInterpreter <: AbstractInterpreter
# incorrect, fail out loudly.
@assert world <= get_world_counter()


return new(
# Initially empty cache
# Initially empty caches
Vector{InferenceResult}(),
MethodLookupCache(),

# world age counter
world,
Expand Down Expand Up @@ -316,6 +335,8 @@ may_discard_trees(::AbstractInterpreter) = true
verbose_stmt_info(::AbstractInterpreter) = false

method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp))
get_method_lookup_cache(ni::NativeInterpreter) = ni.method_lookup_cache
get_method_lookup_cache(::AbstractInterpreter) = nothing

"""
By default `AbstractInterpreter` implements the following inference bail out logic:
Expand Down

0 comments on commit 38d8fa1

Please sign in to comment.