Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChainRules rrule Integration for Unitful #504

Conversation

SBuercklin
Copy link
Contributor

@SBuercklin SBuercklin commented Dec 1, 2021

This intention of this PR is to implement the machinery within Unitful.jl to allow for autodiff over Unitful.Quantitys. Specifically, it should include the necessary ChainRulesCore.jl methods to provide some basic level of compatibility with ChainRules-based AD systems.

Before this PR, Quantitys would be reduced to Tangent{Any}(val = ...) which would break a lot of basic AD arithmetic. After this PR:

julia> Zygote.gradient((x,y) -> (x*W)/(y*μm)/ms, 3.0*W, 2.0*μm)
(0.5 W μm^-2 ms^-1, -0.75 W^2 μm^-3 ms^-1)

julia> Zygote.gradient((x,y) -> (x*ms + 9*y*ms)/μm, 2.0*W, 3.0*W)
(1.0 ms μm^-1, 9.0 ms μm^-1)

I've implemented an rrule for the Quantity constructor, ProjectTo{Quantity}, and arithmetic between Number/Quantity and Units, which are used to call the Quantity constructor. Generally speaking, projecting a Quantity to a Number involves projecting the value of the Quantity onto the number, and then propagating the units of the projecting Quantity onto the Number. This ensures the proper real/complex behavior is obeyed while maintaining correct units.

I wanted to get feedback before continuing to ensure that:

  1. This is something other people have an interest in
  2. It's being properly implemented and tested
  3. Any necessary changes elsewhere can also be implemented

Testing is difficult as ChainRulesTestUtils.jl does not play nicely with Unitful.jl at the moment. I can manually test the rrules and ProjectTo, but if there are any other testing approaches I'm open to ideas.

Regarding dependencies elsewhere, more full compatibility with ChainRules.jl needs this PR which relaxes the constraint over many of the rules from Union{Real, Complex} to just Number. This should give compatibility with Quantitys which subtype Number, but typically wrap <:Real, <:Complex.

Remaining work:

  • Fill out rules where it makes sense for operations between Quantitys and Units (right now only * and / are implemented)
  • Implement frules
  • Testing

Copy link

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine to me

Comment on lines 8 to 11
function ProjectTo(x::Quantity)
project_val = ProjectTo(x.val) # Project the literal number
return ProjectTo{typeof(x)}(; project_val = project_val)
end
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't really matter but the convention ChainRulesCore uses is to match the field name, if in doubt.c

Suggested change
function ProjectTo(x::Quantity)
project_val = ProjectTo(x.val) # Project the literal number
return ProjectTo{typeof(x)}(; project_val = project_val)
end
function ProjectTo(x::Quantity)
val = ProjectTo(x.val) # Project the literal number
return ProjectTo{typeof(x)}(; val = val)
end

Comment on lines +13 to +9
function (projector::ProjectTo{<:Quantity})(x::Number)
new_val = projector.project_val(ustrip(x))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convention:

Suggested change
function (projector::ProjectTo{<:Quantity})(x::Number)
new_val = projector.project_val(ustrip(x))
function (project::ProjectTo{<:Quantity})(x::Number)
new_val = project.val(ustrip(x))


function ProjectTo(x::Quantity)
project_val = ProjectTo(x.val) # Project the literal number
return ProjectTo{typeof(x)}(; project_val = project_val)
Copy link
Contributor

@mcabbott mcabbott Dec 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This stores the complete type, and hence unit, of x in the type of the projector. But when this is applied, you use only unit(dx) and not this unit. That's mathematically correct, I think, since the gradient will typically have different units. But it also means that storing this is redundant.

It could just be ProjectTo{Quantity} -- many of them store just the top-level type. But could it just be ProjectTo(x.val)?

Right now ProjectTo{Float64} will allow through any dx with units, without changing them. To get the present behaviour of this PR, could you just define methods for (::ProjectTo{<:Number})(dx::Quantity) which un-wrap, adjust the precision etc if necessary, and re-wrap? Or does this land you in dispatch ambiguity hell?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh true, i missed that, my bad

Comment on lines 27 to 28
units = (y, z...)
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

Suggested change
units = (y, z...)
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...)
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), 1+length(z))...)

Copy link
Contributor

@mcabbott mcabbott Dec 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was going to suggest something like:

        nots = ntuple(Returns(NoTangent()), 1 + length(z))
        return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...)

since I think there is little to gain by making the pullback close over the ProjectTo instead of over x. But something to gain in readability by needing fewer symbols. (But this is just style.)

Comment on lines +1 to +6
function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U}
unitful_x = Quantity{T,D,U}(x)
projector_x = ProjectTo(x)
uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT))
return unitful_x, uq_pullback
end
Copy link
Contributor

@mcabbott mcabbott Dec 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this called?

If't it's used when attaching units to an initially plain number, x=1 -> unitful_x = 1m, then the thinking is that if the loss is a unitless scalar, the gradient for unitful_x will be d loss / d unitful_x = 100/m, and this will produce a gradient for x with no units (or units equivalent to 1)?

And does that work out in practice? With some Zygote.gradient(loss, 1u"m")... must you ensure by hand that you remove the units within loss, or does Zygote.sensititvity do the right thing? Maybe that's a bigger question than this function... have not thought much about how this all ought to work.

Copy link
Contributor Author

@SBuercklin SBuercklin Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zygote.sensitivity returns the multiplicative identity which is usually 1.0, even for Unitful.Quantity. The example worked out in @oxinabox's comment matches how I've thought about this, so I think Zygote is correct here.

@oxinabox
Copy link

oxinabox commented Dec 3, 2021

I am trying to think this one though.

Consider: x = (5.0Meter/Second) * t)
or written long:

function f(t)
    a = 5 Meter/Second
    x = a*t
    return x
end

so dx/dt = 5.0 Meter/Second
this is known

I initially assumed (incorrectly) that the seed co-tangent has the same units as a difference of primals, which is same as the primal units in most (all?) cases
which would have t̄ = dt/dt = 1 Second

Then we would get
x̄ = a * t̄
x̄ = (5.0 Meter/Second) * 1 Second
x̄ = (5.0 Meter)
Which is wrong since x̄ = dx/dt

So I guess the seed must have units of Second/Second or just 1, makes sense since it is dt/dt

t̄ = dt/dt = 1 Second/Second
x̄ = a * t̄
x̄ = (5.0 Meter/Second) * 1 Second/Second
x̄ = (5.0 Meter/Second)

Which was what was wanted.

@jgreener64
Copy link

It's above my abilities to review this but I just wanted to say that this would be a great addition to AD in Julia.

@codecov-commenter
Copy link

codecov-commenter commented Dec 6, 2021

Codecov Report

Merging #504 (5f11a68) into master (f9992c0) will increase coverage by 3.05%.
The diff coverage is 73.68%.

@@            Coverage Diff             @@
##           master     #504      +/-   ##
==========================================
+ Coverage   84.94%   88.00%   +3.05%     
==========================================
  Files          16       17       +1     
  Lines        1448     1467      +19     
==========================================
+ Hits         1230     1291      +61     
+ Misses        218      176      -42     
Impacted Files Coverage Δ
src/Unitful.jl 100.00% <ø> (ø)
src/chainrules.jl 73.68% <73.68%> (ø)
src/dates.jl 97.22% <0.00%> (+1.38%) ⬆️
src/types.jl 91.83% <0.00%> (+2.04%) ⬆️
src/logarithm.jl 78.27% <0.00%> (+8.60%) ⬆️
src/promotion.jl 95.55% <0.00%> (+53.33%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f9992c0...5f11a68. Read the comment docs.

@platawiec
Copy link

Are there any further steps required before this can get merged? Should there be a manual rrule implemented for the tests, maybe?

@SBuercklin SBuercklin changed the title WIP: ChainRules Integration for Unitful ChainRules Integration for Unitful Jan 15, 2022
@SBuercklin
Copy link
Contributor Author

Added some tests for the rrules.

I left the / rules untested since the calls themselves dispatch to *(::Number, ::Units) which in turn dispatch to the Quantity constructor rule. We could write these dispatches out manually and test them, but this seems like something the AD backend should be inferring.

frules aren't implemented still, but it seems fine to put that off to a later issue if we want to get this merged. I removed the WIP prefix from the PR title, and this seems good to go if we're happy with just rrules for now.

@jgreener64
Copy link

Is there anything else that needs doing or can this be merged?

@SBuercklin
Copy link
Contributor Author

This only implements the rrules right now, which I think should be sufficient for most AD usage. In my opinion it's fine to merge, but I'll update the PR title to emphasize that it only has rrules, not frules

Relatedly, @oxinabox do you think we need the frules right now?

@SBuercklin SBuercklin changed the title ChainRules Integration for Unitful ChainRules rrule Integration for Unitful Mar 14, 2022
Project.toml Show resolved Hide resolved
test/chainrules.jl Outdated Show resolved Hide resolved
@SBuercklin SBuercklin force-pushed the sbuercklin/chainrules-unitful branch from bd88a9d to 5e24deb Compare March 15, 2022 13:53
@SBuercklin
Copy link
Contributor Author

Accidentally bumped the patch version before; should be good now

@giordano
Copy link
Collaborator

This is overall ok with me, but the question is do we want to add another dependency? This package has been traditionally rather conservative on taking on dependencies, however there is already a non-standard library and ChainRulesCore is widely used with over 2k total dependents, so the probability someone would have already this package in an environment is non-negligible. Also, I verified loading time of Unitful.jl doesn't increase dramatically.

@ajkeller34 @sostock opinions?

@amostof
Copy link

amostof commented Mar 31, 2022

I am awfully new to Julia and SciML ecosystem. Would this addition make possible to run sciml_train from DiffEqFlux packages on DifferentialEquations models for parameter estimation?

@oxinabox
Copy link

Also, I verified loading time of Unitful.jl doesn't increase dramatically.

It should now be even lighter
JuliaDiff/ChainRulesCore.jl#524

@SBuercklin
Copy link
Contributor Author

Is there anything left to do on this PR?

@SBuercklin
Copy link
Contributor Author

Closing this PR. I turned this into its own package: https://github.com/SBuercklin/UnitfulChainRules.jl

I just submitted the registration on the General registry, once that clears I'll submit a PR adding a link to the README.md

@SBuercklin SBuercklin closed this Jun 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants