Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytest for unit testing #3857

Merged
merged 20 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def run_coverage(session):
session.install("-e", ".[all,dev]", silent=False)
else:
session.install("-e", ".[all,dev,jax]", silent=False)
session.run("coverage", "run", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
session.run("pytest", "--cov=pybamm", "--cov-report=xml", "tests/unit")


@nox.session(name="integration")
Expand Down Expand Up @@ -115,7 +113,7 @@ def run_integration(session):
@nox.session(name="doctests")
def run_doctests(session):
"""Run the doctests and generate the output(s) in the docs/build/ directory."""
session.install("-e", ".[all,docs]", silent=False)
session.install("-e", ".[all,dev,docs]", silent=False)
session.run("python", "run-tests.py", "--doctest")


Expand Down Expand Up @@ -162,7 +160,7 @@ def run_scripts(session):
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
session.install("setuptools", silent=False)
session.install("-e", ".[all]", silent=False)
session.install("-e", ".[all,dev]", silent=False)
session.run("python", "run-tests.py", "--scripts")


Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/operations/latexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_geometry_displays(self, var):
for _, rng in self.model.default_geometry[var.domain[-1]].items():
rng_max = get_rng_min_max_name(rng, "max")

geo_latex = f"\quad {rng_min} < {name} < {rng_max}"
geo_latex = rf"\quad {rng_min} < {name} < {rng_max}"
geo.append(geo_latex)

return geo
Expand Down
2 changes: 1 addition & 1 deletion pybamm/expression_tree/printing/sympy_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _print_Derivative(self, expr):
eqn = super()._print_Derivative(expr)
if getattr(expr, "force_partial", False) and "partial" not in eqn:
var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0]
eqn = eqn.replace(var1, "\partial").replace(var2, "\partial")
eqn = eqn.replace(var1, r"\partial").replace(var2, r"\partial")

return eqn

Expand Down
2 changes: 1 addition & 1 deletion pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
self.spatial_unit = "mm"
elif spatial_unit == "um": # micrometers
self.spatial_factor = 1e6
self.spatial_unit = "$\mu$m"
self.spatial_unit = r"$\mu$m"
else:
raise ValueError(f"spatial unit '{spatial_unit}' not recognized")

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dev = [
# For running testing sessions
"nox",
# For coverage
"coverage[toml]",
"pytest-cov",
# For test parameterization
"parameterized>=0.9",
# For testing Jupyter notebooks
Expand Down Expand Up @@ -249,6 +249,10 @@ filterwarnings = [
# ignore internal nbmake warnings
'ignore:unclosed \<socket.socket:ResourceWarning',
'ignore:unclosed event loop \<:ResourceWarning',
# ignore warnings generated while running tests
"ignore::DeprecationWarning",
agriyakhetarpal marked this conversation as resolved.
Show resolved Hide resolved
"ignore::UserWarning",
"ignore::RuntimeWarning",
agriyakhetarpal marked this conversation as resolved.
Show resolved Hide resolved
]

# Logging configuration
Expand Down
20 changes: 13 additions & 7 deletions run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import pybamm
import sys
import argparse
import unittest
import subprocess
import pytest
import unittest
agriyakhetarpal marked this conversation as resolved.
Show resolved Hide resolved


def run_code_tests(executable=False, folder: str = "unit", interpreter="python"):
Expand All @@ -36,12 +37,18 @@ def run_code_tests(executable=False, folder: str = "unit", interpreter="python")
# currently activated virtual environment
interpreter = sys.executable
if executable is False:
suite = unittest.defaultTestLoader.discover(tests, pattern="test*.py")
result = unittest.TextTestRunner(verbosity=2).run(suite)
ret = int(not result.wasSuccessful())
if tests == "tests/unit":
ret = pytest.main(["-v", tests])
else:
suite = unittest.defaultTestLoader.discover(tests, pattern="test*.py")
result = unittest.TextTestRunner(verbosity=2).run(suite)
ret = int(not result.wasSuccessful())
else:
print(f"Running {folder} tests with executable '{interpreter}'")
cmd = [interpreter, "-m", "unittest", "discover", "-v", tests]
print(f"Running {folder} tests with executable {interpreter}")
if tests == "tests/unit":
cmd = [interpreter, "-m", "pytest", "-v", tests]
else:
cmd = [interpreter, "-m", "unittest", "discover", "-v", tests]
p = subprocess.Popen(cmd)
try:
ret = p.wait()
Expand Down Expand Up @@ -243,7 +250,6 @@ def test_script(path, executable="python"):
metavar="python",
help="Give the name of the Python interpreter if it is not 'python'",
)

# Parse!
args = parser.parse_args()

Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
get_discretisation_for_testing,
get_p2d_discretisation_for_testing,
get_size_distribution_disc_for_testing,
function_test,
multi_var_function_test,
multi_var_function_cube_test,
get_1p1d_discretisation_for_testing,
get_2p1d_discretisation_for_testing,
get_unit_2p1D_mesh_for_testing,
Expand Down
12 changes: 12 additions & 0 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,18 @@ def get_size_distribution_disc_for_testing(xpts=None, rpts=10, Rpts=10, zpts=15)
)


def function_test(arg):
return arg + arg


def multi_var_function_test(arg1, arg2):
return arg1 + arg2


def multi_var_function_cube_test(arg1, arg2):
return arg1 + arg2**3


def get_1p1d_discretisation_for_testing(xpts=None, rpts=10, zpts=15):
return get_discretisation_for_testing(
mesh=get_1p1d_mesh_for_testing(xpts, rpts, zpts),
Expand Down
47 changes: 20 additions & 27 deletions tests/unit/test_expression_tree/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,11 @@

import pybamm
import sympy


def test_function(arg):
return arg + arg


def test_multi_var_function(arg1, arg2):
return arg1 + arg2


def test_multi_var_function_cube(arg1, arg2):
return arg1 + arg2**3
from tests import (
function_test,
multi_var_function_test,
multi_var_function_cube_test,
)


class TestFunction(TestCase):
Expand All @@ -31,16 +24,16 @@ def test_number_input(self):
self.assertIsInstance(log.children[0], pybamm.Scalar)
self.assertEqual(log.evaluate(), np.log(10))

summ = pybamm.Function(test_multi_var_function, 1, 2)
summ = pybamm.Function(multi_var_function_test, 1, 2)
self.assertIsInstance(summ.children[0], pybamm.Scalar)
self.assertIsInstance(summ.children[1], pybamm.Scalar)
self.assertEqual(summ.evaluate(), 3)

def test_function_of_one_variable(self):
a = pybamm.Symbol("a")
funca = pybamm.Function(test_function, a)
self.assertEqual(funca.name, "function (test_function)")
self.assertEqual(str(funca), "test_function(a)")
funca = pybamm.Function(function_test, a)
self.assertEqual(funca.name, "function (function_test)")
self.assertEqual(str(funca), "function_test(a)")
self.assertEqual(funca.children[0].name, a.name)

b = pybamm.Scalar(1)
Expand All @@ -61,7 +54,7 @@ def test_diff(self):
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5])
func = pybamm.Function(test_function, a)
func = pybamm.Function(function_test, a)
self.assertEqual(func.diff(a).evaluate(y=y), 2)
self.assertEqual(func.diff(func).evaluate(), 1)
func = pybamm.sin(a)
Expand All @@ -72,38 +65,38 @@ def test_diff(self):
self.assertEqual(func.diff(a).evaluate(y=y), np.exp(a.evaluate(y=y)))

# multiple variables
func = pybamm.Function(test_multi_var_function, 4 * a, 3 * a)
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * a)
self.assertEqual(func.diff(a).evaluate(y=y), 7)
func = pybamm.Function(test_multi_var_function, 4 * a, 3 * b)
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(func.diff(b).evaluate(y=np.array([5, 6])), 3)
func = pybamm.Function(test_multi_var_function_cube, 4 * a, 3 * b)
func = pybamm.Function(multi_var_function_cube_test, 4 * a, 3 * b)
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
self.assertEqual(
func.diff(b).evaluate(y=np.array([5, 6])), 3 * 3 * (3 * 6) ** 2
)

# exceptions
func = pybamm.Function(
test_multi_var_function_cube, 4 * a, 3 * b, derivative="derivative"
multi_var_function_cube_test, 4 * a, 3 * b, derivative="derivative"
)
with self.assertRaises(ValueError):
func.diff(a)

def test_function_of_multiple_variables(self):
a = pybamm.Variable("a")
b = pybamm.Parameter("b")
func = pybamm.Function(test_multi_var_function, a, b)
self.assertEqual(func.name, "function (test_multi_var_function)")
self.assertEqual(str(func), "test_multi_var_function(a, b)")
func = pybamm.Function(multi_var_function_test, a, b)
self.assertEqual(func.name, "function (multi_var_function_test)")
self.assertEqual(str(func), "multi_var_function_test(a, b)")
self.assertEqual(func.children[0].name, a.name)
self.assertEqual(func.children[1].name, b.name)

# test eval and diff
a = pybamm.StateVector(slice(0, 1))
b = pybamm.StateVector(slice(1, 2))
y = np.array([5, 2])
func = pybamm.Function(test_multi_var_function, a, b)
func = pybamm.Function(multi_var_function_test, a, b)

self.assertEqual(func.evaluate(y=y), 7)
self.assertEqual(func.diff(a).evaluate(y=y), 1)
Expand All @@ -114,7 +107,7 @@ def test_exceptions(self):
a = pybamm.Variable("a", domain="something")
b = pybamm.Variable("b", domain="something else")
with self.assertRaises(pybamm.DomainError):
pybamm.Function(test_multi_var_function, a, b)
pybamm.Function(multi_var_function_test, a, b)

def test_function_unnamed(self):
fun = pybamm.Function(np.cos, pybamm.t)
Expand Down Expand Up @@ -148,7 +141,7 @@ def test_to_equation(self):

def test_to_from_json_error(self):
a = pybamm.Symbol("a")
funca = pybamm.Function(test_function, a)
funca = pybamm.Function(function_test, a)

with self.assertRaises(NotImplementedError):
funca.to_json()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@

if pybamm.have_jax():
import jax


def test_function(arg):
return arg + arg


def test_function2(arg1, arg2):
return arg1 + arg2
from tests import (
function_test,
multi_var_function_test,
)


class TestEvaluate(TestCase):
Expand Down Expand Up @@ -93,10 +89,10 @@ def test_find_symbols(self):
# test function
constant_symbols = OrderedDict()
variable_symbols = OrderedDict()
expr = pybamm.Function(test_function, a)
expr = pybamm.Function(function_test, a)
pybamm.find_symbols(expr, constant_symbols, variable_symbols)
self.assertEqual(next(iter(constant_symbols.keys())), expr.id)
self.assertEqual(next(iter(constant_symbols.values())), test_function)
self.assertEqual(next(iter(constant_symbols.values())), function_test)
self.assertEqual(next(iter(variable_symbols.keys())), a.id)
self.assertEqual(list(variable_symbols.keys())[1], expr.id)
self.assertEqual(next(iter(variable_symbols.values())), "y[0:1]")
Expand Down Expand Up @@ -283,9 +279,9 @@ def test_to_python(self):
expr = a + b
constant_str, variable_str = pybamm.to_python(expr)
expected_str = (
"var_[0-9m]+ = y\[0:1\].*\\n"
"var_[0-9m]+ = y\[1:2\].*\\n"
"var_[0-9m]+ = var_[0-9m]+ \+ var_[0-9m]+"
r"var_[0-9m]+ = y\[0:1\].*\n"
r"var_[0-9m]+ = y\[1:2\].*\n"
r"var_[0-9m]+ = var_[0-9m]+ \+ var_[0-9m]+"
)

self.assertRegex(variable_str, expected_str)
Expand All @@ -306,12 +302,12 @@ def test_evaluator_python(self):
self.assertEqual(result, 3)

# test function(a*b)
expr = pybamm.Function(test_function, a * b)
expr = pybamm.Function(function_test, a * b)
evaluator = pybamm.EvaluatorPython(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, 12)

expr = pybamm.Function(test_function2, a, b)
expr = pybamm.Function(multi_var_function_test, a, b)
evaluator = pybamm.EvaluatorPython(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, 5)
Expand Down Expand Up @@ -486,7 +482,7 @@ def test_evaluator_jax(self):
self.assertEqual(result, 3)

# test function(a*b)
expr = pybamm.Function(test_function, a * b)
expr = pybamm.Function(function_test, a * b)
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, 12)
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/test_expression_tree/test_operations/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
import unittest
from scipy.sparse import eye
from tests import get_mesh_for_testing


def test_multi_var_function(arg1, arg2):
return arg1 + arg2
from tests import multi_var_function_test


class TestJacobian(TestCase):
Expand Down Expand Up @@ -217,7 +214,7 @@ def test_functions(self):
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(test_multi_var_function, 2 * y, 3 * y)
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(4))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/test_expression_tree/test_operations/test_jac_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import numpy as np
import unittest
from scipy.sparse import eye
from tests import get_1p1d_discretisation_for_testing


def test_multi_var_function(arg1, arg2):
return arg1 + arg2
from tests import (
get_1p1d_discretisation_for_testing,
multi_var_function_test,
)


class TestJacobian(TestCase):
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_functions(self):
np.testing.assert_array_equal(0, dfunc_dy)

# several children
func = pybamm.Function(test_multi_var_function, 2 * y, 3 * y)
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
jacobian = np.diag(5 * np.ones(8))
dfunc_dy = func.jac(y).evaluate(y=y0)
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_plotting/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
from tests import TestCase
import matplotlib.pyplot as plt
from matplotlib import use

use("Agg")


class TestPlot(TestCase):
Expand Down
Loading
Loading