Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

usage of isposdef leads to cryptic error when taking gradients #1240

Open
Maximilian-Stefan-Ernst opened this issue Jun 9, 2022 · 3 comments
Labels
needs adjoint missing rule

Comments

@Maximilian-Stefan-Ernst
Copy link

Not shure if this belongs here or to ChainRules, but taking the gradient of functions that use isposdef fails with ERROR: MethodError: no method matching iterate(::Nothing). See this MWE:

using Zygote, LinearAlgebra

# a is actually positive definite
a = rand(3,3); a = a*a'

# works fine
function f(x)
    return x
end

gradient(f, 1.0)

# errors
function f(x)
    isposdef(a)
    return x
end

gradient(f, 1.0)

ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at /usr/share/julia/base/range.jl:826
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at /usr/share/julia/base/range.jl:826
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at /usr/share/julia/base/dict.jl:695
  ...
Stacktrace:
  [1] indexed_iterate(I::Nothing, i::Int64)
    @ Base ./tuple.jl:92
  [2] chain_rrule_kw
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/chainrules.jl:229 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0 [inlined]
  [4] _pullback(::Zygote.Context, ::LinearAlgebra.var"#cholesky##kw", ::NamedTuple{(:check,), Tuple{Bool}}, ::typeof(cholesky), ::Hermitian{Float64, Matrix{Float64}})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:9
  [5] _pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/dense.jl:92 [inlined]
  [6] _pullback(ctx::Zygote.Context, f::typeof(isposdef), args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [7] _pullback
    @ ~/Downloads/test_zygote.jl:151 [inlined]
  [8] _pullback(ctx::Zygote.Context, f::typeof(f), args::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
  [9] _pullback(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:34
 [10] pullback(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:40
 [11] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface.jl:75
 [12] top-level scope
    @ ~/Downloads/test_zygote.jl:155

The problem seems to be calling isposdef and Hermitian, as this fails too:

function f(x)
    isposdef(cholesky(Hermitian(a); check = false))
    return x
end

gradient(f, 1.0)

But interestingly enough, this works just fine:

function f(x)
    isposdef(cholesky(Symmetric(a); check = false))
    return x
end

gradient(f, 1.0)

function f(x)
    Hermitian(a)
    return x
end

gradient(f, 1.0)

Version info:

Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i5-8265U CPU @ 1.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 1

Package versions:

  [082447d4] ChainRules v1.35.1
  [d360d2e6] ChainRulesCore v1.15.0
  [a93c6f00] DataFrames v1.3.4
  [31c24e10] Distributions v0.25.62
  [6a86dc24] FiniteDiff v2.12.1
  [f6369f11] ForwardDiff v0.10.30
  [d3d80556] LineSearches v7.1.1
  [d41bc354] NLSolversBase v7.8.2
  [76087f3c] NLopt v0.6.5
  [429524aa] Optim v1.7.0
  [08abe8d2] PrettyTables v1.3.1
  [2913bbd2] StatsBase v0.33.16
  [78862bba] StenoGraphs v0.2.0
  [0c5d862f] Symbolics v4.6.0
  [e88e6eb3] Zygote v0.6.40
  [8bb1440f] DelimitedFiles
  [4af54fe1] LazyArtifacts
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg
  [9a3f8284] Random
  [2f01184e] SparseArrays
  [10745b16] Statistics
@ToucheSir
Copy link
Member

I suspect a @non_differentiable isposdef(...) in ChainRules would be sufficient to resolve this. Certainly there's no point for AD to dig into that function.

@devmotion
Copy link
Collaborator

The cholesky(Hermitian(...)) issue should probably be fixed by #1114 on the master branch.

@ToucheSir
Copy link
Member

At least on nightly, the test for logdet(cholesky(Hermitian(...))) is still failing: https://github.com/FluxML/Zygote.jl/runs/6971643805?check_suite_focus=true

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs adjoint missing rule
Projects
None yet
Development

No branches or pull requests

4 participants