Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add indirection in == fallback #44678

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Add indirection in == fallback #44678

wants to merge 1 commit into from

Conversation

timholy
Copy link
Sponsor Member

@timholy timholy commented Mar 19, 2022

The fallback for == just calls ===. However, we also have types
like WeakRef that add three specializations of == that just
unwrap the WeakRef. It turns out this pattern is also used in
packages, and one of them, ChainRulesCore, has >2300 dependent
packages as of now. Unfortunately, when ChainRulesCore is loaded,
it invalidates code that calls == on WeakRef. This includes
much of the Serialization stdlib.

Since I have a personal goal of getting Makie's TTFP to less than one second (see #44527), and I recently learned from @SimonDanisch that Makie uses the Serialization stdlib, I will need to find a way to fix this.

This addresses the problem by defining the notion of "unwrapping
for ==" by introducing the non-exported function unwrap_isequal.
The fallback is just

unwrap_isequal(x) = x

but one can add specializations.

Now, the obvious concern about this implementation is that it may hurt performance. In cases of concrete inference it might have no impact, as something that uses the fallback like

julia> struct M
           x::Int
       end

julia> @code_llvm M(2) == M(2)

produces the same code before and after. However, @code_typed shows that the extra steps are not entirely DCEd:

julia> @code_typed M(2) == M(2)
CodeInfo(
1%1 = (x === x)::Bool
└──      goto #5 if not %1
2%3 = (y === y)::Bool
└──      goto #4 if not %3
3%5 = (x === y)::Bool
└──      return %5
4nothing::Nothing
5%8 = invoke Base.:(==)(x::M, y::M)::Bool
└──      return %8
) => Bool

and so this might affect inlining. Moreover, there is also a chance for a performance regression in poorly-inferred code. So benchmarks are crucial: @nanosoldier runbenchmarks(ALL, vs=":master").

To show that this fixes the problem, on master:

julia> using SnoopCompileCore

julia> invs = @snoopr using ChainRulesCore;

julia> using SnoopCompile
[ Info: Precompiling SnoopCompile [aa65fe97-06da-5843-b5b1-d5d13cad87d2]

julia> trees = invalidation_trees(invs)
4-element Vector{SnoopCompile.MethodInvalidations}:
 inserting *(::Any, ::ZeroTangent) in ChainRulesCore at /home/tim/.julia/packages/ChainRulesCore/IzITE/src/tangent_arithmetic.jl:105 invalidated:
   mt_backedges: 1: signature Tuple{typeof(*), String, Any} triggered MethodInstance for (::Test.var"#7#9")(::Any) (0 children)
                 2: signature Tuple{typeof(*), String, Any} triggered MethodInstance for (::Test.var"#8#10")(::Any) (0 children)
   12 mt_cache

 inserting convert(::Type{<:Thunk}, a::AbstractZero) in ChainRulesCore at /home/tim/.julia/packages/ChainRulesCore/IzITE/src/tangent_types/thunks.jl:205 invalidated:
   backedges: 1: superseding convert(::Type{Union{}}, x) in Base at essentials.jl:213 with MethodInstance for convert(::Core.TypeofBottom, ::Any) (1 children)
   15 mt_cache

 inserting ==(a::AbstractThunk, b) in ChainRulesCore at /home/tim/.julia/packages/ChainRulesCore/IzITE/src/tangent_types/thunks.jl:28 invalidated:
   backedges: 1: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Any, ::Task) (4 children)
              2: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Any, ::Base.UUID) (9 children)
              3: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Any, ::FileWatching._FDWatcher) (22 children)

 inserting ==(a, b::AbstractThunk) in ChainRulesCore at /home/tim/.julia/packages/ChainRulesCore/IzITE/src/tangent_types/thunks.jl:29 invalidated:
   backedges: 1: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Base.UUID, ::Any) (6 children)
              2: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Module, ::Any) (12 children)
              3: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Method, ::Any) (24 children)
              4: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Symbol, ::Any) (46 children)
              5: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Core.TypeName, ::Any) (124 children)


julia> tree = trees[end]
inserting ==(a, b::AbstractThunk) in ChainRulesCore at /home/tim/.julia/packages/ChainRulesCore/IzITE/src/tangent_types/thunks.jl:29 invalidated:
   backedges: 1: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Base.UUID, ::Any) (6 children)
              2: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Module, ::Any) (12 children)
              3: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Method, ::Any) (24 children)
              4: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Symbol, ::Any) (46 children)
              5: superseding ==(x, y) in Base at Base.jl:119 with MethodInstance for ==(::Core.TypeName, ::Any) (124 children)

but with this diff in ChainRulesCore:

diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl
index 82f2e415..ffd1bdb8 100644
--- a/src/tangent_types/thunks.jl
+++ b/src/tangent_types/thunks.jl
@@ -24,9 +24,13 @@ end
     return element, (underlying_object, new_state)
 end
 
-Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
-Base.:(==)(a::AbstractThunk, b) = unthunk(a) == b
-Base.:(==)(a, b::AbstractThunk) = a == unthunk(b)
+if isdefined(Base, :unwrap_isequal)
+    Base.unwrap_isequal(a::AbstractThunk) = unthunk(a)
+else
+    Base.:(==)(a::AbstractThunk, b::AbstractThunk) = unthunk(a) == unthunk(b)
+    Base.:(==)(a::AbstractThunk, b) = unthunk(a) == b
+    Base.:(==)(a, b::AbstractThunk) = a == unthunk(b)
+end
 
 Base.:(-)(a::AbstractThunk) = -unthunk(a)
 Base.:(-)(a::AbstractThunk, b) = unthunk(a) - b

one gets

julia> trees = invalidation_trees(invs)
2-element Vector{SnoopCompile.MethodInvalidations}:
 inserting *(::Any, ::ZeroTangent) in ChainRulesCore at /home/tim/.julia/dev/ChainRulesCore/src/tangent_arithmetic.jl:105 invalidated:
   mt_backedges: 1: signature Tuple{typeof(*), String, Any} triggered MethodInstance for (::Test.var"#7#9")(::Any) (0 children)
                 2: signature Tuple{typeof(*), String, Any} triggered MethodInstance for (::Test.var"#8#10")(::Any) (0 children)

 inserting convert(::Type{<:Number}, x::ChainRulesCore.NotImplemented) in ChainRulesCore at /home/tim/.julia/dev/ChainRulesCore/src/tangent_types/notimplemented.jl:63 invalidated:
   backedges: 1: superseding convert(::Type{Union{}}, x) in Base at essentials.jl:213 with MethodInstance for convert(::Core.TypeofBottom, ::Any) (1 children)
   24 mt_cache

The fallback for `==` just calls `===`. However, we also have types
like `WeakRef` that add three specializations of `==` that just
unwrap the `WeakRef`. It turns out this pattern is also used in
packages, and one of them, ChainRulesCore, has >2300 dependent
packages as of now. Unfortunately, when ChainRulesCore is loaded,
it invalidates code that calls `==` on `WeakRef`. This includes
much of the Serialization stdlib.

This addresses the problem by defining the notion of "unwrapping
for ==" by introducing the non-exported function `unwrap_isequal`.
The fallback is just

```
unwrap_isequal(x) = x
```

but one can add specializations.
@timholy timholy added the compiler:latency Compiler latency label Mar 19, 2022
@KristofferC
Copy link
Sponsor Member

Regarding ChainRulesCore, there is a PR to just remove them JuliaDiff/ChainRulesCore.jl#524 because they are "only used for tests".

@nanosoldier
Copy link
Collaborator

Something went wrong when running your job:

NanosoldierError: error when preparing/pushing to report repo: failed process: Process(setenv(`git push`; dir="/nanosoldier/workdir/NanosoldierReports"), ProcessExited(1)) [1]

Unfortunately, the logs could not be uploaded.

@@ -116,7 +116,13 @@ include("range.jl")
include("error.jl")

# core numeric operations & types
==(x, y) = x === y
unwrap_isequal(x) = x
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, if we go with this, it should be used much more broadly:

Suggested change
unwrap_isequal(x) = x
unwrap_for_compare(x) = x

Then we need to go update ==, <, isless, isequal, hash etc.

But perhaps instead I should open an issue about eliminating WeakRef (in favor, similar to the atomic attribute, of being a runtime-settable attribute in Array flags or in the field properties of a struct). Unfortunately, we cannot delete WeakRef until v2 regardless, since we can give a replacement for it, but we cannot remove it, until then.

@Tokazama
Copy link
Contributor

Is the intention to have unwrap_isequal be completely internal or a general tool for reducing invalidations for ==?

@timholy
Copy link
Sponsor Member Author

timholy commented Mar 24, 2022

I was intending that it should be overloaded by packages. However, I do think this has some issues. One of them is that conceptually this seems like it would fail to prevent invalidation if two "confusable" types both specialize unwrap_isequal. I'm actually puzzled that this doesn't seem to happen, I wonder if it's an invalidation bug?

julia> abstract type AbstractMyType end

julia> struct MyType1 <: AbstractMyType
           val
       end

julia> Base.unwrap_isequal(x::MyType1) = x.val

julia> d = Dict{AbstractMyType,Int}(MyType1(1) => 1)
Dict{AbstractMyType, Int64} with 1 entry:
  MyType1(1) => 1

julia> f(x) = (d::Dict{AbstractMyType,Int})[AbstractMyType[x][1]]
f (generic function with 1 method)

julia> f(MyType1(1))
1

julia> code_typed(==, (AbstractMyType, AbstractMyType))
1-element Vector{Any}:
 CodeInfo(
1 ── %1  = (isa)(x, MyType1)::Bool
└───       goto #3 if not %1
2 ── %3  = π (x, MyType1)
│    %4  = Base.getfield(%3, :val)::Any
└───       goto #4
3 ── %6  = Base.unwrap_isequal(x)::Any
└───       goto #4
4 ┄─ %8  = φ (#2 => %4, #3 => %6)::Any%9  = (isa)(y, MyType1)::Bool
└───       goto #6 if not %9
5 ── %11 = π (y, MyType1)
│    %12 = Base.getfield(%11, :val)::Any
└───       goto #7
6 ── %14 = Base.unwrap_isequal(y)::Any
└───       goto #7
7 ┄─ %16 = φ (#5 => %12, #6 => %14)::Any%17 = (%8 === x)::Bool
└───       goto #11 if not %17
8 ── %19 = (%16 === y)::Bool
└───       goto #10 if not %19
9 ── %21 = (x === y)::Bool
└───       return %21
10nothing::Nothing
11%24 = (%8 == %16)::Any
└───       return %24
) => Any

julia> using SnoopCompileCore

julia> struct MyType2 <: AbstractMyType
           val
       end

julia> invs = @snoopr Base.unwrap_isequal(x::MyType2) = x.val
Any[]

julia> f(MyType2(1))
ERROR: KeyError: key MyType2(1) not found
Stacktrace:
 [1] getindex(h::Dict{AbstractMyType, Int64}, key::MyType2)
   @ Base ./dict.jl:516
 [2] f(x::MyType2)
   @ Main ./REPL[5]:1
 [3] top-level scope
   @ REPL[11]:1

julia> code_typed(==, (AbstractMyType, AbstractMyType))
1-element Vector{Any}:
 CodeInfo(
1 ── %1  = (isa)(x, MyType1)::Bool
└───       goto #3 if not %1
2 ── %3  = π (x, MyType1)
│    %4  = Base.getfield(%3, :val)::Any
└───       goto #6
3 ── %6  = (isa)(x, MyType2)::Bool
└───       goto #5 if not %6
4 ── %8  = π (x, MyType2)
│    %9  = Base.getfield(%8, :val)::Any
└───       goto #6
5 ── %11 = Base.unwrap_isequal(x)::Any
└───       goto #6
6 ┄─ %13 = φ (#2 => %4, #4 => %9, #5 => %11)::Any%14 = (isa)(y, MyType1)::Bool
└───       goto #8 if not %14
7 ── %16 = π (y, MyType1)
│    %17 = Base.getfield(%16, :val)::Any
└───       goto #11
8 ── %19 = (isa)(y, MyType2)::Bool
└───       goto #10 if not %19
9 ── %21 = π (y, MyType2)
│    %22 = Base.getfield(%21, :val)::Any
└───       goto #11
10%24 = Base.unwrap_isequal(y)::Any
└───       goto #11
11%26 = φ (#7 => %17, #9 => %22, #10 => %24)::Any%27 = (%13 === x)::Bool
└───       goto #15 if not %27
12%29 = (%26 === y)::Bool
└───       goto #14 if not %29
13%31 = (x === y)::Bool
└───       return %31
14nothing::Nothing
15%34 = (%13 == %26)::Any
└───       return %34
) => Any

julia> struct MyType3 <: AbstractMyType
           val
       end

julia> invs = @snoopr Base.unwrap_isequal(x::MyType3) = x.val
Any[]

julia> code_typed(==, (AbstractMyType, AbstractMyType))
1-element Vector{Any}:
 CodeInfo(
1%1  = Base.unwrap_isequal(x)::Any%2  = Base.unwrap_isequal(y)::Any%3  = (%1 === x)::Bool
└──       goto #5 if not %3
2%5  = (%2 === y)::Bool
└──       goto #4 if not %5
3%7  = (x === y)::Bool
└──       return %7
4nothing::Nothing
5%10 = (%1 == %2)::Any
└──       return %10
) => Any

Why aren't there any invalidations when you can verify that the implementation of == is changing? I guess they are now all runtime dispatches?

@vtjnash
Copy link
Sponsor Member

vtjnash commented Mar 24, 2022

By "confusable type" I assume you mean just Some with Some or nothing or missing? I assume that the Some will need to define Base.unwrap_isequal as well as all combinations of == between Some and a confusable type (e.g. the set (Some,Nothing) or (Nothing,Some) or (Some,Missing) or (Missing,Some) or (Some,Some))

The WeakRef and VecElement types shouldn't be confusable with anything

@vtjnash
Copy link
Sponsor Member

vtjnash commented Mar 24, 2022

The fallback is always a runtime dispatch that returns Any, so it does not need a backedge

@Tokazama
Copy link
Contributor

I like the idea. Could also do unwrap_when(::typeof(isequal), x) if you want to have variable support for other methods.

@Tokazama
Copy link
Contributor

Tokazama commented May 19, 2022

Is there an alternate solution in the works or does this PR just need some more work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:latency Compiler latency
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants