-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow NotImplemented tangents for things that have a correct tangent of NoTangent #218
Conversation
…eturns `NoTangent`
Codecov Report
@@ Coverage Diff @@
## master #218 +/- ##
==========================================
- Coverage 91.21% 90.75% -0.46%
==========================================
Files 11 11
Lines 296 303 +7
==========================================
+ Hits 270 275 +5
- Misses 26 28 +2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am convinced that this is the best way.
Do you want to extract some of this code into a helper method?
Its getting long.
Maybe a function for the content of the branch: accum_cotangent isa NoTangent
?
Or maybe a method that dispatches on accum_cotangent
, and ad_cotangent
?
Regardless, change as you will then merge and tag when happy.
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
src/testers.jl
Outdated
# the `@test_broken` below should tell them that there is an easy | ||
# implementation for this case of `NoTangent()` | ||
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217 | ||
@test_broken false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@test_broken
false is much less useful than
@test_broken false | |
@test_broken not_implemented === NoTangent() |
because it doesn't give a message saying what the correct answer is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah true, I thought I could simplify the code but I see that it was not a good idea 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Macro's they will do that to you
# the `@test_broken` below should tell them that there is an easy implementation for | ||
# this case of `NoTangent()` (`@test_broken false` would be less useful!) | ||
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217 | ||
@test_broken ad_cotangent isa NoTangent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the
@test_broken
below should tell them that there is an easy implementation for this case ofNoTangent()
I don't understand what this is supposed to say, could you explain so we can make it a more helpful comment for future readers of the codebase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the comment by @oxinabox on the outdated diff:
The correct implementation of anything that test_rrule determined was a NoTangent is almost certainly NoTangent()
Which is easy.
It might not generalize to cases that are not tested, but for the types passed into this test it is (almost certainly) correct.
But I'm still not clear what I'd be expected to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding (based on the CRC docs) is: If I mark something as NotImplemented, that generally means I could implement it, but it's hard (and I can't be bothered, maybe it wouldn't be used anyways in practice). So what is the reference to an "easy implementation" intended to say?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you end up here you should have used a NoTangent()
cotangent in your rrule
. So the easy fix is to just replace whatever you did in your rrule
with NoTangent()
. @test_broken ad_cotangent isa NoTangent
will display the broken expression and hence it is easy to see that ad_tangent should have been NoTangent()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the argument is actually differentiable and NoTangent
is just caused by an incorrect rand_tangent
default, then you should provide a correct tangent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider the MWE from #217 (comment)
function ChainRulesCore.rrule(::typeof(foo), K, fun)
f = foo(K, fun)
function pullback(Δf)
∂self = NoTangent()
∂K = @thunk(2Δf)
∂fun = @not_implemented("does not work") # does not work
return (∂self, ∂K, ∂fun)
end
return f, pullback
end
The correct implementation; for the types that were tested in this test_rrule
, (for you to hit this test_broken
are:
function ChainRulesCore.rrule(::typeof(foo), K, fun)
f = foo(K, fun)
function pullback(Δf)
∂self = NoTangent()
∂K = @thunk(2Δf)
∂fun = NoTangent()
return (∂self, ∂K, ∂fun)
end
return f, pullback
end
Because you can only end-up on this line if fun
has no fields, and this it's tangent must be NoTangent()
.
The implementation is not a mystery here, NoTangent()
is the correct thing to write.
Now if the rule author wanted to be generic to whether or not fun
had fields.
Then the rule author should probably test that case by using a closure or some functor.
And they might end-up with something like:
∂fun = Base.issingletontype(typeof(fun)) ? NoTangent() : @not_implemented("Functors not supported")
per #217 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where would we have ended up if fun
had fields? How would I tell the dispatcher that an rrule would only apply to functions with no field?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my example
function ChainRulesCore.rrule(::typeof(foo), K, fun)
f = foo(K, fun)
function pullback(Δf)
∂self = NoTangent()
∂K = @thunk(2Δf)
∂fun = @not_implemented("would have to do more maths")
return (∂self, ∂K, ∂fun)
end
return f, pullback
end
should be generic for fun
(the derivative for fields of fun
would exist, I just haven't worked it out yet), is @not_implemented
not the right thing to do ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it is not the right thing to do, the right thing is to implement the derivatives 😛 @not_implemented
is a workaround until you've done this, and therefore all tests with NotImplemented
use @test_broken
to indicate that something is broken and should be fixed.
Co-authored-by: st-- <st--@users.noreply.github.com>
This is one possible approach to fix #217.
An alternative would be to tell users to define
rand_tangent
or, probably better, pass a suitable tangent of typeNotImplemented
. It is a bit inconvenient though since currently we check for equality, including the message AND the LineNumberNodes. I.e., one can't just usetest_rrule(f, x \vdash @not_implemented("does not work"))
since the LineNumberNode would be different from the one used inside therrule
. Maybe this should be changed and we should only check if the messages are equal?