Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make jax optional #1767

Merged
merged 22 commits into from
Nov 10, 2021
Merged

Conversation

priyanshuone6
Copy link
Member

@priyanshuone6 priyanshuone6 commented Oct 29, 2021

Description

Make jax and jaxlib optional

Fixes #1701
Fixes #1775

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ flake8
  • All tests pass: $ python run-tests.py --unit
  • The documentation builds: $ cd docs and then $ make clean; make html

You can run all three at once, using $ python run-tests.py --quick.

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@valentinsulzer
Copy link
Member

Looks good, just a few more tests that need to be skipped if jax is not installed

system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create a pybamm.have_jax() function, for example in the jax solver, for this (see

def have_idaklu():
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the have_jax in util because it was causing circular imports

@valentinsulzer
Copy link
Member

Also, Jax should be installed at least for the ubuntu tests

@valentinsulzer
Copy link
Member

We should also say somewhere (maybe in the readme) which jax version should be installed for this to work.
Or, maybe, have a pybamm.install_jax() function that automatically installs the correct version?

@valentinsulzer
Copy link
Member

Re-running all the tests with and without jax is overkill (adds a lot of time to the CI). I think using the ubuntu tests to test with jax and the windows tests to test without jax is good enough. Otherwise we could specifically run only the jax tests separately, but I don't think that's necessary

@priyanshuone6
Copy link
Member Author

Looks good, just a few more tests that need to be skipped if jax is not installed

@tinosulzer Which tests are not yet covered

@valentinsulzer
Copy link
Member

These two:

======================================================================
ERROR: test_sensitivities (unit.test_solvers.test_base_solver.TestBaseSolver)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/work/PyBaMM/PyBaMM/tests/unit/test_solvers/test_base_solver.py", line 352, in test_sensitivities
    solver.set_up(model, inputs={'a': 0, 'b': 0})
  File "/home/runner/work/PyBaMM/PyBaMM/pybamm/solvers/base_solver.py", line 441, in set_up
    use_jacobian=False,
  File "/home/runner/work/PyBaMM/PyBaMM/pybamm/solvers/base_solver.py", line 251, in process
    func = pybamm.EvaluatorJax(func)
AttributeError: module 'pybamm' has no attribute 'EvaluatorJax'

======================================================================
ERROR: test_ida_roberts_klu_sensitivities (unit.test_solvers.test_idaklu_solver.TestIDAKLUSolver)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/runner/work/PyBaMM/PyBaMM/tests/unit/test_solvers/test_idaklu_solver.py", line 74, in test_ida_roberts_klu_sensitivities
    model, t_eval, inputs={"a": a_value},
  File "/home/runner/work/PyBaMM/PyBaMM/pybamm/solvers/base_solver.py", line 908, in solve
    self.set_up(model, ext_and_inputs_list[0], t_eval)
  File "/home/runner/work/PyBaMM/PyBaMM/pybamm/solvers/base_solver.py", line 441, in set_up
    use_jacobian=False,
  File "/home/runner/work/PyBaMM/PyBaMM/pybamm/solvers/base_solver.py", line 251, in process
    func = pybamm.EvaluatorJax(func)
AttributeError: module 'pybamm' has no attribute 'EvaluatorJax'

@priyanshuone6 priyanshuone6 marked this pull request as ready for review November 7, 2021 13:09
Copy link
Member

@valentinsulzer valentinsulzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good other than that docs error. Perhaps the JaxSolver class should exist and simply raise an error in __init__ if jax is not installed (so that JaxSolver() errors but the documentation builds). Alternatively, the simpler option is to install jax when building the docs

@priyanshuone6
Copy link
Member Author

Jax is used everywhere in pybamm/solvers/jax_bdf_solver.py and it is difficult to skip the entire module unless I add a bunch of if statements in all the methods. Can I instead remove from .solvers.jax_bdf_solver import jax_bdf_integrate from pybamm.__init__?

@codecov
Copy link

codecov bot commented Nov 9, 2021

Codecov Report

Merging #1767 (243ef8c) into develop (884d6e9) will decrease coverage by 0.04%.
The diff coverage is 97.25%.

Impacted file tree graph

@@             Coverage Diff             @@
##           develop    #1767      +/-   ##
===========================================
- Coverage    99.28%   99.24%   -0.05%     
===========================================
  Files          343      343              
  Lines        18945    18962      +17     
===========================================
+ Hits         18810    18818       +8     
- Misses         135      144       +9     
Impacted Files Coverage Δ
pybamm/util.py 97.12% <68.75%> (-2.88%) ⬇️
pybamm/solvers/jax_solver.py 96.77% <85.71%> (-1.54%) ⬇️
...bamm/expression_tree/operations/evaluate_python.py 97.58% <94.59%> (-0.68%) ⬇️
pybamm/solvers/jax_bdf_solver.py 98.04% <98.65%> (-0.21%) ⬇️
pybamm/__init__.py 94.78% <100.00%> (-0.14%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 884d6e9...243ef8c. Read the comment docs.

@valentinsulzer
Copy link
Member

Can you install jax for coverage?

@priyanshuone6
Copy link
Member Author

How do I test if not pybamm.have_jax()

@valentinsulzer
Copy link
Member

Windows tests should cover the not-jax case

@priyanshuone6 priyanshuone6 requested review from valentinsulzer and removed request for valentinsulzer November 10, 2021 14:10
Copy link
Member

@valentinsulzer valentinsulzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, thanks

@valentinsulzer valentinsulzer merged commit 337e602 into pybamm-team:develop Nov 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jax now supports mac with m1 chip make jax and the pybamm JaxSolver optional
2 participants