Skip to content

Commit

Permalink
Make jax optional
Browse files Browse the repository at this point in the history
  • Loading branch information
priyanshuone6 committed Nov 2, 2021
1 parent 269629f commit c9e1b59
Show file tree
Hide file tree
Showing 12 changed files with 32 additions and 83 deletions.
2 changes: 0 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ autograd >= 1.2
scikit-fem >= 0.2.0
casadi >= 3.5.0
imageio>=2.9.0
jax==0.2.12
jaxlib==0.1.70
jupyter # For example notebooks
pybtex
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs
Expand Down
12 changes: 3 additions & 9 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import os
import platform
import importlib.util

#
# Version info
Expand Down Expand Up @@ -102,10 +103,7 @@ def version(formatted=False):
EvaluatorPython,
)

if not (
platform.system() == "Windows"
or (platform.system() == "Darwin" and "ARM64" in platform.version())
):
if importlib.util.find_spec("jax"):
from .expression_tree.operations.evaluate_python import EvaluatorJax
from .expression_tree.operations.evaluate_python import JaxCooMatrix

Expand Down Expand Up @@ -226,11 +224,7 @@ def version(formatted=False):
from .solvers.scikits_ode_solver import ScikitsOdeSolver, have_scikits_odes
from .solvers.scipy_solver import ScipySolver

# Jax not supported under windows
if not (
platform.system() == "Windows"
or (platform.system() == "Darwin" and "ARM64" in platform.version())
):
if importlib.util.find_spec("jax"):
from .solvers.jax_solver import JaxSolver
from .solvers.jax_bdf_solver import jax_bdf_integrate

Expand Down
11 changes: 2 additions & 9 deletions pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import numpy as np
import scipy.sparse
from collections import OrderedDict

import importlib.util
import numbers
from platform import system, version

if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
if importlib.util.find_spec("jax"):
import jax

from jax.config import config

config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -104,12 +103,6 @@ def create_jax_coo_matrix(value):
return JaxCooMatrix(row, col, data, value.shape)


else:

def create_jax_coo_matrix(value): # pragma: no cover
raise NotImplementedError("Jax is not available on Windows")


