-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Sketch project implementation #306
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
using LinearAlgebra: Diagonal, diag | ||
|
||
""" | ||
project(T::Type, x, dx) | ||
|
||
"project" `dx` onto type `T` such that it is the same size as `x`. | ||
|
||
It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s | ||
onto `Array`s -- this wouldn't be possible with type information alone because the neither | ||
`AbstractZero`s nor `T` know what size of `Array` to produce. | ||
""" | ||
function project end | ||
|
||
# Number-types | ||
project(::Type{T}, x::T, dx::T) where {T<:Real} = dx | ||
|
||
project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) | ||
|
||
project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) | ||
|
||
|
||
|
||
# Arrays | ||
project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx | ||
|
||
project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) | ||
|
||
function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} | ||
return project(T, x, collect(dx)) | ||
end | ||
|
||
function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} | ||
return project.(Ref(T), x, Ref(dx)) | ||
end | ||
|
||
|
||
|
||
# Diagonal | ||
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} | ||
return Diagonal(project(V, diag(x), diag(dx))) | ||
end | ||
|
||
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Composite) where {V} | ||
return Diagonal(project(V, diag(x), dx.diag)) | ||
end | ||
|
||
function project(::Type{<:Composite}, x::Diagonal, dx::Diagonal) | ||
return Composite{typeof(x)}(diag=diag(dx)) | ||
end | ||
|
||
|
||
|
||
# One use for this functionality is to make it easy to define addition between two different | ||
# representations of the same tangent. This also makes it clear that the | ||
Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't julia> struct Foo
a::Float64
end
julia> foo = Foo(2.3)
Foo(2.3)
julia> tfoo = Tangent{Foo}(;a=3.2)
Tangent{Foo}(a = 3.2,)
julia> foo + tfoo
Foo(5.5)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have we found a occurance where + between tangents, and expodential map (which we call primal + tangent) I think they do actually agree, but I would need to think more to be sure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that I don't think we've found a concrete example where the exponential map doesn't agree with My inclination would be instead to ensure that we never need to add a primal to a (co)tangent inside an AD by ensuring that tangents are "projected" (still not sure that's the best name) onto a common type before we try to add them. Possibly the more compelling reason to avoid ever calling So my instinct is to avoid defining |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
@testset "projection" begin | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make
T
optional to avoid duplication when it is the same astypeof(x)
?Perhaps we could drop
T
entirely and say thatproject
will map the differential into a canonical form, and we decide what the canonical form is?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah -- we'd need to privelege a particular representation (presumably
Tangent
), but I think that this would make sense.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to close over
x
? I was thinking elsewhere that the right pattern iswith
This lets you pass information besides the type when that's required. (I think the motivating case needed a reshape.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that you could do this, but since we'll be getting opaque closures in 1.7, can we just assume that we'll get to use them, and therefore not have to worry about optimising this stuff away manually? (It's entirely possible that I'm overestimating the compiler's ability to optimise stuff away here though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's true. What is the status BTW, I haven't kept track, had an idea they existed but were still slow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty much any
Tangent -> AbstractArray
conversion needs size info I think, you probably need to know about the locations of non-zero elements in sparse arrays, that kind of thing. Off the top of my head, I can't think of others though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re sparse arrays... is it obvious how many structural zeros
gradient(f, sinpi.(sparse(0:0.5:2)))
should have? Maybe there's an issue discussing this but it doesn't super-obvious to me; perhaps there are more realistic examples than mine which would clarify what should happen. (I don't mean structurally sparse like Diagonal, obviously, I mean SparseArrays.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My claim is that it should have precisely the same number as the primal -- i.e. zeros should be treated structurally (it's what I argue for in the Abstract Arrays Dilemma docs). This is something that we need to reach a consensus on though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, but a bigger question than this mini-thread. Are there other clear, concrete, examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. Off the top of my head, I can't think of any other obvious ones. Possibly it's going to be limited to array-like things, because they're pretty much the only case in which we need to transform between
Tangent
s and array-like representations of tangents.