Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Sketch project implementation #306

Closed
wants to merge 1 commit into from
Closed

Conversation

willtebbutt
Copy link
Member

This is a sketch of my proposal in #286, so that we've got something concrete to base our discussions upon.

I'm not sure that project is a good name for the function, but I think that it might roughly correspond to a projection the usual sense, so maybe it's fine 🤷

It's slightly frustrating, but it appears that we'll need to carry around some additional data about the primal / a related tangent to ensure that e.g. size information is available when projecting onto an Array from an AbstractZero. We've seen this before under a different guise with to_vec.

You can potentially use this kind of functionality to

  1. make rule-implementer's lives easier, by enabling them to convert whatever tangent they get into their preferred type
  2. provide helper functionality to e.g. Flux to ensure that users get "the types that they expect"
  3. implement + between different representations of tangents, as shown at the bottom of src/projection.jl.

Note: I don't know how much time I'm going to have to do this properly, so if someone wants to pick this up, please feel free!

@codecov-io
Copy link

codecov-io commented Feb 24, 2021

Codecov Report

Merging #306 (8ae1457) into master (f81ccda) will decrease coverage by 2.93%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #306      +/-   ##
==========================================
- Coverage   89.40%   86.47%   -2.94%     
==========================================
  Files          13       14       +1     
  Lines         472      488      +16     
==========================================
  Hits          422      422              
- Misses         50       66      +16     
Impacted Files Coverage Δ
src/ChainRulesCore.jl 100.00% <ø> (ø)
src/projection.jl 0.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f81ccda...8ae1457. Read the comment docs.


# One use for this functionality is to make it easy to define addition between two different
# representations of the same tangent. This also makes it clear that the
Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't primal + tangent return a primal?
I.e. as in

julia> struct Foo
           a::Float64
       end

julia> foo = Foo(2.3)
Foo(2.3)

julia> tfoo = Tangent{Foo}(;a=3.2)
Tangent{Foo}(a = 3.2,)

julia> foo + tfoo
Foo(5.5)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we found a occurance where + between tangents, and expodential map (which we call primal + tangent)
might disagree?

I think they do actually agree, but I would need to think more to be sure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that I don't think we've found a concrete example where the exponential map doesn't agree with + in our existing code, but it's sufficiently straightforward to concoct such examples, that I don't want to do it (really type whose instantiations live in a subset of R^D don't make sense).

My inclination would be instead to ensure that we never need to add a primal to a (co)tangent inside an AD by ensuring that tangents are "projected" (still not sure that's the best name) onto a common type before we try to add them.

Possibly the more compelling reason to avoid ever calling + on a primal and tangent is to do with @mzgubic 's point: it's unclear how to define it, because it's not clear what the output ought to be. The useful thing to do if we were "accumulating" inside reverse-mode would be to output another (co)tangent, but the useful thing to do for users who want to gradient-ascend is to output a primal.

So my instinct is to avoid defining + between primals and (co)tangents, and instead define the exponential map to transform between spaces.

using LinearAlgebra: Diagonal, diag

"""
project(T::Type, x, dx)
Copy link
Member

@mzgubic mzgubic Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make T optional to avoid duplication when it is the same as typeof(x)?

Perhaps we could drop T entirely and say that project will map the differential into a canonical form, and we decide what the canonical form is?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah -- we'd need to privelege a particular representation (presumably Tangent), but I think that this would make sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to close over x? I was thinking elsewhere that the right pattern is

info = preproject(x)
...
dx = project(dx_raw, info...)

with

preproject(x) = ()
preproject(x::Number) = (typeof(x),)
preproject(x::AbstractArray) = (typeof(x), axes(x))

This lets you pass information besides the type when that's required. (I think the motivating case needed a reshape.)

Copy link
Member Author

@willtebbutt willtebbutt Jun 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that you could do this, but since we'll be getting opaque closures in 1.7, can we just assume that we'll get to use them, and therefore not have to worry about optimising this stuff away manually? (It's entirely possible that I'm overestimating the compiler's ability to optimise stuff away here though)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's true. What is the status BTW, I haven't kept track, had an idea they existed but were still slow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much any Tangent -> AbstractArray conversion needs size info I think, you probably need to know about the locations of non-zero elements in sparse arrays, that kind of thing. Off the top of my head, I can't think of others though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re sparse arrays... is it obvious how many structural zeros gradient(f, sinpi.(sparse(0:0.5:2))) should have? Maybe there's an issue discussing this but it doesn't super-obvious to me; perhaps there are more realistic examples than mine which would clarify what should happen. (I don't mean structurally sparse like Diagonal, obviously, I mean SparseArrays.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My claim is that it should have precisely the same number as the primal -- i.e. zeros should be treated structurally (it's what I argue for in the Abstract Arrays Dilemma docs). This is something that we need to reach a consensus on though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but a bigger question than this mini-thread. Are there other clear, concrete, examples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. Off the top of my head, I can't think of any other obvious ones. Possibly it's going to be limited to array-like things, because they're pretty much the only case in which we need to transform between Tangents and array-like representations of tangents.

@mzgubic mzgubic mentioned this pull request Jun 22, 2021
2 tasks
@mzgubic
Copy link
Member

mzgubic commented Jul 6, 2021

cosed in favour of #385

@mzgubic mzgubic closed this Jul 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants