-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: project
implementation
#380
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using LinearAlgebra: Diagonal, diag | ||
|
||
""" | ||
project([T::Type], x, dx) | ||
|
||
"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, | ||
it is assumed to be the type of `x`. | ||
|
||
It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s | ||
onto `Array`s -- this wouldn't be possible with type information alone because the neither | ||
`AbstractZero`s nor `T` know what size of `Array` to produce. | ||
""" | ||
function project end | ||
|
||
project(x, dx) = project(typeof(x), x, dx) | ||
|
||
# identity | ||
project(::Type{T}, x::T, dx::T) where T = dx | ||
|
||
### AbstractZero | ||
project(::Type{T}, x::T, dx::AbstractZero) where T = zero(x) | ||
|
||
### AbstractThunk | ||
project(::Type{T}, x::T, dx::AbstractThunk) where T = project(x, unthunk(dx)) | ||
|
||
|
||
### Number-types | ||
project(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = T(dx) | ||
project(::Type{T}, x::T, dx::Complex) where {T<:Real} = T(real(dx)) | ||
|
||
|
||
# Arrays | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a feeling we might need to support the same kwargs that E.g.
to get back a And We might need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can see where you're coming from, but is it not both cleaner and more general to do the kind of thing that @CarloLucibello suggests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these seem orthogonal to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hmmm maybe I've misunderstood your reasoning for needing Are you saying we might need it because we would want to apply i.e. let Is this the kind of situation you're imagining? |
||
project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx | ||
|
||
# for project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) | ||
project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) | ||
|
||
# for project(rand(2, 2), Diagonal(rand(2))) | ||
function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} | ||
return project(T, x, collect(dx)) | ||
end | ||
|
||
# for project([Foo(0.0), Foo(0.0)], ZeroTangent()) | ||
function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} | ||
return project.(Ref(T), x, Ref(dx)) | ||
end | ||
|
||
|
||
## Diagonal | ||
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} | ||
return Diagonal(project(V, diag(x), diag(dx))) | ||
end | ||
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} | ||
return Diagonal(project(V, diag(x), dx.diag)) | ||
end | ||
function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} | ||
return Diagonal(project(V, diag(x), dx)) | ||
end | ||
|
||
function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) | ||
return Tangent{typeof(x)}(diag=diag(dx)) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
struct Foo | ||
a::Float64 | ||
end | ||
|
||
Base.zero(::Foo) = Foo(0.0) | ||
Base.zero(::Type{Foo}) = "F0" | ||
|
||
@testset "projection" begin | ||
|
||
#identity | ||
@test Foo(1.2) == project(Foo(-0.2), Foo(1.2)) | ||
@test 3.2 == project(1.0, 3.2) | ||
@test 2.0 + 0.0im == project(1.0im, 2.0) | ||
|
||
@testset "From AbstractZero" begin | ||
@testset "to numbers" begin | ||
@test 0.0 == project(1.1, ZeroTangent()) | ||
@test 0.0f0 == project(1.1f0, ZeroTangent()) | ||
end | ||
|
||
@testset "to arrays (dense and structured)" begin | ||
@test zeros(2, 2) == project([1.0 2; 3 4], ZeroTangent()) | ||
@test Diagonal(zeros(2)) == project(Diagonal([1.0, 4]), ZeroTangent()) | ||
@test Diagonal(zeros(ComplexF64, 2)) == project(Diagonal([1.0 + 0im, 4]), ZeroTangent()) | ||
end | ||
|
||
@testset "to structs" begin | ||
@test Foo(0.0) == project(Foo(3.2), ZeroTangent()) | ||
end | ||
|
||
@testset "to arrays of structs" begin | ||
@test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], ZeroTangent()) | ||
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), ZeroTangent()) | ||
end | ||
end | ||
|
||
@testset "From AbstractThunk" begin | ||
@test 3.2 == project(1.0, @thunk(3.2)) | ||
@test Foo(3.2) == project(Foo(-0.2), @thunk(Foo(3.2))) | ||
@test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) | ||
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) | ||
end | ||
|
||
@testset "To number types" begin | ||
@testset "to subset" begin | ||
@test 3.2 == project(1.0, 3.2 + 3im) | ||
@test 3.2f0 == project(1.0f0, 3.2) | ||
@test 3.2f0 == project(1.0f0, 3.2 - 3im) | ||
end | ||
|
||
@testset "to superset" begin | ||
@test 2.0 + 0.0im == project(2.0 + 1.0im, 2.0) | ||
@test 2.0 == project(2.0, 2.0f0) | ||
end | ||
end | ||
|
||
@testset "To Arrays" begin | ||
# change eltype | ||
@test [1.0 2.0; 3.0 4.0] == project(zeros(2, 2), [1.0 2.0; 3.0 4.0]) | ||
@test [1.0f0 2; 3 4] == project(zeros(Float32, 2, 2), [1.0 2; 3 4]) | ||
|
||
# from a structured array | ||
@test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) | ||
|
||
# from an array of specials | ||
@test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) | ||
end | ||
|
||
@testset "Diagonal" begin | ||
d = Diagonal([1.0, 4.0]) | ||
t = Tangent{Diagonal}(;diag=[1.0, 4.0]) | ||
@test d == project(d, [1.0 2; 3 4]) | ||
@test d == project(d, t) | ||
@test project(Tangent, d, d) isa Tangent | ||
|
||
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) | ||
@test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) | ||
end | ||
|
||
# how to project to Upper/Lower Symmetric | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of
T
in addition tox
? It sounds like you have something in mind but I don't see what. In what directions may they differ? MayT
be abstract?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh T could be Tangent as in line 60
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But why? And when, I mean for which
x
would this be chosen somehow? Sounds a bit like an orthogonal feature.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial thought was to ditch
T
altogether, and privilege some tangent (probablyTangent
) representation as canonical.However, somewhere in the many threads about this, @oxinabox argued that we need to go both ways (to tangent and to natural differential say). e.g. if a rule author writes a rule only for
Diagonal
and wants to make sure that theTangent
is being transformed rightEdit: sorry, the other way, writes a rule for
Tangent
and wants to make sure all the tangents it gets areTangent
sThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, but this still seems a bit like it's rolling two things together. One is a projection to preserve the input's
x
's real-ness / Diagon-ality etc. This would be applied automatically to every single rule.The other is a way to massage certain
dx
s to a form digestible by certain rules. It's possible that there is some efficiency gained by fusing these processes into one function, but... for a first stab, it seems that combing them is adding complication?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I don't think so.
As i see it:
In that use,
project
is one way that lets us write things that can run on anything, while getting out the type (and associated basis) that we want.But it's not the only way, and its not always the best way.
In fact i think it rarely is, since it often results in allocating and computing things that we are going going to throw away.
Its just that it is the most generic way.
Alternatives for example, are things like structure preserving maps and broadcasts.
Which avoid allocations and computing things they are going to throw away.
Though i guess in those cases you can still
project
after, but it will be a identity operation.Like
convert(Array, [1,2,3,])
is an identity operation.Which brings me around to the general thing:
This is like
convert
except that it is for vector spaces and is aware of vector space concerns, where are your zeros (which is same as basis)Which is the same convert-like operation one want to massage on the way in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's my thinking. In my list of examples somewhere, any operation mixing real & complex arrays is something which ought to project. But making every rule do the ideal thing would be a lot of work for little benefit. Whereas projecting back to a real array for
hcat([1,2], [3+im, 4+im])
's gradient at least means it propagates no further.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, anyway, point is it is our vector space aware
convert
so it is the right name for both the thing you (often) want to do at the end.But it is also the one you might want to do at the start