Skip to content

Commit

Permalink
unify compile-fn and compile-state-fn; primitive mode for compiled fns (
Browse files Browse the repository at this point in the history
  • Loading branch information
littleredcomputer committed Apr 5, 2023
1 parent fe0e8bb commit ac6e444
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 283 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ public/
/.calva
/.lsp
/.cpcache
/.vscode
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@

- refactor JS rendering to allow compiler to use it

- adjust meaning of :native and :source compilation modes now you get
what's compatible with your execution environment but you can also
ask for a specific language, allowing tests to be bilingual
- adjust meaning of :native and :source compilation modes: now you get
what's compatible with your execution environment. You can also
ask for a specific language, allowing tests to be bilingual.

- #100:

Expand Down
525 changes: 342 additions & 183 deletions src/emmy/expression/compile.cljc

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/emmy/expression/cse.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
;; The goal of this process is to split some symbolic expression into:
;;
;; - a map of symbol -> redundant subexpression
;; - a new expression with each redundant subexpr replaced with its
;; - a new expression with each redundant subexpression replaced with its
;; corresponding symbol.
;;
;; The invariant we want to achieve is that the new expression, rehydrated using
Expand All @@ -40,7 +40,7 @@
;; earlier. We want to be careful that we only generate and bind subexpressions
;; that are actually used in the final computation.
;;
;; `discard-unferenced-syms` ensures this by removing any entry from our
;; `discard-unreferenced-syms` ensures this by removing any entry from our
;; replacement map that doesn't appear in the expression it's passed, or any
;; subexpression referenced by a symbol in the expression, etc etc.
;;
Expand Down Expand Up @@ -159,7 +159,7 @@
`:gensym-fn`: side-effecting function that returns a new, unique
variable name prefixed by its argument on each invocation.
`monotonic-symbol-genπerator` by default.
`monotonic-symbol-generator` by default.
NOTE that the symbols should appear in sorted order! Otherwise we can't
guarantee that the binding sequence passed to `continue` won't contain entries
Expand Down
76 changes: 30 additions & 46 deletions src/emmy/numerical/ode.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
[emmy.expression.compile :as c]
[emmy.structure :as struct]
#?(:clj [emmy.util :as u])
[emmy.util.stopwatch :as us]
[emmy.value :as v]
[taoensso.timbre :as log])
#?(:clj
Expand All @@ -22,11 +21,10 @@

(defn- flatten-into-primitive-array
"Copy the sequence `xs` into the primitive double array `arr`."
[arr xs]
(let [ix (atom -1)
asetter #?(:clj aset-double :cljs aset)]
(r/reduce (fn [a x]
(asetter a (swap! ix unchecked-inc) x)
[xs ^doubles arr]
(let [ix (atom -1)]
(r/reduce (fn [^doubles a ^double x]
(aset a (swap! ix unchecked-inc) x)
a)
arr
(r/flatten xs))))
Expand Down Expand Up @@ -171,39 +169,33 @@
:rawFunction true})]
(comp js->clj (.integrate solver x0 (double-array y0)))))))



(defn integration-opts
"Returns a map with the following kv pairs:
- `:integrator` an instance of `GraggBulirschStoerIntegrator`
- `:stopwatch` [[IStopwatch]] instance that records total evaluation time inside
the derivative function
- `:counter` an atom containing a `Long` that increments every time derivative fn
is called."
(defn ^:no-doc make-integrator*
"Returns a stream integrator configured to integrate a SICM state function.
The function is compiled (unless `compile?` is falsy in the `opts` map) with
the primitive calling convention to allow efficient transition between the
flat representation preferred by integrators and the structured form used in
the book. If the function is not compiled, a wrapper function is created to
accomplish the same thing."
[state-derivative derivative-args initial-state
{:keys [compile?] :as opts}]
(let [evaluation-time (us/stopwatch :started? false)
evaluation-count (atom 0)
flat-initial-state (flatten initial-state)
(let [flat-initial-state (flatten initial-state)
primitive-params (double-array derivative-args)
derivative-fn (if compile?
(let [f' (c/compile-state-fn state-derivative derivative-args initial-state)]
(fn [y]
(f' y (or derivative-args []))))
(c/compile-state-fn state-derivative derivative-args initial-state
{:calling-convention :primitive})
(do (log/warn "Not compiling function for ODE analysis")
(let [d:dt (apply state-derivative derivative-args)
array->state #(struct/unflatten % initial-state)]
(comp d:dt array->state))))
equations (fn [_ y out]
(us/start evaluation-time)
(swap! evaluation-count inc)
(flatten-into-primitive-array out (derivative-fn y))
(us/stop evaluation-time))

integrator (stream-integrator equations 0 flat-initial-state opts)]
{:integrator integrator
:stopwatch evaluation-time
:counter evaluation-count}))
(let [f' (apply state-derivative derivative-args)]
(fn [ys yps _]
(-> ys
(struct/unflatten initial-state)
f'
(flatten-into-primitive-array yps))))))
equations (fn [_ ys yps]
;; TODO: should we consider allowing an option to add a dummy
;; x-parameter in the compiled code, which would allow unwrapping
;; this last layer?
(derivative-fn ys yps primitive-params))]
(stream-integrator equations 0 flat-initial-state opts)))

(defn make-integrator
"make-integrator takes a state derivative function (which in this
Expand All @@ -225,10 +217,8 @@
([initial-state step-size t]
(call initial-state step-size t {}))
([initial-state step-size t {:keys [observe] :as opts}]
(let [total-time (us/stopwatch :started? true)
latest (atom [0 nil])
{:keys [integrator stopwatch counter]}
(integration-opts state-derivative derivative-args initial-state opts)
(let [latest (atom [0 nil])
integrator (make-integrator* state-derivative derivative-args initial-state opts)
array->state #(struct/unflatten % initial-state)
step (fn [x]
(let [y (array->state (integrator x))]
Expand All @@ -240,8 +230,6 @@
(when (not (near? t (nth @latest 0)))
(step t))
(integrator)
(us/stop total-time)
(log/info "#" @counter "total" (us/repr total-time) "f" (us/repr stopwatch))
(nth @latest 1)))))

(defn state-advancer
Expand Down Expand Up @@ -287,11 +275,7 @@
state derivative (and its argument package) from [0 to t1] in steps
of size dt"
[state-derivative state-derivative-args initial-state t1 dt]
(let [opts (integration-opts state-derivative
state-derivative-args
initial-state
{})
f (:integrator opts)]
(let [f (make-integrator* state-derivative state-derivative-args initial-state {})]
(try
(mapv f (for [x (range 0 (+ t1 dt) dt)
:when (< x (+ t1 (/ dt 2)))]
Expand Down
22 changes: 12 additions & 10 deletions test/emmy/examples/double_pendulum_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@
" const _21 = Math.sin(_15);\n"
" const _22 = Math.pow(_21, 2.0);\n"
" return [1.0, [y04, y05], [(- l1 * m2 * _10 * _19 * _21 - l2 * m2 * _11 * _21 + g * m2 * _19 * _14 - g * m1 * _13 - g * m2 * _13) / (l1 * m2 * _22 + l1 * m1), (l2 * m2 * _11 * _19 * _21 + l1 * m1 * _10 * _21 + l1 * m2 * _10 * _21 + g * m1 * _19 * _13 + g * m2 * _19 * _13 - g * m1 * _14 - g * m2 * _14) / (l2 * m2 * _22 + l2 * m1)]];"))]
(c/compile-state-fn* double/state-derivative
'[m1 m2 l1 l2 g]
(up 't (up 'theta 'phi) (up 'thetadot 'phidot))
{:mode :js
:flatten? false
:generic-params? false
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true})))
(c/compile-state-fn double/state-derivative
'[m1 m2 l1 l2 g]
(up 't (up 'theta 'phi) (up 'thetadot 'phidot))
{:mode :js
:calling-convention :structure
:generic-params? false
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true})))


(is (= ["[y01, y02, y03, y04, y05]"
Expand All @@ -110,11 +110,12 @@
" const _26 = Math.sin(_20);\n"
" const _27 = Math.pow(_26, 2.0);\n"
" return [1.0, [y04, y05], [(- p07 * p08 * _15 * _24 * _26 - p07 * p09 * _16 * _26 + p07 * p10 * _24 * _19 - p06 * p10 * _18 - p07 * p10 * _18) / (p07 * p08 * _27 + p06 * p08), (p07 * p09 * _16 * _24 * _26 + p06 * p08 * _15 * _26 + p07 * p08 * _15 * _26 + p06 * p10 * _24 * _18 + p07 * p10 * _24 * _18 - p06 * p10 * _19 - p07 * p10 * _19) / (p07 * p09 * _27 + p06 * p09)]];"))]
(c/compile-state-fn*
(c/compile-state-fn
double/state-derivative
'[1 1 1 1 'g]
(up 't (up 'theta 'phi) (up 'thetadot 'phidot))
{:mode :js
:calling-convention :flat
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true})))

Expand Down Expand Up @@ -157,12 +158,13 @@
" const _60 = _18 * _20 * _23 * _56;\n"
" const _62 = _60 + _59 + _58 + _32 + _34;\n"
" return [1.0, [(- l_1 * y05 * _46 + l_2 * y04) / (_18 * l_2 * m_2 * _57 + _18 * l_2 * m_1), (- l_2 * m_2 * y04 * _46 + l_1 * m_1 * y05 + l_1 * m_2 * y05) / (l_1 * _20 * _23 * _57 + l_1 * _20 * m_1 * m_2)], [(- g * _19 * _20 * m_1 * _23 * _56 * _28 - g * _19 * _20 * _24 * _56 * _28 + 2.0 * g * _19 * _20 * _22 * m_2 * _55 * _28 + 4.0 * g * _19 * _20 * m_1 * _23 * _55 * _28 + 2.0 * g * _19 * _20 * _24 * _55 * _28 - g * _19 * _20 * Math.pow(m_1, 3.0) * _28 - 3.0 * g * _19 * _20 * _22 * m_2 * _28 - 3.0 * g * _19 * _20 * m_1 * _23 * _28 - g * _19 * _20 * _24 * _28 - l_1 * l_2 * m_2 * y04 * y05 * _55 * _50 + _18 * m_1 * _26 * _46 * _50 + _18 * m_2 * _26 * _46 * _50 + _20 * m_2 * _25 * _46 * _50 - l_1 * l_2 * m_1 * y04 * y05 * _50 - l_1 * l_2 * m_2 * y04 * y05 * _50) / _62, (- g * _18 * _21 * _24 * _56 * _29 - 2.0 * g * _18 * _21 * m_1 * _23 * _57 * _29 + 2.0 * g * _18 * _21 * _24 * _55 * _29 - g * _18 * _21 * _22 * m_2 * _29 - g * _18 * _21 * _24 * _29 + l_1 * l_2 * m_2 * y04 * y05 * _55 * _50 - _18 * m_1 * _26 * _46 * _50 - _18 * m_2 * _26 * _46 * _50 - _20 * m_2 * _25 * _46 * _50 + l_1 * l_2 * m_1 * y04 * y05 * _50 + l_1 * l_2 * m_2 * y04 * y05 * _50) / _62]];")]
(c/compile-state-fn*
(c/compile-state-fn
#(e/Hamiltonian->state-derivative
(e/Lagrangian->Hamiltonian
(double/L 'm_1 'm_2 'l_1 'l_2 'g)))
[]
(e/->H-state 't (up 'theta 'psi) (down 'p_theta 'p_psi))
{:mode :js
:calling-convention :flat
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true}))))))
76 changes: 73 additions & 3 deletions test/emmy/examples/driven_pendulum_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[emmy.env :as e :refer [up /]]
[emmy.examples.driven-pendulum :as driven]
[emmy.expression.analyze :as a]
[emmy.expression.compile :refer [compile-state-fn*]]
[emmy.expression.compile :refer [compile-state-fn]]
[emmy.simplify :refer [hermetic-simplify-fixture]]))

(use-fixtures :each hermetic-simplify-fixture)
Expand Down Expand Up @@ -35,17 +35,87 @@
(driven/evolver {:t (/ 3 60) :dt (/ 1 60) :observe observe})
(is (= 4 (count @o))))))

