diff --git a/pybamm/discretisations/discretisation.py b/pybamm/discretisations/discretisation.py index 81ba49d55b..a925b65af5 100644 --- a/pybamm/discretisations/discretisation.py +++ b/pybamm/discretisations/discretisation.py @@ -1036,13 +1036,7 @@ def _process_symbol(self, symbol): return new_symbol elif isinstance(symbol, pybamm.InputParameter): - # Return a new copy of the input parameter, but set the expected size - # according to the domain of the input parameter - expected_size = self._get_variable_size(symbol) - new_input_parameter = pybamm.InputParameter( - symbol.name, symbol.domain, expected_size - ) - return new_input_parameter + return symbol.create_copy() else: # Backup option: return the object diff --git a/pybamm/expression_tree/input_parameter.py b/pybamm/expression_tree/input_parameter.py index 6c79c782b1..2dea90bc5f 100644 --- a/pybamm/expression_tree/input_parameter.py +++ b/pybamm/expression_tree/input_parameter.py @@ -25,6 +25,7 @@ class InputParameter(pybamm.Symbol): """ def __init__(self, name, domain=None, expected_size=None): + print('creating InputParameter with expected_size', expected_size) # Expected size defaults to 1 if no domain else None (gets set later) if expected_size is None: if domain is None: @@ -36,6 +37,7 @@ def __init__(self, name, domain=None, expected_size=None): def create_copy(self): """See :meth:`pybamm.Symbol.new_copy()`.""" + print('create_copy', self._expected_size) new_input_parameter = InputParameter( self.name, self.domain, expected_size=self._expected_size ) @@ -46,6 +48,7 @@ def _evaluate_for_shape(self): Returns the scalar 'NaN' to represent the shape of a parameter. See :meth:`pybamm.Symbol.evaluate_for_shape()` """ + print('evaluate_for_shape', self, self._expected_size) if self._expected_size is None: return pybamm.evaluate_for_shape_using_domain(self.domains) elif self._expected_size == 1: diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 28e691b78b..e4246bb25a 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -335,6 +335,7 @@ def _unary_new_copy(self, child): return new_index def _evaluate_for_shape(self): + print('evaluate_for_shape', self.children, self.children[0].evaluate_for_shape()) return self._unary_evaluate(self.children[0].evaluate_for_shape()) def _evaluates_on_edges(self, dimension): diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 262357da60..d09a537155 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -14,8 +14,7 @@ try: idaklu = importlib.util.module_from_spec(idaklu_spec) idaklu_spec.loader.exec_module(idaklu) - except ImportError as e: # pragma: no cover - print('IDAKLUSolver import error:', e) + except ImportError: # pragma: no cover idaklu_spec = None diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 919f9d2699..123c09a4ec 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -137,6 +137,53 @@ def test_model_events(self): solution.y[-1], 2 * np.exp(0.1 * solution.t), decimal=5 ) + def test_input_params(self): + # test a mix of scalar and vector input params + for form in ["python", "casadi", "jax"]: + if form == "jax" and not pybamm.have_jax(): + continue + if form == "casadi": + root_method = "casadi" + else: + root_method = "lm" + model = pybamm.BaseModel() + model.convert_to_format = form + u1 = pybamm.Variable("u1") + u2 = pybamm.Variable("u2") + u3 = pybamm.Variable("u3") + v = pybamm.Variable("v") + a = pybamm.InputParameter("a") + b = pybamm.InputParameter("b", expected_size=2) + model.rhs = {u1: a * v, u2: pybamm.Index(b, 0), u3: pybamm.Index(b, 1)} + model.algebraic = {v: 1 - v} + model.initial_conditions = {u1: 0, u2: 0, u3: 0, v: 1} + + disc = pybamm.Discretisation() + disc.process_model(model) + + solver = pybamm.IDAKLUSolver(root_method=root_method) + + t_eval = np.linspace(0, 3, 100) + a_value = 0.1 + b_value = np.array([[0.2], [0.3]]) + + sol = solver.solve( + model, t_eval, inputs={"a": a_value, "b": b_value}, + ) + + # test that y[3] remains constant + np.testing.assert_array_almost_equal( + sol.y[3, :], np.ones(sol.t.shape) + ) + + # test that y[0] = to true solution + true_solution = a_value * sol.t + np.testing.assert_array_almost_equal(sol.y[0, :], true_solution) + + # test that y[1:2] = to true solution + true_solution = b_value * sol.t + np.testing.assert_array_almost_equal(sol.y[1:2, :], true_solution) + def test_ida_roberts_klu_sensitivities(self): # this test implements a python version of the ida Roberts # example provided in sundials