Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all output types in reverse-mode #169

Closed
wants to merge 18 commits into from
9 changes: 7 additions & 2 deletions .dir-locals.el
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,10 @@
(cider-clojure-cli-aliases . ":test:cljs:nextjournal/clerk:dev")

;; Custom indentation:
(eval . (put-clojure-indent 'sci-macro :defn))
(eval . (put-clojure-indent 'careful-def 1)))))
(eval . (progn
;; the require here avoids projectile errors:
;; "Symbol’s function definition is void: put-clojure-indent"
(require 'clojure-mode)
(require 'cider)
(put-clojure-indent 'sci-macro :defn)
(put-clojure-indent 'careful-def 1))))))
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

## [unreleased]

- PR:

- all lifted function machinery and impls moves to `emmy.autodiff`
- `emmy.differential` moves to `emmy.dual`
- deletes perturbed?
- iperturbed impl for functions moves to `emmy.function`
- adds extract-id, makes reverse-mode work with lots of output types
- new `mode` arg for `extract-tangent`
- single `derivative` moves to `emmy.dual`

- #156:

- Makes forward- and reverse-mode automatic differentiation compatible with
Expand Down
23 changes: 11 additions & 12 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
(:refer-clojure :exclude [name])
(:require #?(:clj [clojure.pprint :as pprint])
[emmy.abstract.number :as an]
[emmy.differential :as d]
[emmy.autodiff :as ad]
[emmy.dual :as dual]
[emmy.function :as f]
[emmy.generic :as g]
[emmy.matrix :as m]
Expand Down Expand Up @@ -260,17 +261,15 @@
- the `tag` of the innermost active derivative call

And returns a folding function (designed for use
with [[emmy.structure/fold-chain]]) that

generates a new [[emmy.differential/Dual]] by applying the chain rule and
summing the partial derivatives for each perturbed argument in the input
structure."
with [[emmy.structure/fold-chain]]) that generates a new [[emmy.dual/Dual]] by
applying the chain rule and summing the partial derivatives for each perturbed
argument in the input structure."
[f primal-s tag]
(fn
([] 0)
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
([tangent] (dual/bundle-element (apply f primal-s) tangent tag))
([tangent [x path _]]
(let [dx (d/tangent x tag)]
(let [dx (dual/tangent x tag)]
(if (g/numeric-zero? dx)
tangent
(let [partial (literal-partial f path)]
Expand Down Expand Up @@ -309,7 +308,7 @@

and generates the proper return value for `((D f) xs)`.

In forward-mode AD this is a new [[emmy.differential/Dual]] generated by
In forward-mode AD this is a new [[emmy.dual/Dual]] generated by
applying the chain rule and summing the partial derivatives for each perturbed
argument in the input structure.

Expand All @@ -318,9 +317,9 @@
that input."
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(d/dual? dx) forward-mode-fold
(dual/dual? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
primal-s (s/mapr (fn [x] (ad/primal-of x tag)) s)]
(s/fold-chain (fold-fn f primal-s tag) s)))

(defn- check-argument-type
Expand Down Expand Up @@ -355,7 +354,7 @@
(if-let [[tag dx] (s/fold-chain
(fn
([] [])
([acc] (apply tape/tag+perturbation acc))
([acc] (apply ad/tag+perturbation acc))
([acc [d]] (conj acc d)))
s)]
(literal-derivative f s tag dx)
Expand Down
10 changes: 5 additions & 5 deletions src/emmy/abstract/number.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@
(memoize g/simplify))

(defn ^:no-doc simplify-numerical-expression
"This function will only simplify instances of [[expression/Literal]]; if `x` is
of that type, [[simplify-numerical-expression]] acts as a memoized version
of [[generic/simplify]]. Else, acts as identity.
"This function will only simplify instances of [[emmy.expression/Literal]]; if
`x` is of that type, [[simplify-numerical-expression]] acts as a memoized
version of [[generic/simplify]]. Else, acts as identity.

This trick is used in [[emmy.calculus.manifold]] to memoize
simplification _only_ for non-[[differential/Differential]] types."
This trick is used in [[emmy.calculus.manifold]] to memoize simplification
_only_ for non-perturbed types."
[x]
(if (literal-number? x)
(memoized-simplify x)
Expand Down
Loading