(deftest as-clojure
(let [compile (fn [calling-convention]
(compile-state-fn driven/state-derivative
'[m l g a omega]
(up 't 'theta 'thetadot)
{:mode :clj
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true
:calling-convention calling-convention}))]
(is (= `(fn [[~'y01 ~'y02 ~'y03] [~'p04 ~'p05 ~'p06 ~'p07 ~'p08]]
(let [~'_09 (~'Math/sin ~'y02)]
(clojure.core/vector 1.0
~'y03
(clojure.core//
(clojure.core/+
(clojure.core/* ~'p07 (~'Math/pow ~'p08 2.0) ~'_09 (~'Math/cos (clojure.core/* ~'p08 ~'y01)))
(clojure.core/* -1.0 ~'p06 ~'_09))
~'p05))))
(compile :structure)))
(is (= `(clojure.core/fn
[~'a09 ~'a10 ~'a11]
(clojure.core/let
[~'y01 (clojure.core/aget ~'a09 0)
~'y02 (clojure.core/aget ~'a09 1)
~'y03 (clojure.core/aget ~'a09 2)
~'p04 (clojure.core/aget ~'a11 0)
~'p05 (clojure.core/aget ~'a11 1)
~'p06 (clojure.core/aget ~'a11 2)
~'p07 (clojure.core/aget ~'a11 3)
~'p08 (clojure.core/aget ~'a11 4)
~'_12 (~'Math/sin ~'y02)]
(clojure.core/doto
~'a10
(clojure.core/aset 0 1.0)
(clojure.core/aset 1 ~'y03)
(clojure.core/aset
2
(clojure.core//
(clojure.core/+
(clojure.core/*
~'p07
(~'Math/pow ~'p08 2.0)
~'_12
(~'Math/cos (clojure.core/* ~'p08 ~'y01)))
(clojure.core/* -1.0 ~'p06 ~'_12))
~'p05)))))
(compile :primitive)))))

(deftest as-javascript
(is (= ["[y01, y02, y03]"
"[p04, p05, p06, p07, p08]"
(maybe-defloatify
(str
" const _09 = Math.sin(y02);\n"
" return [1.0, y03, (p07 * Math.pow(p08, 2.0) * _09 * Math.cos(p08 * y01) - p06 * _09) / p05];"))]
(compile-state-fn* driven/state-derivative
(compile-state-fn driven/state-derivative
'[m l g a omega]
(up 't 'theta 'thetadot)
{:mode :js
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true})))
(is (= ["a09" "a10" "a11"
(maybe-defloatify
(str
" const y01 = a09[0];\n"
" const y02 = a09[1];\n"
" const y03 = a09[2];\n"
" const p04 = a11[0];\n"
" const p05 = a11[1];\n"
" const p06 = a11[2];\n"
" const p07 = a11[3];\n"
" const p08 = a11[4];\n"
" const _12 = Math.sin(y02);\n"
" a10[0] = 1.0;\n"
" a10[1] = y03;\n"
" a10[2] = (p07 * Math.pow(p08, 2.0) * _12 * Math.cos(p08 * y01) - p06 * _12) / p05;"))]
(compile-state-fn driven/state-derivative
'[m l g a omega]
(up 't 'theta 'thetadot)
{:mode :js
:calling-convention :primitive
:gensym-fn (a/monotonic-symbol-generator 2)
:deterministic? true})))

Expand All @@ -59,7 +129,7 @@
" const _08 = Math.sin(y02);\n"
" const _09 = Math.sin(_04);\n"
" return [1.0, (a * l * m * omega * _08 * _09 + y03) / (_06 * m), (- Math.pow(a, 2.0) * l * m * Math.pow(omega, 2.0) * _08 * Math.pow(_09, 2.0) * _05 - a * omega * y03 * _09 * _05 - g * _06 * m * _08) / l];"))]
(compile-state-fn*
(compile-state-fn
#(e/Hamiltonian->state-derivative
(e/Lagrangian->Hamiltonian
(driven/L 'm 'l 'g 'a 'omega)))
Expand Down
Loading

0 comments on commit ac6e444

Please sign in to comment.