Skip to content

Commit

Permalink
delete perturbed
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Aug 8, 2024
1 parent b80a0f4 commit 80c34fc
Show file tree
Hide file tree
Showing 18 changed files with 559 additions and 572 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

## [unreleased]

- #181:

- moves the generic implementations for `TapeCell` and `Dual` to `emmy.autodiff`

- moves `emmy.calculus.derivative` to `emmy.dual/derivative`

- removes `emmy.dual/perturbed?` from `IPerturbed`, as this is no longer used.

- #180 renames `emmy.differential` to `emmy.dual`, since the file now contains a
proper dual number implementation, not a truncated multivariate power series.

Expand Down
13 changes: 7 additions & 6 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
(:refer-clojure :exclude [name])
(:require #?(:clj [clojure.pprint :as pprint])
[emmy.abstract.number :as an]
[emmy.dual :as d]
[emmy.autodiff :as ad]
[emmy.dual :as dual]
[emmy.function :as f]
[emmy.generic :as g]
[emmy.matrix :as m]
Expand Down Expand Up @@ -268,9 +269,9 @@
[f primal-s tag]
(fn
([] 0)
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
([tangent] (dual/bundle-element (apply f primal-s) tangent tag))
([tangent [x path _]]
(let [dx (d/tangent x tag)]
(let [dx (dual/tangent x tag)]
(if (g/numeric-zero? dx)
tangent
(let [partial (literal-partial f path)]
Expand Down Expand Up @@ -318,9 +319,9 @@
that input."
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(d/dual? dx) forward-mode-fold
(dual/dual? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
primal-s (s/mapr (fn [x] (ad/primal-of x tag)) s)]
(s/fold-chain (fold-fn f primal-s tag) s)))

(defn- check-argument-type
Expand Down Expand Up @@ -355,7 +356,7 @@
(if-let [[tag dx] (s/fold-chain
(fn
([] [])
([acc] (apply tape/tag+perturbation acc))
([acc] (apply ad/tag+perturbation acc))
([acc [d]] (conj acc d)))
s)]
(literal-derivative f s tag dx)
Expand Down
Loading

0 comments on commit 80c34fc

Please sign in to comment.