def id_to_python_variable(symbol_id, constant=False):
"""
This function defines the format for the python variable names used in find_symbols
Expand Down
4 changes: 1 addition & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
numpy >= 1.16
numpy >= 1.16
scipy >= 1.3
pandas >= 0.24
anytree >= 2.4.3
autograd >= 1.2
scikit-fem >= 0.2.0
casadi >= 3.5.0
imageio>=2.9.0
jax==0.2.12
jaxlib==0.1.70
jupyter # For example notebooks
pybtex
sympy==1.8
Expand Down
9 changes: 2 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import subprocess
from pathlib import Path
from platform import system, version
from platform import system
import wheel.bdist_wheel as orig
import site
import shutil
Expand Down Expand Up @@ -162,11 +162,6 @@ def compile_KLU():
idaklu_ext = Extension("pybamm.solvers.idaklu", ["pybamm/solvers/c_solvers/idaklu.cpp"])
ext_modules = [idaklu_ext] if compile_KLU() else []

jax_dependencies = []
if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
jax_dependencies = ["jax==0.2.12", "jaxlib==0.1.70"]


# Load text for description and license
with open("README.md", encoding="utf-8") as f:
readme = f.read()
Expand Down Expand Up @@ -198,7 +193,6 @@ def compile_KLU():
"scikit-fem>=0.2.0",
"casadi>=3.5.0",
"imageio>=2.9.0",
*jax_dependencies,
"jupyter", # For example notebooks
"pybtex",
"sympy==1.8",
Expand All @@ -209,6 +203,7 @@ def compile_KLU():
"matplotlib>=2.0",
],
extras_require={
"jax": ["jax", "jaxlib"],
"docs": ["sphinx>=1.5", "guzzle-sphinx-theme"], # For doc generation
"dev": [
"flake8>=3", # For code style checking
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import unittest
from platform import system, version

import importlib.util

class TestSPM(unittest.TestCase):
def test_basic_processing(self):
Expand Down Expand Up @@ -71,9 +71,7 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)

if not (
system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
):
if importlib.util.find_spec("jax"):
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
import pybamm
import tests

import importlib.util
import numpy as np
import unittest
from platform import system, version
Expand Down Expand Up @@ -79,9 +79,7 @@ def test_optimisations(self):
np.testing.assert_array_almost_equal(original, using_known_evals)
np.testing.assert_array_almost_equal(original, to_python)

if not (
system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
):
if importlib.util.find_spec("jax"):
to_jax = optimtest.evaluate_model(to_jax=True)
np.testing.assert_array_almost_equal(original, to_jax)

Expand Down
7 changes: 2 additions & 5 deletions tests/unit/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pybamm
import unittest
from platform import system, version

import importlib.util

class TestCitations(unittest.TestCase):
def test_citations(self):
Expand Down Expand Up @@ -255,10 +255,7 @@ def test_solver_citations(self):
pybamm.IDAKLUSolver()
self.assertIn("Hindmarsh2005", citations._papers_to_cite)

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_jax_citations(self):
citations = pybamm.citations
citations._reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import scipy.sparse
from collections import OrderedDict
from platform import system, version
import importlib.util

if importlib.util.find_spec("jax"):
import jax

def test_function(arg):
return arg + arg
Expand Down Expand Up @@ -457,10 +460,7 @@ def test_evaluator_python(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_find_symbols_jax(self):
# test sparse conversion
constant_symbols = OrderedDict()
Expand All @@ -473,10 +473,7 @@ def test_find_symbols_jax(self):
list(constant_symbols.values())[0].toarray(), A.entries.toarray()
)

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_evaluator_jax(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
Expand Down Expand Up @@ -638,10 +635,7 @@ def test_evaluator_jax(self):
result = evaluator.evaluate(t=t, y=y)
np.testing.assert_allclose(result, expr.evaluate(t=t, y=y))

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_evaluator_jax_jacobian(self):
a = pybamm.StateVector(slice(0, 1))
y_tests = [np.array([[2.0]]), np.array([[1.0]]), np.array([1.0])]
Expand All @@ -656,35 +650,24 @@ def test_evaluator_jax_jacobian(self):
result_true = evaluator_jac.evaluate(t=None, y=y)
np.testing.assert_allclose(result_test, result_true)

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_evaluator_jax_debug(self):
a = pybamm.StateVector(slice(0, 1))
expr = a ** 2
y_test = np.array([[2.0], [3.0]])
evaluator = pybamm.EvaluatorJax(expr)
evaluator.debug(y=y_test)

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_evaluator_jax_inputs(self):
a = pybamm.InputParameter("a")
expr = a ** 2
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator.evaluate(inputs={"a": 2})
self.assertEqual(result, 4)

@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
def test_jax_coo_matrix(self):
import jax

A = pybamm.JaxCooMatrix([0, 1], [0, 1], [1.0, 2.0], (2, 2))
Adense = jax.numpy.array([[1.0, 0], [0, 2.0]])
v = jax.numpy.array([[2.0], [1.0]])
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import time
import numpy as np
from platform import system, version
import importlib.util

if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
if importlib.util.find_spec("jax"):
import jax


@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
class TestJaxBDFSolver(unittest.TestCase):
def test_solver(self):
# Create model
Expand Down
9 changes: 4 additions & 5 deletions tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import importlib
import pybamm
import unittest
from tests import get_mesh_for_testing
import sys
import time
import numpy as np
from platform import system, version
import importlib.util

if not (system() == "Windows" or (system() == "Darwin" and "ARM64" in version())):
if importlib.util.find_spec("jax"):
import jax


@unittest.skipIf(
system() == "Windows" or (system() == "Darwin" and "ARM64" in version()),
"JAX not supported on windows or Mac M1",
)
@unittest.skipIf(not(importlib.util.find_spec("jax")), "requires jax")
class TestJaxSolver(unittest.TestCase):
def test_model_solver(self):
# Create model
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/test_solvers/test_scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import warnings
import sys
from platform import system, version

import importlib.util

class TestScipySolver(unittest.TestCase):
def test_model_solver_python_and_jax(self):

if not (
system() == "Windows" or (system() == "Darwin" and "ARM64" in version())
):
if importlib.util.find_spec("jax"):
formats = ["python", "jax"]
else:
formats = ["python"]
Expand Down

0 comments on commit c9e1b59

Please sign in to comment.