Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all output types in reverse-mode #169

Closed
wants to merge 18 commits into from
Closed

Conversation

sritchie
Copy link
Member

@sritchie sritchie commented Apr 23, 2024

This PR:

  • Adds a new mode argument to emmy.differential/extract-tangent, allowing us to use this for reverse mode instead of emmy.tape/->partials (which is now gone)
  • replaces emmy.tape/extract with emmy.differential/extract-id, now part of the IPerturbed protocol.

We need these two passes because:

  1. for structural outputs, we need to do a reverse-mode pass for each entry in the output. This creates a structure of Completed maps of node => sensitivity.
  2. then, for each entry in a structural INPUT, we need to replace that entry with a copy of this output with that input's ID selected out

@sritchie sritchie changed the base branch from main to sritchie/nested April 23, 2024 20:19
Base automatically changed from sritchie/nested to main April 23, 2024 22:22
Copy link

codecov bot commented Apr 23, 2024

Codecov Report

Attention: Patch coverage is 86.86869% with 52 lines in your changes are missing coverage. Please review.

Project coverage is 87.57%. Comparing base (f3ac544) to head (c14793b).

Files Patch % Lines
src/emmy/calculus/derivative.cljc 69.38% 11 Missing and 4 partials ⚠️
src/emmy/collection.cljc 7.69% 5 Missing and 7 partials ⚠️
src/emmy/autodiff.cljc 94.92% 5 Missing and 5 partials ⚠️
src/emmy/quaternion.cljc 60.00% 4 Missing ⚠️
src/emmy/dual.cljc 87.50% 2 Missing and 1 partial ⚠️
src/emmy/function.cljc 93.54% 2 Missing ⚠️
src/emmy/series.cljc 94.28% 2 Missing ⚠️
src/emmy/tape.cljc 90.90% 1 Missing and 1 partial ⚠️
src/emmy/matrix.cljc 50.00% 1 Missing ⚠️
src/emmy/polynomial.cljc 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #169      +/-   ##
==========================================
- Coverage   87.68%   87.57%   -0.12%     
==========================================
  Files          99      100       +1     
  Lines       15792    15796       +4     
  Branches      850      850              
==========================================
- Hits        13847    13833      -14     
- Misses       1095     1113      +18     
  Partials      850      850              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sritchie sritchie changed the title Better gradient protocol [in progress] Support all output types in reverse-mode Apr 24, 2024
@sritchie
Copy link
Member Author

sritchie commented Apr 24, 2024

Qs for @littleredcomputer:

  • what do you think of this protocol approach? I am not sure there is a way around doing two passes over the output structure, annoyingly... but wdyt of the switch? Another way to do it would be to have some sort of tag type where the tag carried around info about whether it was forward or reverse mode.

Looking forward to your feedback here...

@sritchie sritchie marked this pull request as ready for review April 24, 2024 12:27
MultiFn
(perturbed? [_] false)
(replace-tag [f old new] (replace-tag-fn f old new))
(extract-tangent [f tag] (extract-tangent-fn f tag))
(extract-tangent [f tag mode] (extract-tangent-fn f tag mode))
(extract-id [f id] (comp #(d/extract-id % id) f))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's kind of an fmap for most types... after your multimethod conversion, @littleredcomputer , do you feel like these should be multimethods too? Then we could have a ::functor type that supports fmap?

src/emmy/calculus/derivative.cljc Outdated Show resolved Hide resolved

(extract-id [this id]))

(defrecord Completed [v->partial]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

annoying that this now has to live in differential. It has to be here so that a Dual can return the correct default type if it's called with a different tag.

src/emmy/differential.cljc Outdated Show resolved Hide resolved
[IObj
(meta [_] m)
(withMeta [_ meta] (Series. xs meta))
(:clj
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting change that snuck in

@@ -142,7 +140,12 @@
;; This implementation is called if a tape ever makes it out of
;; forward-mode-differentiated function. If this happens, a [[TapeCell]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO update the docs here. We need a better narrative now that hints at reverse-mode in the differential.cljc exposition.

@@ -220,10 +220,11 @@

(checking "d/extract-tangent" 100 [tag gen/nat
tape (sg/tapecell gen/symbol)]
(is (zero? (d/extract-tangent tape tag))
;; TODO fix these for non-dual
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO add a test for the case where we pass a different mode

@sritchie
Copy link
Member Author

Superceded by #179, #182, #183, #185

@sritchie sritchie closed this Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant