-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dict gradients leading to addition error after broadcasting change #662
Comments
Can reproduce the errors shown. Note BTW that Molly only seems to work (for me) on 1.7, not on Julia 1.8. (But maybe it's an Apple M1 problem.) Without the opt-out, (jl_GvojyS) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GvojyS/Project.toml`
[082447d4] ChainRules v1.44.2
[aa0f7f06] Molly v0.13.0
[e88e6eb3] Zygote v0.6.43
julia> ENV["JULIA_DEBUG"] = ChainRules;
# This isn't in fact called:
julia> @eval Zygote function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
# first check whether there is an `rrule` which handles this directly
direct = rrule(config, f_args...; kwargs...)
f = f_args[1]
direct === nothing || (@info "rrule shortcut" f; return direct)
# create a closure to work around _pullback not accepting kwargs
# but avoid creating a closure unnecessarily (pullbacks of closures do not infer)
y, pb = if !isempty(kwargs)
kwf() = first(f_args)(Base.tail(f_args)...; kwargs...)
_y, _pb = _pullback(config.context, kwf)
_y, Δ -> first(_pb(Δ)).f_args # `first` should be `only`
else
_pullback(config.context, f_args...)
end
ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
return y, ad_pullback
end;
# Note that the T == Bool path is called many times, no @info here
julia> @eval Zygote @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
@info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward" f
return broadcast_forward(f, args...)
end
len = inclen(args)
@info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks" f
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
function ∇broadcasted(ȳ)
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
return y, ∇broadcasted
end
# Random easy test
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└ f = #6 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)
# MWE from above
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: split broadcasting generic
│ f = inject_interaction_list (generic function with 4 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any}) (Error as above.) With the opt-out, it's the second broadcast above, with julia> ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N}
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└ f = #18 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)
julia> gradient((xs, y) -> sum((x -> sin(x/y)).(xs)), [1,2,3]/4, 5/6)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = #22 (generic function with 1 method)
([1.146403786950727, 0.9904027378916139, 0.7459319619247974], -1.609501544552504)
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/lib.jl:326 So plausibly it's a failure of the opt-out mechanism? Where above it uses the |
Now I see. The CR rule accepts any BroadcastStyle, to handle tuples too, while the Zygote one restricts to AbstractArrayStyle. The cases where the CR rule is called all have julia> @eval ChainRules function rrule(cfg::RCR, ::typeof(broadcasted), style::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N}
@debug "called the rrule!" f style
T = Broadcast.combine_eltypes(f, args)
if T === Bool # TODO use nondifftype here
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
@debug("split broadcasting trivial", f, T)
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
return f.(args...), bc_trivial_back
elseif T <: Number && may_bc_derivatives(T, f, args...)
# 2: Fast path: use arguments & result to find derivatives.
return split_bc_derivatives(f, args...)
elseif T <: Number && may_bc_forwards(cfg, f, args...)
# 3: Future path: use `frule_via_ad`?
return split_bc_forwards(cfg, f, args...)
else
# 4: Slow path: collect all the pullbacks & apply them later.
return split_bc_pullbacks(cfg, f, args...)
end
end
rrule (generic function with 1065 methods)
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
┌ Debug: called the rrule!
│ f = inject_interaction (generic function with 7 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: called the rrule!
│ f = inject_interaction_list (generic function with 4 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction_list (generic function with 4 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Debug: called the rrule!
│ f = inject_interaction (generic function with 7 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any}) |
This is not solved by changing the julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}) |
This is a zygote bug. I wish i i could transfer this issue there |
As discussed on Slack there is an issue with dictionaries appearing in gradients. The following is as minimum an example as I could make.
This requires Molly master, Zygote master, ChainRules 1.44.2 and I am using Julia 1.7.2. The file
ala5.pdb
should be put in the current directory and is pasted below.The
ala5.pdb
file:On ChainRules up to 1.42.0 this worked, on 1.43.0-1.44.1 it errors with a different error fixed by #661, and on 1.44.2 it errors as follows:
Adding
ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N}
to Molly as suggested by @mcabbott gives a different error:Commenting out either the
"inter_LJ_weight_14" => 0.5,
or"inter_CO_weight_14" => 0.5,
lines makes it work, presumably because no dictionaries have to be added in the case of one gradient.The text was updated successfully, but these errors were encountered: