Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 1031 jax #1038

Merged
merged 56 commits into from
Jul 7, 2020
Merged

Issue 1031 jax #1038

merged 56 commits into from
Jul 7, 2020

Conversation

martinjrobins
Copy link
Contributor

@martinjrobins martinjrobins commented Jun 3, 2020

Description

Adds support for evaluating expression trees using JAX, and adds a new solver JaxSolver with two methods: A RK4(5) method using jax.experimental.odeint, and a custom BDF method written using jax. This solver is currently only useful for ode models with no termination events, and since JAX does not support sparse matrices they are all converted to dense arrays in the model. The latter limitation should hopefully be addressed over time by JAX (see jax-ml/jax#765).

This solver would be useful for running on a GPU/TPU, and, for very small state vectors (due to the restriction to dense arrays), is also very quick on a CPU. The next step after this PR would be to implement the adjoint sensitivities for the new BDF solver (already there for the RK45) and expose the raw jax solvers so a user can calculate sensitivities for parameter estimation.

UPDATE: just saw that jax is not supported under windows :(, so this would be linux/mac only....

Fixes #1031

Type of change

  • New feature (non-breaking change which adds functionality)

Key checklist:

  • No style issues: $ flake8
  • All tests pass: $ python run-tests.py --unit
  • The documentation builds: $ cd docs and then $ make clean; make html

You can run all three at once, using $ python run-tests.py --quick.

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@martinjrobins martinjrobins marked this pull request as draft June 16, 2020 09:55
@martinjrobins martinjrobins marked this pull request as ready for review June 26, 2020 12:11
@codecov
Copy link

codecov bot commented Jun 26, 2020

Codecov Report

Merging #1038 into develop will increase coverage by 0.04%.
The diff coverage is 98.51%.

Impacted file tree graph

@@             Coverage Diff             @@
##           develop    #1038      +/-   ##
===========================================
+ Coverage    97.76%   97.81%   +0.04%     
===========================================
  Files          243      245       +2     
  Lines        12667    13177     +510     
===========================================
+ Hits         12384    12889     +505     
- Misses         283      288       +5     
Impacted Files Coverage Δ
pybamm/solvers/jax_solver.py 96.42% <96.42%> (ø)
pybamm/expression_tree/operations/evaluate.py 97.14% <97.16%> (+1.30%) ⬆️
pybamm/solvers/jax_bdf_solver.py 99.39% <99.39%> (ø)
pybamm/solvers/base_solver.py 100.00% <100.00%> (ø)
pybamm/expression_tree/binary_operators.py 95.23% <0.00%> (+0.28%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 56c0383...a30ddf0. Read the comment docs.

Copy link
Member

@valentinsulzer valentinsulzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @martinjrobins , I don't understand most of the evaluate/BDF code but since tests pass I'm sure it's all fine - and thanks for documenting well in case anyone wants to check it. Looking forward to seeing how this works and hopefully they add sparse matrices soon.
How much work is it to adapt BDF to work for DAEs?
Also, do you think it's possible to write this code in a way that we can generate a C function from the solver?

pybamm/solvers/jax_bdf_solver.py Outdated Show resolved Hide resolved
pybamm/solvers/jax_bdf_solver.py Show resolved Hide resolved
pybamm/solvers/jax_bdf_solver.py Outdated Show resolved Hide resolved
pybamm/solvers/jax_solver.py Outdated Show resolved Hide resolved
@martinjrobins
Copy link
Contributor Author

Thanks Tino. Yea, the BDF code was a pain to write as you can't use any flow control or classes, it does make it difficult to read!

I'm not sure what changes need to be made to solve a DAE as I've not looked into that yet. I've not done the solution of daes before, but I would suspect you would have to code up a separate algorithm.

You can't generate C code from Jax as it never uses C. JAX traces your python code and produces its own expression tree. It uses that to build a computational graph using XLA. The input to XLA is an intermediate representation language called HLO IR, so rather than C, the language that you could emit would be HLO, this could then be compiled and run on any machine using XLA. I'm not sure exactly how to get the HLO using JAX, but in theory it should be a matter of taking any jax function (e.g. the jax_bdf_integrate function in this case), compiling it with JAX and then telling it to give you the generated HLO.

@martinjrobins martinjrobins merged commit 5ff13ed into develop Jul 7, 2020
@martinjrobins martinjrobins deleted the issue-1031-jax branch July 7, 2020 09:52
@valentinsulzer
Copy link
Member

valentinsulzer commented Jul 7, 2020

Thanks for the info. It looks like it's possible to adapt BDF to solve DAEs: https://www.cs.usask.ca/~spiteri/M314/notes/AP/chap10.pdf

@martinjrobins
Copy link
Contributor Author

martinjrobins commented Jul 7, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

investigate use of JAX to compile and autodiff expression trees
2 participants