Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow NotImplemented tangents for things that have a correct tangent of NoTangent #218

Merged
merged 8 commits into from
Sep 21, 2021

Conversation

devmotion
Copy link
Member

This is one possible approach to fix #217.

An alternative would be to tell users to define rand_tangent or, probably better, pass a suitable tangent of type NotImplemented. It is a bit inconvenient though since currently we check for equality, including the message AND the LineNumberNodes. I.e., one can't just use test_rrule(f, x \vdash @not_implemented("does not work")) since the LineNumberNode would be different from the one used inside the rrule. Maybe this should be changed and we should only check if the messages are equal?

@codecov-commenter
Copy link

codecov-commenter commented Sep 21, 2021

Codecov Report

Merging #218 (f57fc7e) into master (a46fbbc) will decrease coverage by 0.45%.
The diff coverage is 85.71%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #218      +/-   ##
==========================================
- Coverage   91.21%   90.75%   -0.46%     
==========================================
  Files          11       11              
  Lines         296      303       +7     
==========================================
+ Hits          270      275       +5     
- Misses         26       28       +2     
Impacted Files Coverage Δ
src/testers.jl 91.42% <85.71%> (-1.58%) ⬇️
src/finite_difference_calls.jl 97.22% <0.00%> (+0.07%) ⬆️
src/check_result.jl 89.70% <0.00%> (+0.15%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a46fbbc...f57fc7e. Read the comment docs.

src/testers.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am convinced that this is the best way.

Do you want to extract some of this code into a helper method?
Its getting long.
Maybe a function for the content of the branch: accum_cotangent isa NoTangent ?
Or maybe a method that dispatches on accum_cotangent, and ad_cotangent ?

Regardless, change as you will then merge and tag when happy.

@oxinabox oxinabox changed the title Fix https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217 Allow NotImplemented tangents for things that have a correct tangent of NoTangent Sep 21, 2021
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
src/testers.jl Outdated Show resolved Hide resolved
src/testers.jl Outdated
# the `@test_broken` below should tell them that there is an easy
# implementation for this case of `NoTangent()`
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
@test_broken false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@test_broken false is much less useful than

Suggested change
@test_broken false
@test_broken not_implemented === NoTangent()

because it doesn't give a message saying what the correct answer is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah true, I thought I could simplify the code but I see that it was not a good idea 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Macro's they will do that to you

src/testers.jl Show resolved Hide resolved
src/testers.jl Show resolved Hide resolved
src/testers.jl Outdated Show resolved Hide resolved
src/testers.jl Show resolved Hide resolved
# the `@test_broken` below should tell them that there is an easy implementation for
# this case of `NoTangent()` (`@test_broken false` would be less useful!)
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
@test_broken ad_cotangent isa NoTangent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the @test_broken below should tell them that there is an easy implementation for this case of NoTangent()

I don't understand what this is supposed to say, could you explain so we can make it a more helpful comment for future readers of the codebase?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the comment by @oxinabox on the outdated diff:

The correct implementation of anything that test_rrule determined was a NoTangent is almost certainly NoTangent()
Which is easy.
It might not generalize to cases that are not tested, but for the types passed into this test it is (almost certainly) correct.

But I'm still not clear what I'd be expected to do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding (based on the CRC docs) is: If I mark something as NotImplemented, that generally means I could implement it, but it's hard (and I can't be bothered, maybe it wouldn't be used anyways in practice). So what is the reference to an "easy implementation" intended to say?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you end up here you should have used a NoTangent() cotangent in your rrule. So the easy fix is to just replace whatever you did in your rrule with NoTangent(). @test_broken ad_cotangent isa NoTangent will display the broken expression and hence it is easy to see that ad_tangent should have been NoTangent().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the argument is actually differentiable and NoTangent is just caused by an incorrect rand_tangent default, then you should provide a correct tangent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the MWE from #217 (comment)

function ChainRulesCore.rrule(::typeof(foo), K, fun)
    f = foo(K, fun)

    function pullback(Δf)
        ∂self = NoTangent()
        ∂K = @thunk(2Δf)
        ∂fun = @not_implemented("does not work")  # does not work
        return (∂self, ∂K, ∂fun)
    end

    return f, pullback
end

The correct implementation; for the types that were tested in this test_rrule, (for you to hit this test_broken are:

function ChainRulesCore.rrule(::typeof(foo), K, fun)
    f = foo(K, fun)

    function pullback(Δf)
        ∂self = NoTangent()
        ∂K = @thunk(2Δf)
        ∂fun = NoTangent()
        return (∂self, ∂K, ∂fun)
    end

    return f, pullback
end

Because you can only end-up on this line if fun has no fields, and this it's tangent must be NoTangent().
The implementation is not a mystery here, NoTangent() is the correct thing to write.

Now if the rule author wanted to be generic to whether or not fun had fields.
Then the rule author should probably test that case by using a closure or some functor.

And they might end-up with something like:

∂fun = Base.issingletontype(typeof(fun)) ? NoTangent() : @not_implemented("Functors not supported") 

per #217 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would we have ended up if fun had fields? How would I tell the dispatcher that an rrule would only apply to functions with no field?

Copy link
Contributor

@st-- st-- Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my example

function ChainRulesCore.rrule(::typeof(foo), K, fun)
    f = foo(K, fun)

    function pullback(Δf)
        ∂self = NoTangent()
        ∂K = @thunk(2Δf)
        ∂fun = @not_implemented("would have to do more maths")
        return (∂self, ∂K, ∂fun)
    end

    return f, pullback
end

should be generic for fun (the derivative for fields of fun would exist, I just haven't worked it out yet), is @not_implemented not the right thing to do ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it is not the right thing to do, the right thing is to implement the derivatives 😛 @not_implemented is a workaround until you've done this, and therefore all tests with NotImplemented use @test_broken to indicate that something is broken and should be fixed.

Co-authored-by: st-- <st--@users.noreply.github.com>
@devmotion devmotion merged commit e3ff6b4 into master Sep 21, 2021
@devmotion devmotion deleted the dw/notimplemented_notangent branch September 21, 2021 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to test rules with @not_implemented arguments?
4 participants