Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: project implementation #380

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export canonicalize, extern, unthunk # differential operations
export canonicalize, extern, unthunk, project # differential operations
export add!! # gradient accumulation operations
# differentials
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
Expand All @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl")

include("differential_arithmetic.jl")
include("accumulation.jl")
include("projection.jl")

include("config.jl")
include("rules.jl")
Expand Down
62 changes: 62 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using LinearAlgebra: Diagonal, diag

"""
project([T::Type], x, dx)

"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided,
it is assumed to be the type of `x`.
Comment on lines +4 to +7
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of T in addition to x? It sounds like you have something in mind but I don't see what. In what directions may they differ? May T be abstract?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh T could be Tangent as in line 60

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why? And when, I mean for which x would this be chosen somehow? Sounds a bit like an orthogonal feature.

Copy link
Member Author

@mzgubic mzgubic Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thought was to ditch T altogether, and privilege some tangent (probably Tangent) representation as canonical.

However, somewhere in the many threads about this, @oxinabox argued that we need to go both ways (to tangent and to natural differential say). e.g. if a rule author writes a rule only for Diagonal and wants to make sure that the Tangent is being transformed right

Edit: sorry, the other way, writes a rule for Tangent and wants to make sure all the tangents it gets are Tangents

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but this still seems a bit like it's rolling two things together. One is a projection to preserve the input's x's real-ness / Diagon-ality etc. This would be applied automatically to every single rule.

The other is a way to massage certain dxs to a form digestible by certain rules. It's possible that there is some efficiency gained by fusing these processes into one function, but... for a first stab, it seems that combing them is adding complication?

Copy link
Member

@oxinabox oxinabox Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be applied automatically to every single rule.

No, I don't think so.
As i see it:

In that use, project is one way that lets us write things that can run on anything, while getting out the type (and associated basis) that we want.
But it's not the only way, and its not always the best way.
In fact i think it rarely is, since it often results in allocating and computing things that we are going going to throw away.
Its just that it is the most generic way.
Alternatives for example, are things like structure preserving maps and broadcasts.
Which avoid allocations and computing things they are going to throw away.
Though i guess in those cases you can still project after, but it will be a identity operation.
Like convert(Array, [1,2,3,]) is an identity operation.

Which brings me around to the general thing:
This is like convert except that it is for vector spaces and is aware of vector space concerns, where are your zeros (which is same as basis)
Which is the same convert-like operation one want to massage on the way in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in those cases you can still project after, but it will be a identity operation

Yes, that's my thinking. In my list of examples somewhere, any operation mixing real & complex arrays is something which ought to project. But making every rule do the ideal thing would be a lot of work for little benefit. Whereas projecting back to a real array for hcat([1,2], [3+im, 4+im])'s gradient at least means it propagates no further.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, anyway, point is it is our vector space aware convert so it is the right name for both the thing you (often) want to do at the end.
But it is also the one you might want to do at the start


It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s
onto `Array`s -- this wouldn't be possible with type information alone because the neither
`AbstractZero`s nor `T` know what size of `Array` to produce.
"""
function project end

project(x, dx) = project(typeof(x), x, dx)

# identity
project(::Type{T}, x::T, dx::T) where T = dx

### AbstractZero
project(::Type{T}, x::T, dx::AbstractZero) where T = zero(x)

### AbstractThunk
project(::Type{T}, x::T, dx::AbstractThunk) where T = project(x, unthunk(dx))


### Number-types
project(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = T(dx)
project(::Type{T}, x::T, dx::Complex) where {T<:Real} = T(real(dx))


# Arrays
Copy link
Member

@oxinabox oxinabox Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling we might need to support the same kwargs that similar does.
which iirc are size and eltype.
So you can say this AbstractArray type, except the the size should be this and the eltype that.

E.g.
for primal being WrapperArray{Tuple{Float64,Float64}}

project(typeof(primal), dx; eltype=Tangent{Tuple{Float64,Float64}}

to get back a WrapperArray{Tangent{Tuple{Float64,Float64}}}

And size would let us just store the size, rather than the whole primal type for zeros.

We might need a structual_zeros_from set_structural_zeros_by function or strutural_zero_inds list of indices, for if structural zero locations are only known by value (like for SparseCSC)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see where you're coming from, but is it not both cleaner and more general to do the kind of thing that @CarloLucibello suggests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these seem orthogonal to me.
Or at least that the closure @CarloLucibello suggest would be implemented on top of this.
So that it extracts these bits of information from the primal and passes them into the kwargs.
Rather than closing over everything

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these seem orthogonal to me.

Hmmm maybe I've misunderstood your reasoning for needing eltype and size kwargs then.

Are you saying we might need it because we would want to apply project for some primal other than the primal that the tangent is for?

i.e. let x and y be primals of differing sizes. Say that I've got a tangent dx for x, but for some reason I've only got access to y. So I say project(y, dx; size=size_of_x).

Is this the kind of situation you're imagining?

project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx

# for project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()])
project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx)

# for project(rand(2, 2), Diagonal(rand(2)))
function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array}
return project(T, x, collect(dx))
end

# for project([Foo(0.0), Foo(0.0)], ZeroTangent())
function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T}
return project.(Ref(T), x, Ref(dx))
end


## Diagonal
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V}
return Diagonal(project(V, diag(x), diag(dx)))
end
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V}
return Diagonal(project(V, diag(x), dx.diag))
end
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V}
return Diagonal(project(V, diag(x), dx))
end

function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal)
return Tangent{typeof(x)}(diag=diag(dx))
end
81 changes: 81 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
struct Foo
a::Float64
end

Base.zero(::Foo) = Foo(0.0)
Base.zero(::Type{Foo}) = "F0"

@testset "projection" begin

#identity
@test Foo(1.2) == project(Foo(-0.2), Foo(1.2))
@test 3.2 == project(1.0, 3.2)
@test 2.0 + 0.0im == project(1.0im, 2.0)

@testset "From AbstractZero" begin
@testset "to numbers" begin
@test 0.0 == project(1.1, ZeroTangent())
@test 0.0f0 == project(1.1f0, ZeroTangent())
end

@testset "to arrays (dense and structured)" begin
@test zeros(2, 2) == project([1.0 2; 3 4], ZeroTangent())
@test Diagonal(zeros(2)) == project(Diagonal([1.0, 4]), ZeroTangent())
@test Diagonal(zeros(ComplexF64, 2)) == project(Diagonal([1.0 + 0im, 4]), ZeroTangent())
end

@testset "to structs" begin
@test Foo(0.0) == project(Foo(3.2), ZeroTangent())
end

@testset "to arrays of structs" begin
@test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], ZeroTangent())
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), ZeroTangent())
end
end

@testset "From AbstractThunk" begin
@test 3.2 == project(1.0, @thunk(3.2))
@test Foo(3.2) == project(Foo(-0.2), @thunk(Foo(3.2)))
@test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent()))
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent()))
end

@testset "To number types" begin
@testset "to subset" begin
@test 3.2 == project(1.0, 3.2 + 3im)
@test 3.2f0 == project(1.0f0, 3.2)
@test 3.2f0 == project(1.0f0, 3.2 - 3im)
end

@testset "to superset" begin
@test 2.0 + 0.0im == project(2.0 + 1.0im, 2.0)
@test 2.0 == project(2.0, 2.0f0)
end
end

@testset "To Arrays" begin
# change eltype
@test [1.0 2.0; 3.0 4.0] == project(zeros(2, 2), [1.0 2.0; 3.0 4.0])
@test [1.0f0 2; 3 4] == project(zeros(Float32, 2, 2), [1.0 2; 3 4])

# from a structured array
@test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4]))

# from an array of specials
@test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()])
end

@testset "Diagonal" begin
d = Diagonal([1.0, 4.0])
t = Tangent{Diagonal}(;diag=[1.0, 4.0])
@test d == project(d, [1.0 2; 3 4])
@test d == project(d, t)
@test project(Tangent, d, d) isa Tangent

@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()]))
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent()))
end

# how to project to Upper/Lower Symmetric
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Test
end

include("accumulation.jl")
include("projection.jl")

include("rules.jl")
include("rule_definition_tools.jl")
Expand Down