Skip to content

Commit

Permalink
Add transposition rules for + and *
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Nov 19, 2019
1 parent 6427a9e commit 161a0ae
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 62 deletions.
77 changes: 21 additions & 56 deletions examples/ad-tests.dx
Original file line number Diff line number Diff line change
@@ -1,62 +1,27 @@
f = lam x. x * x * x


:p jvp f 1.0 1.0
:p f :: Real --o Real
f = llam x. x
tlinear f 2.0
> 2.0

> 3.0
:p f :: Real --o Real
f = llam x. y = x; y
tlinear f 2.0
> 2.0

:p f :: Real --o Real
f = llam x. x + x
tlinear f 2.0
> 4.0

:p jvp (lam x. jvp f x 1.0) 1.0 1.0
:p f :: Real --o Real
f = llam x. y = 2.0 * x
3.0 * y + x
tlinear f 1.0
> 7.0

> 6.0


:p grad f 1.0

> 3.0


_, Nx = unpack range 3


g x = for i::Nx. 3.0 * x * x


:p jvp g 2.0 1.0

> [12.0, 12.0, 12.0]


g2 (x, y) = x * y


:p grad g2 (1.0, 2.0)

> (2.0, 1.0)


xs = for i::Nx. real iota.i * 1.0


arrFun c = for i::Nx. c


:p let (_, pullback) = vjp arrFun 2.0
in pullback xs

> 3.0


:p (transpose vsum 1.5) :: Nx=>Real

> [1.5, 1.5, 1.5]


:p jvp vsum xs xs

> 3.0


:p transpose (lam x. for i. x.i) xs

> [0.0, 1.0, 2.0]
:p f :: Real --o Real
f = llam x. (2.0 + 3.0) * x
tlinear f 1.0
> 5.0
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tests: quine-tests quine-tests-interp repl-test stack-tests

quine-tests: $(quine-test-targets)

quine-tests-interp: runinterp-eval-tests runinterp-interp-tests
quine-tests-interp: runinterp-eval-tests runinterp-ad-tests runinterp-interp-tests

run-%: examples/%.dx
misc/check-quine $^ $(dex) script --lit --allow-errors
Expand Down
6 changes: 3 additions & 3 deletions prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ linearize f x0 = %linearize(lam x. f x, x0)
jvp :: (a -> b) -> a -> a --o b
jvp f x = llam t. snd (linearize f x) t

linearTranspose :: (a --o b) --o b --o a
linearTranspose = llam f ct. %linearTranspose(llam t. f t, ct)
tlinear :: (a --o b) --o b --o a

This comment has been minimized.

Copy link
@jessebett

jessebett Nov 19, 2019

@dougalm I have a question about the return type for this linear transpose that came up when I was talking with @duvenaud.

Isn't it true that the tangent space of a function is not in-general equivalent to the output space. While it might be that they are copies of the same space, they aren't equivalent.

Having a hard time with the precise terms to mean identical copies but not equivalent. e.g. a function f: a --o b would have tangent space T b.

Further, these tangent spaces are dependent on the point of linearity. Which means the tangent space of a f: R^m -o R^n linearized at a might be a copy of R^n whose origin is f(a) and the tangent space of f linearized at a' is also a copy of R^n but with origin at f(a'). These spaces are properly defined as sets R_a^n = {(a,v}|v \in R^n}' which is **disjoint** from the tangent at R_a'^n = {(a',v}|v \in R^n}'.

I'm curious if you think these aren't useful distinctions to make, or if there are some other facts about tangents that allow you to simplify the type in the way you've done above.

In particular, even though the tangent spaces in this example are both copies of R^n, their disjointed-ness means that, e.g., operations between elements of the tangent space developed at a are not defined with elements of the tangent space developed at a'.

However, in the type you've given above you drop both the distinction between b and T b and further drop its dependence on the point of development.

I guess I'm wondering why the signature shouldn't be something like

tlinear :: (a --o b) --o T{a} b --o T{b} a

Where T{x} denotes the tangent (or cotangent) space developed at x. Note that I haven't thought ahead enough about how the cotangent space should behave in this type. Whether it should be dependent on just the b in T{a} b or both a and T{a} b explicitly.

Again, not sure any of this is practically relevant. But if we're being very careful with types, does the type signature you've written above disallow, for instance, developing the linearizations tangent to two separate points and then performing operations between their elements? Because that shouldn't be possible.

This comment has been minimized.

Copy link
@jessebett

jessebett Nov 19, 2019

Possibly the cotangent dependence is correctly:

tlinear :: (a --o b) --o T{a} b --o T{T{a} b} a

I'll have to chew more.

This comment has been minimized.

Copy link
@dougalm

dougalm Nov 20, 2019

Author Collaborator

Hey Jesse, thanks for taking an interest in this linear AD stuff. These are good questions!

It might help to distinguish a few different questions we could ask about types for linearization and transposition:

  1. Math - what are the mathematical structures we're modeling?
  2. Possible types - how could we encode these in the language(s) of types?
  3. Practical types - what's a good choice for a practical language?

Starting with (1), let's think about functions from manifolds to manifolds. Each point on a manifold has a corresponding tangent space. As you point out, the tangent spaces at each point are distinct, though they may all be isomorphic to each other. We can locally linearize the function at a particular point, yielding a linear function from the tangent space of the original function's domain to the tangent space of its codomain. This linear function has a transpose, or dual, from the dual vector space of the original linear function's codomain, to the dual vector space of the linear function's domain. Note there's one dual space for the whole vector space, not one for each point.

Now to questions (2) and (3). How could/should we write the types of these functions? As you point out, one challenge is that the type of the linear map depends on a value, the linearization point. To type this in general, you need dependent types. It's not just an academic problem. If your language treats shapes as values instead of types, then a function on vectors has a domain like R^0 + R^1 + R^2 + ...., and you really do need to be careful about using the same length vector for your primals and your tangents. But I have no plans to pursue the path of value-dependent tangent spaces in Dex because (1) I'm trying to stick to a Hindley-Milner sweet spot and avoid dependent types and (2) Dex's shape types and non-emphasis on sum types make it less relevant.

Things are simpler if we limit ourselves to manifolds whose tangent spaces are all isomorphic to each other and we don't try to keep them distinct. Then we might write linearize like this.

linearize :: (Manifold a, Manifold b) => (a -> b) -> a -> (T a --o T b)

Where T a is a type-level function that gives the tangent space of a, which must satisfy the "is a manifold" constraint, Manifold a. Simplifying further, we could only consider manifolds that are also vector spaces, and identify a manifold with its tangent space. This is what we do in Autograd and JAX, and it's currently what we do on this branch (minus the explicit type class constraints, because the language doesn't have them yet) but that's only because we're starting simple. I think it's quite important to allow tangent spaces that are different from the primal space. It lets you do things like give a trivial (i.e. singleton set) tangent space to integers and bools. It's also important for treating function types, but that's its own can of worms.

Transpose is a little different. First, we're only dealing with linear functions, so there's no need to construct a tangent space since we already have a vector space. The thing we need to worry about for transposition is dual vector spaces. Mathematically, the dual of a vector space is just the space of linear maps from the vector space to reals. That suggests writing the type of transpose like this:

transpose :: (VectorSpace a, VectorSpace b) =>
    (a --o b) --o (b --o Real) --o (a --o Real)

It's even easy to implement:

transpose f ct = llam x. ct (f x)

But it misses the point, which is that it's much more efficient to represent covectors as vectors in the original space (or something similar). That's how we get the efficiency of reverse-mode differentiation. So we actually want a function that works on these efficient representations of covectors.

transpose :: (VectorSpace a, VectorSpace b) => (a --o b) --o CV b --o CV a

Where CV a means the type of the efficient representation of the covector a --o Real. For example, we might say CV Real = Real, CV (a, b) = (CV a, CV b) etc.

Finally, even if we represent covectors efficiently by their corresponding vectors, we still might want to maintain a type-level distinction between vectors and covectors to avoid getting them confused. It's interesting that physics notation often goes to great lengths to keep vectors and covectors separate (bra-ket notation, upper/lower indices and so on) which suggests it's important. It's definitely worth thinking about.

This comment has been minimized.

Copy link
@oxinabox

oxinabox Nov 20, 2019

Contributor

This is a fantastic discussion.
My thoughts on comes in from ChainRules.jl and ChainRulesCore.jl

Simplifying further, we could only consider manifolds that are also vector spaces, and identify a manifold with its tangent space. This is what we do in Autograd and JAX, and it's currently what we do on this branch (minus the explicit type class constraints, because the language doesn't have them yet) but that's only because we're starting simple.

I need words for this and for when it is possible.
Is it literally just vector spaces. I guess it is.
All the relaxing of those rules for Cotangent types (differentials) that I have been thinking about lately for ChainRules.jl, are predicated on the fact that operations on these types don't have to return the same type, just one that is also a valid differential type for the primal type.
But if you insist that the cotangent type is always the primal type, that all just goes away.
So you end up back at vector spaces.

I think it's quite important to allow tangent spaces that are different from the primal space. It lets you do things like give a trivial (i.e. singleton set) tangent space to integers and bools. It's also important for treating function types, but that's its own can of worms.

My pet example for this is if your primal type is DateTime
your differential needs to be a period type, e.g. Millisecond, Minute, etc.

The fact the trivial differential type for non-differentiable types being a singleton set is excellent.
How have I not ran into this idea before.
Very natural from differentials needing to be almost additive groups.
(My latest refinement is the set of all differential types valid for a primal type needs to be an additive group. And that can be just the zero.)

