Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Sketch project implementation #306

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("differentials/composite.jl")

include("differential_arithmetic.jl")
include("accumulation.jl")
include("projection.jl")

include("rules.jl")
include("rule_definition_tools.jl")
Expand Down
55 changes: 55 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using LinearAlgebra: Diagonal, diag

"""
project(T::Type, x, dx)
Copy link
Member

@mzgubic mzgubic Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make T optional to avoid duplication when it is the same as typeof(x)?

Perhaps we could drop T entirely and say that project will map the differential into a canonical form, and we decide what the canonical form is?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah -- we'd need to privelege a particular representation (presumably Tangent), but I think that this would make sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to close over x? I was thinking elsewhere that the right pattern is

info = preproject(x)
...
dx = project(dx_raw, info...)

with

preproject(x) = ()
preproject(x::Number) = (typeof(x),)
preproject(x::AbstractArray) = (typeof(x), axes(x))

This lets you pass information besides the type when that's required. (I think the motivating case needed a reshape.)

Copy link
Member Author

@willtebbutt willtebbutt Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that you could do this, but since we'll be getting opaque closures in 1.7, can we just assume that we'll get to use them, and therefore not have to worry about optimising this stuff away manually? (It's entirely possible that I'm overestimating the compiler's ability to optimise stuff away here though)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's true. What is the status BTW, I haven't kept track, had an idea they existed but were still slow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much any Tangent -> AbstractArray conversion needs size info I think, you probably need to know about the locations of non-zero elements in sparse arrays, that kind of thing. Off the top of my head, I can't think of others though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re sparse arrays... is it obvious how many structural zeros gradient(f, sinpi.(sparse(0:0.5:2))) should have? Maybe there's an issue discussing this but it doesn't super-obvious to me; perhaps there are more realistic examples than mine which would clarify what should happen. (I don't mean structurally sparse like Diagonal, obviously, I mean SparseArrays.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My claim is that it should have precisely the same number as the primal -- i.e. zeros should be treated structurally (it's what I argue for in the Abstract Arrays Dilemma docs). This is something that we need to reach a consensus on though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but a bigger question than this mini-thread. Are there other clear, concrete, examples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Off the top of my head, I can't think of any other obvious ones. Possibly it's going to be limited to array-like things, because they're pretty much the only case in which we need to transform between Tangents and array-like representations of tangents.


"project" `dx` onto type `T` such that it is the same size as `x`.

It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s
onto `Array`s -- this wouldn't be possible with type information alone because the neither
`AbstractZero`s nor `T` know what size of `Array` to produce.
"""
function project end

# Number-types
project(::Type{T}, x::T, dx::T) where {T<:Real} = dx

project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x)

project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx))



# Arrays
project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx

project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx)

function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array}
return project(T, x, collect(dx))
end

function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T}
return project.(Ref(T), x, Ref(dx))
end



# Diagonal
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V}
return Diagonal(project(V, diag(x), diag(dx)))
end

function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Composite) where {V}
return Diagonal(project(V, diag(x), dx.diag))
end

function project(::Type{<:Composite}, x::Diagonal, dx::Diagonal)
return Composite{typeof(x)}(diag=diag(dx))
end



# One use for this functionality is to make it easy to define addition between two different
# representations of the same tangent. This also makes it clear that the
Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't primal + tangent return a primal?
I.e. as in

julia> struct Foo
           a::Float64
       end

julia> foo = Foo(2.3)
Foo(2.3)

julia> tfoo = Tangent{Foo}(;a=3.2)
Tangent{Foo}(a = 3.2,)

julia> foo + tfoo
Foo(5.5)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we found a occurance where + between tangents, and expodential map (which we call primal + tangent)
might disagree?

I think they do actually agree, but I would need to think more to be sure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that I don't think we've found a concrete example where the exponential map doesn't agree with + in our existing code, but it's sufficiently straightforward to concoct such examples, that I don't want to do it (really type whose instantiations live in a subset of R^D don't make sense).

My inclination would be instead to ensure that we never need to add a primal to a (co)tangent inside an AD by ensuring that tangents are "projected" (still not sure that's the best name) onto a common type before we try to add them.

Possibly the more compelling reason to avoid ever calling + on a primal and tangent is to do with @mzgubic 's point: it's unclear how to define it, because it's not clear what the output ought to be. The useful thing to do if we were "accumulating" inside reverse-mode would be to output another (co)tangent, but the useful thing to do for users who want to gradient-ascend is to output a primal.

So my instinct is to avoid defining + between primals and (co)tangents, and instead define the exponential map to transform between spaces.

3 changes: 3 additions & 0 deletions test/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@testset "projection" begin

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using Test
end

include("accumulation.jl")
include("projection.jl")

include("ruleset_loading.jl")
include("rules.jl")
Expand Down