-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ChainRules rrule
Integration for Unitful
#504
ChainRules rrule
Integration for Unitful
#504
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine to me
src/chainrules.jl
Outdated
function ProjectTo(x::Quantity) | ||
project_val = ProjectTo(x.val) # Project the literal number | ||
return ProjectTo{typeof(x)}(; project_val = project_val) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't really matter but the convention ChainRulesCore uses is to match the field name, if in doubt.c
function ProjectTo(x::Quantity) | |
project_val = ProjectTo(x.val) # Project the literal number | |
return ProjectTo{typeof(x)}(; project_val = project_val) | |
end | |
function ProjectTo(x::Quantity) | |
val = ProjectTo(x.val) # Project the literal number | |
return ProjectTo{typeof(x)}(; val = val) | |
end |
function (projector::ProjectTo{<:Quantity})(x::Number) | ||
new_val = projector.project_val(ustrip(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convention:
function (projector::ProjectTo{<:Quantity})(x::Number) | |
new_val = projector.project_val(ustrip(x)) | |
function (project::ProjectTo{<:Quantity})(x::Number) | |
new_val = project.val(ustrip(x)) |
src/chainrules.jl
Outdated
|
||
function ProjectTo(x::Quantity) | ||
project_val = ProjectTo(x.val) # Project the literal number | ||
return ProjectTo{typeof(x)}(; project_val = project_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This stores the complete type, and hence unit, of x
in the type of the projector. But when this is applied, you use only unit(dx)
and not this unit. That's mathematically correct, I think, since the gradient will typically have different units. But it also means that storing this is redundant.
It could just be ProjectTo{Quantity}
-- many of them store just the top-level type. But could it just be ProjectTo(x.val)
?
Right now ProjectTo{Float64}
will allow through any dx
with units, without changing them. To get the present behaviour of this PR, could you just define methods for (::ProjectTo{<:Number})(dx::Quantity)
which un-wrap, adjust the precision etc if necessary, and re-wrap? Or does this land you in dispatch ambiguity hell?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh true, i missed that, my bad
src/chainrules.jl
Outdated
units = (y, z...) | ||
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
units = (y, z...) | |
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) | |
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), 1+length(z))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was going to suggest something like:
nots = ntuple(Returns(NoTangent()), 1 + length(z))
return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...)
since I think there is little to gain by making the pullback close over the ProjectTo instead of over x
. But something to gain in readability by needing fewer symbols. (But this is just style.)
function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} | ||
unitful_x = Quantity{T,D,U}(x) | ||
projector_x = ProjectTo(x) | ||
uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT)) | ||
return unitful_x, uq_pullback | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this called?
If't it's used when attaching units to an initially plain number, x=1
-> unitful_x = 1m
, then the thinking is that if the loss is a unitless scalar, the gradient for unitful_x
will be d loss / d unitful_x = 100/m
, and this will produce a gradient for x
with no units (or units equivalent to 1)?
And does that work out in practice? With some Zygote.gradient(loss, 1u"m")
... must you ensure by hand that you remove the units within loss
, or does Zygote.sensititvity
do the right thing? Maybe that's a bigger question than this function... have not thought much about how this all ought to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zygote.sensitivity
returns the multiplicative identity which is usually 1.0
, even for Unitful.Quantity
. The example worked out in @oxinabox's comment matches how I've thought about this, so I think Zygote
is correct here.
I am trying to think this one though. Consider: function f(t)
a = 5 Meter/Second
x = a*t
return x
end so I initially assumed (incorrectly) that the seed co-tangent has the same units as a difference of primals, which is same as the primal units in most (all?) cases Then we would get So I guess the seed must have units of
Which was what was wanted. |
It's above my abilities to review this but I just wanted to say that this would be a great addition to AD in Julia. |
fbe6810
to
1ea376a
Compare
Codecov Report
@@ Coverage Diff @@
## master #504 +/- ##
==========================================
+ Coverage 84.94% 88.00% +3.05%
==========================================
Files 16 17 +1
Lines 1448 1467 +19
==========================================
+ Hits 1230 1291 +61
+ Misses 218 176 -42
Continue to review full report at Codecov.
|
Are there any further steps required before this can get merged? Should there be a manual rrule implemented for the tests, maybe? |
Added some tests for the I left the
|
Is there anything else that needs doing or can this be merged? |
This only implements the Relatedly, @oxinabox do you think we need the |
rrule
Integration for Unitful
add endline Co-authored-by: Mosè Giordano <giordano@users.noreply.github.com>
bd88a9d
to
5e24deb
Compare
Accidentally bumped the patch version before; should be good now |
This is overall ok with me, but the question is do we want to add another dependency? This package has been traditionally rather conservative on taking on dependencies, however there is already a non-standard library and @ajkeller34 @sostock opinions? |
I am awfully new to Julia and SciML ecosystem. Would this addition make possible to run |
It should now be even lighter |
Is there anything left to do on this PR? |
Closing this PR. I turned this into its own package: https://github.com/SBuercklin/UnitfulChainRules.jl I just submitted the registration on the General registry, once that clears I'll submit a PR adding a link to the |
This intention of this PR is to implement the machinery within
Unitful.jl
to allow for autodiff overUnitful.Quantity
s. Specifically, it should include the necessaryChainRulesCore.jl
methods to provide some basic level of compatibility withChainRules
-based AD systems.Before this PR,
Quantity
s would be reduced toTangent{Any}(val = ...)
which would break a lot of basic AD arithmetic. After this PR:I've implemented an
rrule
for theQuantity
constructor,ProjectTo{Quantity}
, and arithmetic betweenNumber/Quantity
andUnits
, which are used to call theQuantity
constructor. Generally speaking, projecting aQuantity
to aNumber
involves projecting the value of theQuantity
onto the number, and then propagating the units of the projectingQuantity
onto theNumber
. This ensures the proper real/complex behavior is obeyed while maintaining correct units.I wanted to get feedback before continuing to ensure that:
Testing is difficult as
ChainRulesTestUtils.jl
does not play nicely withUnitful.jl
at the moment. I can manually test therrule
s andProjectTo
, but if there are any other testing approaches I'm open to ideas.Regarding dependencies elsewhere, more full compatibility with
ChainRules.jl
needs this PR which relaxes the constraint over many of the rules fromUnion{Real, Complex}
to justNumber
. This should give compatibility withQuantity
s which subtypeNumber
, but typically wrap<:Real, <:Complex
.Remaining work:
Quantity
s andUnit
s (right now only*
and/
are implemented)frule
s