Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tangent type of arrays of (named) tuples from FD #224

Merged
merged 5 commits into from
Dec 8, 2021

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Nov 9, 2021

Currently, arrays of tuples and named tuples as a result from FiniteDifferences are not converted to the correct Tangent types. This causes the issue in JuliaMath/ChangesOfVariables.jl#4 (comment) and is fixed by this PR (confirmed locally).

Deeply nested arrays of (named) tuples will still have an incorrect, so maybe a (better?) alternative would be to always recurse arrays until one hits a AbstractArray{<:Real} (which would be returned unmodified, also to avoid copies in this case) or a non-array base case such as Tuple, NamedTuple, and Any.

The PR replaces the _maybe_fix_to_composite with ProjectTo.

It also seems that the function is not tested currently. Should I add some tests?

@codecov-commenter
Copy link

codecov-commenter commented Nov 9, 2021

Codecov Report

Merging #224 (8dc5f8d) into main (dd0e246) will increase coverage by 0.08%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #224      +/-   ##
==========================================
+ Coverage   90.75%   90.84%   +0.08%     
==========================================
  Files          11       11              
  Lines         303      295       -8     
==========================================
- Hits          275      268       -7     
+ Misses         28       27       -1     
Impacted Files Coverage Δ
src/finite_difference_calls.jl 100.00% <100.00%> (+2.77%) ⬆️
src/rand_tangent.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 394f539...8dc5f8d. Read the comment docs.

@oxinabox
Copy link
Member

oxinabox commented Nov 9, 2021

I wonder if we can do this with ProjectTo ?

@devmotion
Copy link
Member Author

I wonder if we can do this with ProjectTo ?

Nice, this seems to fix the issue in ChangesOfVariables as well. Great if we can use existing functionality. I'm currently checking if it breaks any tests in CRTestUtils (and if successful, I'll push and we can see if it breaks any downstream packages).

@devmotion
Copy link
Member Author

Hmm, tests passed locally but it seems doctests fail since ProjectTo(::Tuple{Float64,Float64,Float64}) is not defined 😢

@devmotion
Copy link
Member Author

@oxinabox would it be OK to add a ProjectTo(::Tuple) definition in ChainRulesCore? Or was it omitted on purpose?

@oxinabox
Copy link
Member

oxinabox commented Nov 9, 2021

@oxinabox would it be OK to add a ProjectTo(::Tuple) definition in ChainRulesCore? Or was it omitted on purpose?

it is here, no?
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/99d56b145bb4829931c542e720a015d938efeee4/src/projection.jl#L291

@devmotion
Copy link
Member Author

Oh wow, now I'm confused. The error message clearly stated that it is missing: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/runs/4153868239?check_suite_focus=true#step:4:214

@devmotion
Copy link
Member Author

Ah the docs use an old version of CRCore: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/runs/4153868239?check_suite_focus=true#step:4:57

@oxinabox
Copy link
Member

oxinabox commented Nov 9, 2021

Looks good to me.
Assuming integration tests pass.
Once you add a test you can merge and tag when happy.

@devmotion
Copy link
Member Author

There's one test error in ChainRules: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/runs/4154680040?check_suite_focus=true#step:6:162 It is caused by the test in https://github.com/JuliaDiff/ChainRules.jl/blob/edf3a1f48fb5c9af01820aeca6ced94d4f97fa1a/test/rulesets/Base/array.jl#L35

More concretely, the problem is that

julia> using ChainRulesCore # v 1.11.1

julia> ProjectTo((; y = randn(3)))
identity (generic function with 1 method)

I guess this is a bug and has to be fixed in ChainRulesCore since it differs e.g. from

julia> ProjectTo((randn(3),))
ProjectTo{Tangent{Tuple{Vector{Float64}}, T} where T}(elements = (ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(3),)),),)

which projects to a Tangent.

@devmotion
Copy link
Member Author

I assume the ChainRules.jl test errors will be fixed by JuliaDiff/ChainRules.jl#550.

@oxinabox
Copy link
Member

oxinabox commented Dec 7, 2021

yep, merge when happy

@devmotion devmotion merged commit 63bbd48 into main Dec 8, 2021
@devmotion devmotion deleted the dw/fix_to_composite branch December 8, 2021 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants