-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,27 @@ | ||
f = lam x. x * x * x | ||
|
||
|
||
:p jvp f 1.0 1.0 | ||
:p f :: Real --o Real | ||
f = llam x. x | ||
tlinear f 2.0 | ||
> 2.0 | ||
|
||
> 3.0 | ||
:p f :: Real --o Real | ||
f = llam x. y = x; y | ||
tlinear f 2.0 | ||
> 2.0 | ||
|
||
:p f :: Real --o Real | ||
f = llam x. x + x | ||
tlinear f 2.0 | ||
> 4.0 | ||
|
||
:p jvp (lam x. jvp f x 1.0) 1.0 1.0 | ||
:p f :: Real --o Real | ||
f = llam x. y = 2.0 * x | ||
3.0 * y + x | ||
tlinear f 1.0 | ||
> 7.0 | ||
|
||
> 6.0 | ||
|
||
|
||
:p grad f 1.0 | ||
|
||
> 3.0 | ||
|
||
|
||
_, Nx = unpack range 3 | ||
|
||
|
||
g x = for i::Nx. 3.0 * x * x | ||
|
||
|
||
:p jvp g 2.0 1.0 | ||
|
||
> [12.0, 12.0, 12.0] | ||
|
||
|
||
g2 (x, y) = x * y | ||
|
||
|
||
:p grad g2 (1.0, 2.0) | ||
|
||
> (2.0, 1.0) | ||
|
||
|
||
xs = for i::Nx. real iota.i * 1.0 | ||
|
||
|
||
arrFun c = for i::Nx. c | ||
|
||
|
||
:p let (_, pullback) = vjp arrFun 2.0 | ||
in pullback xs | ||
|
||
> 3.0 | ||
|
||
|
||
:p (transpose vsum 1.5) :: Nx=>Real | ||
|
||
> [1.5, 1.5, 1.5] | ||
|
||
|
||
:p jvp vsum xs xs | ||
|
||
> 3.0 | ||
|
||
|
||
:p transpose (lam x. for i. x.i) xs | ||
|
||
> [0.0, 1.0, 2.0] | ||
:p f :: Real --o Real | ||
f = llam x. (2.0 + 3.0) * x | ||
tlinear f 1.0 | ||
> 5.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -225,12 +225,12 @@ linearize f x0 = %linearize(lam x. f x, x0) | |
jvp :: (a -> b) -> a -> a --o b | ||
jvp f x = llam t. snd (linearize f x) t | ||
|
||
linearTranspose :: (a --o b) --o b --o a | ||
linearTranspose = llam f ct. %linearTranspose(llam t. f t, ct) | ||
tlinear :: (a --o b) --o b --o a | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
jessebett
|
||
tlinear = llam f ct. %linearTranspose(llam t. f t, ct) | ||
|
||
vjp :: (a -> b) -> a -> (b, b --o a) | ||
vjp f x = (y, df) = linearize f x | ||
(y, linearTranspose df) | ||
(y, tlinear df) | ||
|
||
grad :: (a -> Real) -> a -> a | ||
grad f x = (_, pullback) = vjp f x | ||
|
@dougalm I have a question about the return type for this linear transpose that came up when I was talking with @duvenaud.
Isn't it true that the tangent space of a function is not in-general equivalent to the output space. While it might be that they are copies of the same space, they aren't equivalent.
Having a hard time with the precise terms to mean identical copies but not equivalent. e.g. a function
f: a --o b
would have tangent spaceT b
.Further, these tangent spaces are dependent on the point of linearity. Which means the tangent space of a
f: R^m -o R^n
linearized ata
might be a copy of R^n whose origin isf(a)
and the tangent space of f linearized ata'
is also a copy of R^n but with origin atf(a')
. These spaces are properly defined as setsR_a^n = {(a,v}|v \in R^n}' which is **disjoint** from the tangent at
R_a'^n = {(a',v}|v \in R^n}'.I'm curious if you think these aren't useful distinctions to make, or if there are some other facts about tangents that allow you to simplify the type in the way you've done above.
In particular, even though the tangent spaces in this example are both copies of R^n, their disjointed-ness means that, e.g., operations between elements of the tangent space developed at
a
are not defined with elements of the tangent space developed ata'
.However, in the type you've given above you drop both the distinction between
b
andT b
and further drop its dependence on the point of development.I guess I'm wondering why the signature shouldn't be something like
Where
T{x}
denotes the tangent (or cotangent) space developed atx
. Note that I haven't thought ahead enough about how the cotangent space should behave in this type. Whether it should be dependent on just theb in T{a} b
or botha
andT{a} b
explicitly.Again, not sure any of this is practically relevant. But if we're being very careful with types, does the type signature you've written above disallow, for instance, developing the linearizations tangent to two separate points and then performing operations between their elements? Because that shouldn't be possible.