Skip to content

Commit

Permalink
project ZeroTangent to natural tangent for some number types
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Aug 10, 2022
1 parent fbb4936 commit 81736ce
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ end
# understands, including a mix of Zeros & reals. Other cases, we just let through:
(project::ProjectTo{<:Number})(dx::Tangent{<:Complex}) = project(Complex(dx.re, dx.im))
(::ProjectTo{<:Number})(dx::Tangent{<:Number}) = dx
(::ProjectTo{T})(::ZeroTangent) where {T<:Real} = zero(T)

# Arrays
# If we don't have a more specialized `ProjectTo` rule, we just assume that there is
Expand Down
8 changes: 4 additions & 4 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ struct NoSuperType end
@test ProjectTo(1.0)(2) === 2.0

# Tangents
ProjectTo(1.0f0 + 2im)(Tangent{ComplexF64}(; re=1, im=NoTangent())) ===
1.0f0 + 0.0f0im

@test 1.0 === ProjectTo(1.0)(Tangent{ComplexF64}(; re=1, im=NoTangent()))
complex_tangent = Tangent{ComplexF64}(; re=1, im=NoTangent())
@test ProjectTo(1.0f0 + 2im)(complex_tangent) === 1.0f0 + 0.0f0im
@test ProjectTo(1.0)(complex_tangent) === 1.0
@test ProjectTo(1.0)(ZeroTangent()) === 0.0
end

@testset "Dual" begin # some weird Real subtype that we should basically leave alone
Expand Down

0 comments on commit 81736ce

Please sign in to comment.