Skip to content

Commit

Permalink
ProjectTo maps to projecting onto the inner val, cleaner NoTangents i…
Browse files Browse the repository at this point in the history
…n * pullback
  • Loading branch information
SBuercklin committed Dec 6, 2021
1 parent d4e0fc8 commit 1ea376a
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,22 @@ function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U}
return unitful_x, uq_pullback
end

function ProjectTo(x::Quantity)
project_val = ProjectTo(x.val) # Project the literal number
return ProjectTo{typeof(x)}(; project_val = project_val)
end

function (projector::ProjectTo{<:Quantity})(x::Number)
new_val = projector.project_val(ustrip(x))
return new_val*unit(x)
end

# Project Unitful Quantities onto numerical types by projecting the value and carrying units
ProjectTo(x::Quantity) = ProjectTo(x.val)

(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx)
(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx)

function rrule(::typeof(*), x::Quantity, y::Units, z::Units...)
Ω = *(x, y, z...)
project_x = ProjectTo(x)
function times_pb(Δ)
δ = project_x(Δ)
units = (y, z...)
return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...)
nots = ntuple(_ -> NoTangent(), 1 + length(z))
return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...)
end
return Ω, times_pb
end
Expand Down

0 comments on commit 1ea376a

Please sign in to comment.