Skip to content

Commit

Permalink
Add return type documentation for util.have_jax() and is_jax_compatib…
Browse files Browse the repository at this point in the history
…le() (#4005)

* add return type for util.have_jax() and is_jax_compatible()

* add description

* Update pybamm/util.py

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>

---------

Co-authored-by: Eric G. Kratz <kratman@users.noreply.github.com>
Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 12, 2024
1 parent ce602d6 commit 8cd1cec
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,15 @@ def get_parameters_filepath(path):


def have_jax():
"""Check if jax and jaxlib are installed with the correct versions"""
"""
Check if jax and jaxlib are installed with the correct versions
Returns
-------
bool
True if jax and jaxlib are installed with the correct versions, False if otherwise
"""
return (
(importlib.util.find_spec("jax") is not None)
and (importlib.util.find_spec("jaxlib") is not None)
Expand All @@ -269,7 +277,14 @@ def have_jax():


def is_jax_compatible():
"""Check if the available version of jax and jaxlib are compatible with PyBaMM"""
"""
Check if the available versions of jax and jaxlib are compatible with PyBaMM
Returns
-------
bool
True if jax and jaxlib are compatible with PyBaMM, False if otherwise
"""
return importlib.metadata.distribution("jax").version.startswith(
JAX_VERSION
) and importlib.metadata.distribution("jaxlib").version.startswith(JAXLIB_VERSION)
Expand Down

0 comments on commit 8cd1cec

Please sign in to comment.