From c9e1b599073985ddf96fad6bb15e4f0d4fcd5620 Mon Sep 17 00:00:00 2001 From: Priyanshu Agarwal Date: Tue, 2 Nov 2021 23:43:43 +0530 Subject: [PATCH] Make jax optional --- docs/requirements.txt | 2 -- pybamm/__init__.py | 12 ++----- .../operations/evaluate_python.py | 11 ++---- requirements.txt | 4 +-- setup.py | 9 ++--- .../test_lithium_ion/test_spm.py | 6 ++-- .../test_lithium_ion/test_spme.py | 6 ++-- tests/unit/test_citations.py | 7 ++-- .../test_operations/test_evaluate_python.py | 35 +++++-------------- .../unit/test_solvers/test_jax_bdf_solver.py | 8 ++--- tests/unit/test_solvers/test_jax_solver.py | 9 +++-- tests/unit/test_solvers/test_scipy_solver.py | 6 ++-- 12 files changed, 32 insertions(+), 83 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b2754b438d..9f8d290eee 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,8 +7,6 @@ autograd >= 1.2 scikit-fem >= 0.2.0 casadi >= 3.5.0 imageio>=2.9.0 -jax==0.2.12 -jaxlib==0.1.70 jupyter # For example notebooks pybtex # Note: Matplotlib is loaded for debug plots but to ensure pybamm runs diff --git a/pybamm/__init__.py b/pybamm/__init__.py index aab8a54c4a..78d5839f38 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -8,6 +8,7 @@ import sys import os import platform +import importlib.util # # Version info @@ -102,10 +103,7 @@ def version(formatted=False): EvaluatorPython, ) -if not ( - platform.system() == "Windows" - or (platform.system() == "Darwin" and "ARM64" in platform.version()) -): +if importlib.util.find_spec("jax"): from .expression_tree.operations.evaluate_python import EvaluatorJax from .expression_tree.operations.evaluate_python import JaxCooMatrix @@ -226,11 +224,7 @@ def version(formatted=False): from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes from .solvers.scipy_solver import ScipySolver -# Jax not supported under windows -if not ( - platform.system() == "Windows" - or (platform.system() == "Darwin" and "ARM64" in platform.version()) -): +if importlib.util.find_spec("jax"): from .solvers.jax_solver import JaxSolver from .solvers.jax_bdf_solver import jax_bdf_integrate diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index ff8bf92853..4953244309 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -6,13 +6,12 @@ import numpy as np import scipy.sparse from collections import OrderedDict - +import importlib.util import numbers from platform import system, version -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): +if importlib.util.find_spec("jax"): import jax - from jax.config import config config.update("jax_enable_x64", True) @@ -104,12 +103,6 @@ def create_jax_coo_matrix(value): return JaxCooMatrix(row, col, data, value.shape) -else: - - def create_jax_coo_matrix(value): # pragma: no cover - raise NotImplementedError("Jax is not available on Windows") - - def id_to_python_variable(symbol_id, constant=False): """ This function defines the format for the python variable names used in find_symbols diff --git a/requirements.txt b/requirements.txt index 7753d6ab6f..cbd88f3be0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy >= 1.16 +numpy >= 1.16 scipy >= 1.3 pandas >= 0.24 anytree >= 2.4.3 @@ -6,8 +6,6 @@ autograd >= 1.2 scikit-fem >= 0.2.0 casadi >= 3.5.0 imageio>=2.9.0 -jax==0.2.12 -jaxlib==0.1.70 jupyter # For example notebooks pybtex sympy==1.8 diff --git a/setup.py b/setup.py index 90ba21485b..b63aece05b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import logging import subprocess from pathlib import Path -from platform import system, version +from platform import system import wheel.bdist_wheel as orig import site import shutil @@ -162,11 +162,6 @@ def compile_KLU(): idaklu_ext = Extension("pybamm.solvers.idaklu", ["pybamm/solvers/c_solvers/idaklu.cpp"]) ext_modules = [idaklu_ext] if compile_KLU() else [] -jax_dependencies = [] -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): - jax_dependencies = ["jax==0.2.12", "jaxlib==0.1.70"] - - # Load text for description and license with open("README.md", encoding="utf-8") as f: readme = f.read() @@ -198,7 +193,6 @@ def compile_KLU(): "scikit-fem>=0.2.0", "casadi>=3.5.0", "imageio>=2.9.0", - *jax_dependencies, "jupyter", # For example notebooks "pybtex", "sympy==1.8", @@ -209,6 +203,7 @@ def compile_KLU(): "matplotlib>=2.0", ], extras_require={ + "jax": ["jax", "jaxlib"], "docs": ["sphinx>=1.5", "guzzle-sphinx-theme"], # For doc generation "dev": [ "flake8>=3", # For code style checking diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py index a8bb0bd82b..de450f3f6f 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spm.py @@ -6,7 +6,7 @@ import numpy as np import unittest from platform import system, version - +import importlib.util class TestSPM(unittest.TestCase): def test_basic_processing(self): @@ -71,9 +71,7 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if not ( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) - ): + if importlib.util.find_spec("jax"): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py index d1c88be1c1..93ef17dcda 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_spme.py @@ -3,7 +3,7 @@ # import pybamm import tests - +import importlib.util import numpy as np import unittest from platform import system, version @@ -79,9 +79,7 @@ def test_optimisations(self): np.testing.assert_array_almost_equal(original, using_known_evals) np.testing.assert_array_almost_equal(original, to_python) - if not ( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) - ): + if importlib.util.find_spec("jax"): to_jax = optimtest.evaluate_model(to_jax=True) np.testing.assert_array_almost_equal(original, to_jax) diff --git a/tests/unit/test_citations.py b/tests/unit/test_citations.py index 17c2ad2d5b..d83da96ca8 100644 --- a/tests/unit/test_citations.py +++ b/tests/unit/test_citations.py @@ -4,7 +4,7 @@ import pybamm import unittest from platform import system, version - +import importlib.util class TestCitations(unittest.TestCase): def test_citations(self): @@ -255,10 +255,7 @@ def test_solver_citations(self): pybamm.IDAKLUSolver() self.assertIn("Hindmarsh2005", citations._papers_to_cite) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_jax_citations(self): citations = pybamm.citations citations._reset() diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index 3cca9eef2d..0206e99728 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -9,7 +9,10 @@ import scipy.sparse from collections import OrderedDict from platform import system, version +import importlib.util +if importlib.util.find_spec("jax"): + import jax def test_function(arg): return arg + arg @@ -457,10 +460,7 @@ def test_evaluator_python(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_find_symbols_jax(self): # test sparse conversion constant_symbols = OrderedDict() @@ -473,10 +473,7 @@ def test_find_symbols_jax(self): list(constant_symbols.values())[0].toarray(), A.entries.toarray() ) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_evaluator_jax(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.StateVector(slice(1, 2)) @@ -638,10 +635,7 @@ def test_evaluator_jax(self): result = evaluator.evaluate(t=t, y=y) np.testing.assert_allclose(result, expr.evaluate(t=t, y=y)) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_evaluator_jax_jacobian(self): a = pybamm.StateVector(slice(0, 1)) y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])] @@ -656,10 +650,7 @@ def test_evaluator_jax_jacobian(self): result_true = evaluator_jac.evaluate(t=None, y=y) np.testing.assert_allclose(result_test, result_true) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_evaluator_jax_debug(self): a = pybamm.StateVector(slice(0, 1)) expr = a ** 2 @@ -667,10 +658,7 @@ def test_evaluator_jax_debug(self): evaluator = pybamm.EvaluatorJax(expr) evaluator.debug(y=y_test) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_evaluator_jax_inputs(self): a = pybamm.InputParameter("a") expr = a ** 2 @@ -678,13 +666,8 @@ def test_evaluator_jax_inputs(self): result = evaluator.evaluate(inputs={"a": 2}) self.assertEqual(result, 4) - @unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", - ) + @unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") def test_jax_coo_matrix(self): - import jax - A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2)) Adense = jax.numpy.array([[1.0, 0], [0, 2.0]]) v = jax.numpy.array([[2.0], [1.0]]) diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 772bc937d0..c2aeab6dc9 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -5,15 +5,13 @@ import time import numpy as np from platform import system, version +import importlib.util -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): +if importlib.util.find_spec("jax"): import jax -@unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", -) +@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") class TestJaxBDFSolver(unittest.TestCase): def test_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 74dccdaf99..255a9a5c46 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -1,3 +1,4 @@ +import importlib import pybamm import unittest from tests import get_mesh_for_testing @@ -5,15 +6,13 @@ import time import numpy as np from platform import system, version +import importlib.util -if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())): +if importlib.util.find_spec("jax"): import jax -@unittest.skipIf( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()), - "JAX not supported on windows or Mac M1", -) +@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax") class TestJaxSolver(unittest.TestCase): def test_model_solver(self): # Create model diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index b525d95eac..6dd4c95286 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -7,14 +7,12 @@ import warnings import sys from platform import system, version - +import importlib.util class TestScipySolver(unittest.TestCase): def test_model_solver_python_and_jax(self): - if not ( - system() == "Windows" or (system() == "Darwin" and "ARM64" in version()) - ): + if importlib.util.find_spec("jax"): formats = ["python", "jax"] else: formats = ["python"]