Skip to content

Commit

Permalink
#1863 #2008 use InputParameter.create_copy in Discretisation._process…
Browse files Browse the repository at this point in the history
…_symbol
  • Loading branch information
martinjrobins committed Apr 5, 2022
1 parent b0bd5bb commit 71ce5e6
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 9 deletions.
8 changes: 1 addition & 7 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,13 +1036,7 @@ def _process_symbol(self, symbol):
return new_symbol

elif isinstance(symbol, pybamm.InputParameter):
# Return a new copy of the input parameter, but set the expected size
# according to the domain of the input parameter
expected_size = self._get_variable_size(symbol)
new_input_parameter = pybamm.InputParameter(
symbol.name, symbol.domain, expected_size
)
return new_input_parameter
return symbol.create_copy()

else:
# Backup option: return the object
Expand Down
3 changes: 3 additions & 0 deletions pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class InputParameter(pybamm.Symbol):
"""

def __init__(self, name, domain=None, expected_size=None):
print('creating InputParameter with expected_size', expected_size)
# Expected size defaults to 1 if no domain else None (gets set later)
if expected_size is None:
if domain is None:
Expand All @@ -36,6 +37,7 @@ def __init__(self, name, domain=None, expected_size=None):

def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
print('create_copy', self._expected_size)
new_input_parameter = InputParameter(
self.name, self.domain, expected_size=self._expected_size
)
Expand All @@ -46,6 +48,7 @@ def _evaluate_for_shape(self):
Returns the scalar 'NaN' to represent the shape of a parameter.
See :meth:`pybamm.Symbol.evaluate_for_shape()`
"""
print('evaluate_for_shape', self, self._expected_size)
if self._expected_size is None:
return pybamm.evaluate_for_shape_using_domain(self.domains)
elif self._expected_size == 1:
Expand Down
1 change: 1 addition & 0 deletions pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _unary_new_copy(self, child):
return new_index

def _evaluate_for_shape(self):
print('evaluate_for_shape', self.children, self.children[0].evaluate_for_shape())
return self._unary_evaluate(self.children[0].evaluate_for_shape())

def _evaluates_on_edges(self, dimension):
Expand Down
3 changes: 1 addition & 2 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
try:
idaklu = importlib.util.module_from_spec(idaklu_spec)
idaklu_spec.loader.exec_module(idaklu)
except ImportError as e: # pragma: no cover
print('IDAKLUSolver import error:', e)
except ImportError: # pragma: no cover
idaklu_spec = None


Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_solvers/test_idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,53 @@ def test_model_events(self):
solution.y[-1], 2 * np.exp(0.1 * solution.t), decimal=5
)

def test_input_params(self):
# test a mix of scalar and vector input params
for form in ["python", "casadi", "jax"]:
if form == "jax" and not pybamm.have_jax():
continue
if form == "casadi":
root_method = "casadi"
else:
root_method = "lm"
model = pybamm.BaseModel()
model.convert_to_format = form
u1 = pybamm.Variable("u1")
u2 = pybamm.Variable("u2")
u3 = pybamm.Variable("u3")
v = pybamm.Variable("v")
a = pybamm.InputParameter("a")
b = pybamm.InputParameter("b", expected_size=2)
model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)}
model.algebraic = {v: 1 - v}
model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1}

disc = pybamm.Discretisation()
disc.process_model(model)

solver = pybamm.IDAKLUSolver(root_method=root_method)

t_eval = np.linspace(0, 3, 100)
a_value = 0.1
b_value = np.array([[0.2], [0.3]])

sol = solver.solve(
model, t_eval, inputs={"a": a_value, "b": b_value},
)

# test that y[3] remains constant
np.testing.assert_array_almost_equal(
sol.y[3, :], np.ones(sol.t.shape)
)

# test that y[0] = to true solution
true_solution = a_value * sol.t
np.testing.assert_array_almost_equal(sol.y[0, :], true_solution)

# test that y[1:2] = to true solution
true_solution = b_value * sol.t
np.testing.assert_array_almost_equal(sol.y[1:2, :], true_solution)

def test_ida_roberts_klu_sensitivities(self):
# this test implements a python version of the ida Roberts
# example provided in sundials
Expand Down

0 comments on commit 71ce5e6

Please sign in to comment.