diff --git a/examples/ad-tests.dx b/examples/ad-tests.dx index 1cb8a16a1..c2ac318a5 100644 --- a/examples/ad-tests.dx +++ b/examples/ad-tests.dx @@ -1,62 +1,27 @@ -f = lam x. x * x * x -:p jvp f 1.0 1.0 +:p f :: Real --o Real + f = llam x. x + tlinear f 2.0 +> 2.0 -> 3.0 +:p f :: Real --o Real + f = llam x. y = x; y + tlinear f 2.0 +> 2.0 +:p f :: Real --o Real + f = llam x. x + x + tlinear f 2.0 +> 4.0 -:p jvp (lam x. jvp f x 1.0) 1.0 1.0 +:p f :: Real --o Real + f = llam x. y = 2.0 * x + 3.0 * y + x + tlinear f 1.0 +> 7.0 -> 6.0 - - -:p grad f 1.0 - -> 3.0 - - -_, Nx = unpack range 3 - - -g x = for i::Nx. 3.0 * x * x - - -:p jvp g 2.0 1.0 - -> [12.0, 12.0, 12.0] - - -g2 (x, y) = x * y - - -:p grad g2 (1.0, 2.0) - -> (2.0, 1.0) - - -xs = for i::Nx. real iota.i * 1.0 - - -arrFun c = for i::Nx. c - - -:p let (_, pullback) = vjp arrFun 2.0 - in pullback xs - -> 3.0 - - -:p (transpose vsum 1.5) :: Nx=>Real - -> [1.5, 1.5, 1.5] - - -:p jvp vsum xs xs - -> 3.0 - - -:p transpose (lam x. for i. x.i) xs - -> [0.0, 1.0, 2.0] +:p f :: Real --o Real + f = llam x. (2.0 + 3.0) * x + tlinear f 1.0 +> 5.0 diff --git a/makefile b/makefile index 846e13072..e1a91dffe 100644 --- a/makefile +++ b/makefile @@ -29,7 +29,7 @@ tests: quine-tests quine-tests-interp repl-test stack-tests quine-tests: $(quine-test-targets) -quine-tests-interp: runinterp-eval-tests runinterp-interp-tests +quine-tests-interp: runinterp-eval-tests runinterp-ad-tests runinterp-interp-tests run-%: examples/%.dx misc/check-quine $^ $(dex) script --lit --allow-errors diff --git a/prelude.dx b/prelude.dx index d4caf1fa8..393a3ab3d 100644 --- a/prelude.dx +++ b/prelude.dx @@ -225,12 +225,12 @@ linearize f x0 = %linearize(lam x. f x, x0) jvp :: (a -> b) -> a -> a --o b jvp f x = llam t. snd (linearize f x) t -linearTranspose :: (a --o b) --o b --o a -linearTranspose = llam f ct. %linearTranspose(llam t. f t, ct) +tlinear :: (a --o b) --o b --o a +tlinear = llam f ct. %linearTranspose(llam t. f t, ct) vjp :: (a -> b) -> a -> (b, b --o a) vjp f x = (y, df) = linearize f x - (y, linearTranspose df) + (y, tlinear df) grad :: (a -> Real) -> a -> a grad f x = (_, pullback) = vjp f x diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 4c97bca76..edeb2109d 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -198,13 +198,25 @@ transpose :: Val -> Expr -> CotangentVals transpose ct expr = case expr of Lit _ -> mempty Var v _ -> MonMap $ M.singleton v [ct] - PrimOp op ts xs -> error "todo" + PrimOp op ts xs -> transposeOp op ct ts xs Decl (LetMono p rhs) body | hasFVs rhs -> cts <> transpose ct' rhs where (ct', cts) = sepCotangents p $ transpose ct body - App (Lam _ p body) e2 -> transpose ct (Decl (LetMono p e2) body) + App e1 e2 + | hasFVs e2 -> cts <> transpose ct' e2 + where + (Lam _ p body) = reduce e1 + (ct', cts) = sepCotangents p $ transpose ct body _ -> error $ "Can't transpose in interpreter: " ++ pprint expr +transposeOp :: Builtin -> Val -> [Type] -> [Val] -> CotangentVals +transposeOp op ct ts xs = case (op, ts, xs) of + (FAdd, _, ~[x1, x2]) -> transpose ct x1 <> transpose ct x2 + (FMul, _, ~[x1, x2]) | hasFVs x2 -> let ct' = mul ct (reduce x1) + in transpose ct' x2 + | otherwise -> let ct' = mul ct (reduce x2) + in transpose ct' x1 + hasFVs :: Expr -> Bool hasFVs expr = not $ null $ envNames $ freeVars expr @@ -220,6 +232,9 @@ sepCotangents p vs = (recTreeToVal tree, cts) put s' return x +mul :: Val -> Val -> Val +mul x y = realBinOp (*) [x, y] + recTreeToVal :: RecTree Val -> Val recTreeToVal (RecLeaf v) = v recTreeToVal (RecTree r) = RecCon Cart $ fmap recTreeToVal r