Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

calculating 2nd order differentials of a function containing a NN w.r.t. to parameters #911

Open
stash-196 opened this issue Feb 25, 2021 · 4 comments
Labels
second order zygote over zygote, or otherwise

Comments

@stash-196
Copy link

stash-196 commented Feb 25, 2021

Do you support any way to calculate the 2nd order differential of a function containing a Neural Network, w.r.t. the parameters?

For instance, consider a loss function defined as below.

using Flux: Chain, Dense, σ, crossentropy, params
using Zygote
model = Chain(
    x -> reshape(x, :, size(x, 4)),
    Dense(2, 5),
    Dense(5, 1),
    x -> σ.(x)
)
n_data = 5
input = randn(2, 1, 1, n_data)
target = randn(1, n_data)
loss = model -> crossentropy(model(input), target)

My goal is to obtain the hessian of loss(model) w.r.t. Flux.params(model). Ideally I would like to have a function like hessian(loss, model), which is currently not supported in Zygote.

In order to construct this, I tried some approaches combining Zygote.gradient, and Zygote.jacobian(added in #890)

  1. Of course, simply combining the two did not work
zygrad = model -> Zygote.gradient(loss, model)
zyjacob = model -> Zygote.jacobian(zygrad, model)
zyjacob(model)    # ERROR: ArgumentError: jacobian expected a function which returns an array, or a scalar, got Tuple{NamedTuple{(:layers,),Tuple{Tuple{Nothing,NamedTuple{(:W, :b, :σ),Tuple{Array{Float64,2},Array{Float64,1},Nothing}},NamedTuple{(:W, :b, :σ),Tuple{Array{Float64,2},Array{Float64,1},Nothing}},Nothing}}}}

# or more explicitly, obtaining a jacobian of one of the weight matrices
zygrad_w1 = model -> Zygote.gradient(loss, model)[1][1][2][1]
zyjacob_w1 = model -> Zygote.jacobian(zygrad, model)
zyjacob_w1(model)    # ERROR: Can't differentiate foreigncall expression
  1. ...or combining with implicit gradient
zygrad = θ -> Zygote.gradient(() -> loss(model), θ)[θ[1]]
zyjacob = θ -> Zygote.jacobian(() -> zygrad_implicit(θ), θ)
zyjacob(params(model))    # ERROR: Can't differentiate foreigncall expression

# or a different combination...
zyjacob = θ -> Zygote.jacobian(zygrad_implicit, θ)
zyjacob(params(model))    # ERROR: MethodError: no method matching (::var"#75#77")()
  1. So I just thought I would make my own jacobian, by element-wise 2nd order gradients, but that didn't work either
zygrad_p1 = model -> Zygote.gradient(loss, model)[1][1][2][1][1]    #Fist element
zygrad2 = model -> Zygote.gradient(zygrad_p1, model)

I'm still learning the framework and I will continue to do so, but as of now I've been working on this for days and exhausted my ideas for a workaround.

Calculating 2nd order differentials is essential in my field of work and I would like to know if Zygote will support it in the near future and whether there are any known solutions to this problem.

@stash-196
Copy link
Author

For reference, I've been using a hessian_wrt_all_params(func, model) till now implemented in pytorch.

I just wanted to implement it in julia.

@mcabbott
Copy link
Member

I don't think there's a way to do this with Zygote's implicit parameter dictionary, params(model). Or at least, nobody has written one. But if I'm reading correctly, your Python code steps through the parameters but builds up the entire Hessian matrix, which isn't something it makes sense to store per-parameter, since it has many off-diagonal blocks. If you need this whole matrix, you can get it by doing something like this:

v, re = Flux.destructure(model)        # length(v) == sum(length, params(model))
loss(model,x,y) = sum(abs2.(model(x) .- y))
g = Zygote.gradient(v -> loss(re(v),x,y), v)[1]  # length(g) == length(v)
h = Zygote.hessian(v -> loss(re(v),x,y), v)      # size(h) == (length(v), length(v))

@DhairyaLGandhi
Copy link
Member

Ref #823

@stash-196
Copy link
Author

stash-196 commented May 20, 2021

@mcabbott Great! The entire Hessian matrix is exactly what I needed in the end.

The python code looks messy because all methods of gradients I could find calculated for each layer(resulting in blocks), and whenever I tried flattening the parameters first it would break. When I got it to work, I didn't dare to touch it.

I never found Flux.destructure() during my investigation. This changes everything. I will try it. Thank you! :)

@mcabbott mcabbott added the second order zygote over zygote, or otherwise label Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
second order zygote over zygote, or otherwise
Projects
None yet
Development

No branches or pull requests

3 participants