Finally, even if we represent covectors efficiently by their corresponding vectors, we still might want to maintain a type-level distinction between vectors and covectors to avoid getting them confused. It's interesting that physics notation often goes to great lengths to keep vectors and covectors separate (bra-ket notation, upper/lower indices and so on) which suggests it's important. It's definitely worth thinking about.

That would be cool and interesting to see.

This comment has been minimized.

Copy link
@dougalm

dougalm Nov 20, 2019

Author Collaborator

if your primal type is DateTime your differential needs to be a period type, e.g. Millisecond, Minute, etc.

Yes, lovely example.

The fact the trivial differential type for non-differentiable types being a singleton set is excellent. How have I not ran into this idea before.

I can't take credit for it! I think I first heard it from @axch. The Swift AD folks (@rxwei, @dan-zheng) are also doing some really nice work in this direction.

This comment has been minimized.

Copy link
@jessebett

jessebett Nov 20, 2019

@dougalm thanks very much for this detailed reply. It clarifies a lot of things for me and points me in the direction to review some more things I still don't quite understand.

Note there's one dual space for the whole vector space, not one for each point.

I suspect this is critical to my misunderstanding, so I will go and review dual spaces.

I agree with @oxinabox that I'd love to see some justification for whether the separation of vectors and covectors is a practical concern here (reminds me a bit of Julia's longest issue thread, JuliaLang/julia#4774. Maybe time for a Taking Function Transposes Seriously thread?).

I know @jrevels had some thoughts about the upper/lower indices in Ricci Calculus notation. And Computing Higher Order Derivatives of Matrix and Tensor Expressions from last NeurIPS suggests it might be practically important at least in higher derivatives (which as you know is interesting to me). Am I correct that these are related to the question of covector type?

@oxinabox I loved the DateTime example when you first described it on Julia slack, very glad you brought it out here. Also does the singleton set being the differential type for non-differentiables answer that discussion you were having before where you thought it could be nothing (? I actually don't remember where you landed)? Is that a contradiction?

This comment has been minimized.

Copy link
@oxinabox

oxinabox Nov 20, 2019

Contributor

Yeah, the singleton set things is great.

  • ChainRules two of them DNE (throws errors if it makes it to the gradient desent stage) and Zero (hard zero) in that category
  • Zygote calls all these things nothing, going to change that with more ChainRules integration.
  • I need to write a bunch of stuff down, and normalize terminology, its mostly scattered through issues and no where has it all written down.
tlinear = llam f ct. %linearTranspose(llam t. f t, ct)

vjp :: (a -> b) -> a -> (b, b --o a)
vjp f x = (y, df) = linearize f x
(y, linearTranspose df)
(y, tlinear df)

grad :: (a -> Real) -> a -> a
grad f x = (_, pullback) = vjp f x
Expand Down
19 changes: 17 additions & 2 deletions src/lib/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,25 @@ transpose :: Val -> Expr -> CotangentVals
transpose ct expr = case expr of
Lit _ -> mempty
Var v _ -> MonMap $ M.singleton v [ct]
PrimOp op ts xs -> error "todo"
PrimOp op ts xs -> transposeOp op ct ts xs
Decl (LetMono p rhs) body
| hasFVs rhs -> cts <> transpose ct' rhs
where (ct', cts) = sepCotangents p $ transpose ct body
App (Lam _ p body) e2 -> transpose ct (Decl (LetMono p e2) body)
App e1 e2
| hasFVs e2 -> cts <> transpose ct' e2
where
(Lam _ p body) = reduce e1
(ct', cts) = sepCotangents p $ transpose ct body
_ -> error $ "Can't transpose in interpreter: " ++ pprint expr

transposeOp :: Builtin -> Val -> [Type] -> [Val] -> CotangentVals
transposeOp op ct ts xs = case (op, ts, xs) of
(FAdd, _, ~[x1, x2]) -> transpose ct x1 <> transpose ct x2
(FMul, _, ~[x1, x2]) | hasFVs x2 -> let ct' = mul ct (reduce x1)
in transpose ct' x2
| otherwise -> let ct' = mul ct (reduce x2)
in transpose ct' x1

hasFVs :: Expr -> Bool
hasFVs expr = not $ null $ envNames $ freeVars expr

Expand All @@ -220,6 +232,9 @@ sepCotangents p vs = (recTreeToVal tree, cts)
put s'
return x

mul :: Val -> Val -> Val
mul x y = realBinOp (*) [x, y]

recTreeToVal :: RecTree Val -> Val
recTreeToVal (RecLeaf v) = v
recTreeToVal (RecTree r) = RecCon Cart $ fmap recTreeToVal r
Expand Down

0 comments on commit 161a0ae

Please sign in to comment.