Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing overload for ZeroTangent #256

Closed

Conversation

simsurace
Copy link

This seems to be missing as the unthunked tangent can sometimes be a ZeroTangent. Does this make sense, @devmotion?

@codecov-commenter
Copy link

codecov-commenter commented Oct 6, 2023

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Files Coverage Δ
ext/DistancesChainRulesCoreExt.jl 98.75% <0.00%> (+98.75%) ⬆️

... and 9 files with indirect coverage changes

📢 Thoughts on this report? Let us know!.

@devmotion
Copy link
Member

It seems a bit restrictive, indeed. But did you ran into an actual problem or is it just a potential issue? In the latter case I would wait with changes until one encounters this case in an application. I also thought that the ZeroTangent case was supposed to be handled by the AD backend since it seems in this case you could always optimize away the whole pullback (but maybe I'm misremembering something? @oxinabox @sethaxen).

ChainRulesTestUtils passed, so in case this is a bug it would be good to generally check this ZeroTangent() example in the tests I think.

@simsurace
Copy link
Author

I linked an issue in KernelFunctions.jl that hits this, if the ZygoteDistancesExt is deactivated. It could be that the AD tests from KernelFunctions.jl are doing something nonstandard, though.

@sethaxen
Copy link

sethaxen commented Oct 6, 2023

I also thought that the ZeroTangent case was supposed to be handled by the AD backend since it seems in this case you could always optimize away the whole pullback (but maybe I'm misremembering something? @oxinabox @sethaxen).

Yes, the ChainRulesCore docs say so: https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/writing_good_rules.html#Ensure-your-pullback-can-accept-the-right-types . If Zygote is erroring, perhaps best to open an issue there or see if one is already opened.

@devmotion
Copy link
Member

Based on the ChainRules docs, my impression is that there's no problem with the rules in Distances but possibly a Zygote issue.

@devmotion devmotion closed this Oct 6, 2023
@simsurace simsurace deleted the chainrules-zerotangent branch October 6, 2023 16:03
@simsurace
Copy link
Author

Thanks for investigating this!

@ToucheSir
Copy link

If Zygote is erroring, perhaps best to open an issue there or see if one is already opened.

Zygote currently should be unthunking every thunk returned by a rrule before it reaches any other AD rules. It also should be catching when a pullback is passed a zero type and not calling said pullback as the CRC docs say (https://github.com/FluxML/Zygote.jl/blob/cf7f7d08705d2787fa31bcf45bcca5447fd9a9a7/src/compiler/chainrules.jl#L214 handles this). Based on that, I would assume ΔΩ is never a thunk or a zero type in practice.

Yes, the ChainRulesCore docs say so: ...

My interpretation of the CRC docs is that this may be a grey area. Yes, rules should not have to handle zeros passed in. But there is no reliable way to tell a priori whether a given thunk will evaluate to a zero type, so that needs to be handled somehow. Rules may also want to handle the case where they get a mix of zeros and non-zeros for a composite return type, e.g. ΔΩ isa Tuple{AbstractZero, Real}. For this particular case, it looks like some upstream rule is stuffing non-numeric zero types into ΔΩ (note the ::Matrix{Any} in (::DistancesChainRulesCoreExt.var"#pairwise_Euclidean_X_pullback#10"{Int64, Matrix{Float64}, Matrix{Float64}})(ΔΩ::Matrix{Any}) on FluxML/Zygote.jl#1460). Finding and solving that should be sufficient to fix this issue. I agree it's not something the Distances rrules should have to deal with, but the problem also doesn't appear to be Zygote ignoring zeros as speculated above.

@devmotion
Copy link
Member

devmotion commented Oct 6, 2023

I suspect the problem here (and the Matrix{Any}) is caused by JuliaDiff/ChainRules.jl#726: JuliaGaussianProcesses/KernelFunctions.jl#528

@simsurace
Copy link
Author

Yes, the issue is absent when adding ChainRules@1.52.1, as mentioned in FluxML/Zygote.jl#1460 (comment)

@simsurace
Copy link
Author

simsurace commented Oct 6, 2023

Basically, with ChainRules > v1.53.0:

julia> f(x) = iszero(x) ? zero(x) : x
f (generic function with 1 method)

julia> using Zygote

julia> Zygote.gradient(f, 0.0)
(nothing,)

whereas with ChainRules v1.52.1:

julia> f(x) = iszero(x) ? zero(x) : x
f (generic function with 1 method)

julia> using Zygote

julia> Zygote.gradient(f, 0.0)
(0.0,)

To my non-expert eyes it looks as if ZeroTangent() is being converted to nothing, which seems to cause issues in some rules, which are expecting numerical types. Is this the core of the problem?

@ToucheSir
Copy link

ToucheSir commented Oct 6, 2023

Zygote has been converting ZeroTangent to nothing since chainrules support was added, mainly because Zygote only understands how to work with nothing internally. However, https://github.com/FluxML/Zygote.jl/blob/v0.6.65/src/compiler/chainrules.jl#L146-L165 exists precisely to make sure rrules see CR zeros and not nothing. So while gradient always converts CR zeros -> nothing because it's a top-level API, that doesn't explain how an intermediate rrule pullback is getting them instead of ZeroTangent or NoTangent. I'm afraid more digging is required 😅.

@devmotion
Copy link
Member

The failing examples all involve a broadcasting path with ForwardDiff.Duals. I assume there has always been a bug there (maybe specifically about the CR types or maybe more generally), and in this specific example it never was a problem because the Matrix you got from the pullbacks was a nice Matrix{Float64}. But with JuliaDiff/ChainRules.jl#726, in the examples now a Matrix{Union{Float64,Nothing}} is returned with nothings/ZeroTangents on the diagonal. Thus for this specific example JuliaDiff/ChainRules.jl#726 seems to be responsible for the test failures but it seems there's a more general Zygote issue that it unmasked.

@ToucheSir
Copy link

I'm not sure if we're referring to the same set of examples, but I've been focusing on the third one in FluxML/Zygote.jl#1460 (comment). That one shouldn't hit the Zygote broadcasting path which uses ForwardDiff, because it hits the CR rule at https://github.com/JuliaDiff/ChainRules.jl/blob/v1.55.0/src/rulesets/Base/mapreduce.jl#L76.

...but it seems there's a more general Zygote issue that it unmasked.

Zygote's own AD rules having a lot of edge cases and not handling data types off the beaten path nicely is well-known I think. That's why they're being slowly removed in favour of better-written rrules in ChainRules or elsewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants