-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test_add!!_behaviour
has strong assumptions on fields
#267
Comments
While broadcast would fix that specific case, i think better is a call to However, I am not 100% sure this is actually a problem. using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
collect_rrule(Δ::AbstractArray) = NoTangent(), Tangent{X}(data = Tangent{NTuple{T, N}}(ntuple(i -> Δ[i], Val(L))...))
return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A) Or much simpler use a natural tangent: using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
collect_rrule(Δ::AbstractArray) = NoTangent(), Δ
return collect(x), collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A) Strictly speaking one should use projection: using ChainRulesCore, ChainRulesTestUtils, StructArrays
function ChainRulesCore.rrule(::typeof(collect), x::X) where {S, T, N, L, X<:SArray{S, T, N, L}}
y = collect(x)
proj = ProjectTo(y)
collect_rrule(Δ::AbstractArray) = NoTangent(), proj(Δ)
return y, collect_rrule
end
A = SArray{Tuple{3, 1, 2}}(ntuple(i -> 3.5i, 6))
test_rrule(collect, A) |
Thanks for the investigation, it is really helpful! I have indeed started to look more into |
When calling
test_rrule
on a struct containingTuple
as its fields (e.g.StructArray
orSArray
),_test_cotangent
will fail due to the impossibility of addingTuple
together.Using broadcast on this line: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/ed9a0073ff83cb3b1f4619303e41f4dd5d8c4825/src/tangent_types/tangent.jl#L301 would solve the issue I think.
Here is a MWE.
The text was updated successfully, but these errors were encountered: