Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradient involving splatting of kwargs #1284

Closed
niklasschmitz opened this issue Aug 7, 2022 · 7 comments · Fixed by #1286
Closed

Wrong gradient involving splatting of kwargs #1284

niklasschmitz opened this issue Aug 7, 2022 · 7 comments · Fixed by #1286
Assignees
Labels
bug Something isn't working

Comments

@niklasschmitz
Copy link

Zygote (v0.6.43) currently gives a wrong gradient involving kwargs splatting. It seems to double-count a gradient contribution through implicit and explicit kwargs. Here's a small example:

f1(; kwargs...) = kwargs[:x]
f2(; kwargs...) = f1(; kwargs..., x=kwargs[:x])
f3(x) = f2(; x)
FiniteDiff.finite_difference_derivative(f3, 0.0) # 1.0
ForwardDiff.derivative(f3, 0.0) # 1.0
Zygote.gradient(f3, 0.0) # (2.0,)
@mcabbott mcabbott added the bug Something isn't working label Aug 7, 2022
@oxinabox
Copy link
Member

oxinabox commented Aug 7, 2022

I wonder if this is to do with kwarg[:x] showing up both in kwarg... and kwarg[:x] ...?

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Aug 7, 2022 via email

@niklasschmitz
Copy link
Author

Interestingly, the issue persists even after eliminating this redundancy:

f1(; kwargs...) = kwargs[:x]
f2(; kwargs...) = f1(; x=kwargs[:x])
f3(x) = f2(; x)
FiniteDiff.finite_difference_derivative(f3, 0.0) # 1.0
ForwardDiff.derivative(f3, 0.0) # 1.0
Zygote.gradient(f3, 0.0) # (2.0,)

@ToucheSir
Copy link
Member

ToucheSir commented Aug 8, 2022

This turned out to be a fun issue 😱 . In short, kwargs are represented as Pairs{..., NamedTuple}. This is an immutable type, but because Pairs <: AbstractDict it triggers the getindex adjoint for mutable dicts. Since objectid works by value and not by reference on immutable types, that means any set of keyword arguments with the same structure + arg types would accumulate to the same gradient.

Given all that, I can't help but wonder if this has been causing other mysterious bugs in the wild. Working on a PR that should hopefully be up soon.

@ChrisRackauckas
Copy link
Member

This seems to cause downstream failures. See

function multiple_shoot(
    p::AbstractArray,
    ode_data::AbstractArray,
    tsteps::AbstractArray,
    ensembleprob::EnsembleProblem,
    ensemblealg::SciMLBase.BasicEnsembleAlgorithm,
    loss_function,
    continuity_loss,
    solver::DiffEqBase.AbstractODEAlgorithm,
    group_size::Integer;
    continuity_term::Real=100,
    kwargs...
)
    datasize = size(ode_data, 2)
    prob = ensembleprob.prob

    if group_size < 2 || group_size > datasize
        throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
    end

    @assert ndims(ode_data) == 3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
    @assert size(ode_data,2) == length(tsteps)
    @show kwargs
    @assert size(ode_data,3) == kwargs[:trajectories]

This then is called like:

function loss_multiple_shooting_ens(p)
    return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
                          loss_function, Tsit5(),
                          group_size; continuity_term,
                          trajectories,
                          abstol=1e-8, reltol=1e-6) # test solver kwargs
end
kwargs = Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}}(:trajectories => 2, :abstol => 1.0e-8, :reltol => 1.0e-6)
ERROR: MethodError: no method matching getindex(::Nothing, ::Int64)
Stacktrace:
  [1] (::Zygote.var"#kwargs_literal_getindex_pullback#326"{Zygote.var"#1925#back#218"{Zygote.var"#back#217"{:trajectories, Zygote.Context{false}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}, Int64}}})(Δ::Nothing)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\qGFGD\src\lib\base.jl:165
  [2] Pullback
    @ c:\Users\accou\.julia\packages\DiffEqFlux\Em1Aj\src\multiple_shooting.jl:185 [inlined]

The line that errors is:

@assert size(ode_data,3) == kwargs[:trajectories]

My symbol is transformed into an integer and kwargs to nothing?

https://github.com/SciML/SciMLSensitivity.jl/runs/8028243222?check_suite_focus=true

using DiffEqFlux, OrdinaryDiffEq, Test

datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2], length=datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat=tsteps))


nn = FastChain((x, p) -> x .^ 3,
    FastDense(2, 16, tanh),
    FastDense(16, 2))
p_init = initial_params(nn)

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat=tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p), u0, tspan, p_init)

function loss_function(data, pred)
    return sum(abs2, data - pred)
end

u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
function prob_func(prob, i, repeat)
    remake(prob, u0=u0s[i])
end
ensemble_prob = EnsembleProblem(prob_node, prob_func=prob_func)
ensemble_prob_trueODE = EnsembleProblem(prob_trueode, prob_func=prob_func)
ensemble_alg = EnsembleThreads()
trajectories = 2
ode_data_ensemble = Array(solve(ensemble_prob_trueODE, Tsit5(), ensemble_alg, trajectories=trajectories, saveat=tsteps))

group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
    return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
        loss_function, Tsit5(),
        group_size; continuity_term,
        trajectories,
        abstol=1e-8, reltol=1e-6) # test solver kwargs
end

res_ms_ensembles = DiffEqFlux.sciml_train(loss_multiple_shooting_ens, neuralode.p,
    ADAM(0.05), maxiters=300)

@ToucheSir
Copy link
Member

Try #1295 on for size.

@ChrisRackauckas
Copy link
Member

That fixes it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants