-
-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nothing
in output of a pullback
#1464
Comments
The problem can be fixed by M src/lib/broadcast.jl
@@ -295,9 +295,14 @@ end
y = broadcast(x -> value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
- unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
+ unbroadcast(args[i], broadcast((y1, o1) -> y1 === nothing ? nothing : y1 * partials(o1,i), ȳ, out))
+ end
+ # Collapse all `nothing`
+ if dargs isa Tuple{Vararg{Nothing}}
+ return nothing
+ else
+ (nothing, nothing, dargs...) # nothings for broadcasted & f
end
- (nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back Similar fixes could (should?) be applied to many more functions that currently assume that the input to the pullback is completely numeric - but in cases such as the example above when dealing with arrays where some elements are Unfortunately, one additional fix is required though: Summation of the broadcast results in |
One idea for a more general solution would be to add an overload in ZygoteRules here which collapses |
I'm opening this to track an issue that was discussed across different repos.
According to @ToucheSir this should not happen: JuliaStats/Distances.jl#256 (comment)
This problem was exposed by the ChainRules 1.53.0 update.
Simple reproducer:
The text was updated successfully, but these errors were encountered: