diff --git a/CHANGELOG.md b/CHANGELOG.md index 558ea6dcbc..0fbfc585f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ ## Bug fixes +- Allowed for pybamm functions exp, sin, cos, sqrt to be used in expression trees that + are converted to casadi format ([#1067](https://github.com/pybamm-team/PyBaMM/pull/1067) - Fix a bug where variables that depend on y and z were transposed in `QuickPlot` ([#1055](https://github.com/pybamm-team/PyBaMM/pull/1055)) ## Breaking changes diff --git a/examples/scripts/compare_lithium_ion.py b/examples/scripts/compare_lithium_ion.py index c0ef3cd3ba..2644511824 100644 --- a/examples/scripts/compare_lithium_ion.py +++ b/examples/scripts/compare_lithium_ion.py @@ -3,7 +3,7 @@ # import pybamm -pybamm.set_logging_level("INFO") +# pybamm.set_logging_level("INFO") # load models models = [ diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index e2d9869b94..41bed0708d 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -74,7 +74,7 @@ def entries_string(self, value): if issparse(entries): self._entries_string = str(entries.__dict__) else: - self._entries_string = entries.tostring() + self._entries_string = entries.tobytes() def set_id(self): """ See :meth:`pybamm.Symbol.set_id()`. """ diff --git a/pybamm/expression_tree/interpolant.py b/pybamm/expression_tree/interpolant.py index de49da96f0..defb56d924 100644 --- a/pybamm/expression_tree/interpolant.py +++ b/pybamm/expression_tree/interpolant.py @@ -86,7 +86,7 @@ def entries_string(self, value): self._entries_string = value else: entries = self.data - self._entries_string = entries.tostring() + self._entries_string = entries.tobytes() def set_id(self): """ See :meth:`pybamm.Symbol.set_id()`. """ diff --git a/pybamm/expression_tree/operations/convert_to_casadi.py b/pybamm/expression_tree/operations/convert_to_casadi.py index 7a19c33d12..7ba7bc01b1 100644 --- a/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/pybamm/expression_tree/operations/convert_to_casadi.py @@ -101,6 +101,28 @@ def _convert(self, symbol, t, y, y_dot, inputs): return casadi.mmax(*converted_children) elif symbol.function == np.abs: return casadi.fabs(*converted_children) + elif symbol.function == np.sqrt: + return casadi.sqrt(*converted_children) + elif symbol.function == np.sin: + return casadi.sin(*converted_children) + elif symbol.function == np.arcsinh: + return casadi.arcsinh(*converted_children) + elif symbol.function == np.arccosh: + return casadi.arccosh(*converted_children) + elif symbol.function == np.tanh: + return casadi.tanh(*converted_children) + elif symbol.function == np.cosh: + return casadi.cosh(*converted_children) + elif symbol.function == np.sinh: + return casadi.sinh(*converted_children) + elif symbol.function == np.cos: + return casadi.cos(*converted_children) + elif symbol.function == np.exp: + return casadi.exp(*converted_children) + elif symbol.function == np.log: + return casadi.log(*converted_children) + elif symbol.function == np.sign: + return casadi.sign(*converted_children) elif isinstance(symbol.function, (PchipInterpolator, CubicSpline)): return casadi.interpolant("LUT", "bspline", [symbol.x], symbol.y)( *converted_children diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 6693d22acd..2e03d622a9 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -740,7 +740,7 @@ def shape(self): # Default behaviour is to try to evaluate the object directly # Try with some large y, to avoid having to unpack (slow) try: - y = np.linspace(0.1, 0.9, int(1e4)) + y = np.nan * np.ones((1000, 1)) evaluated_self = self.evaluate(0, y, y, inputs="shape test") # If that fails, fall back to calculating how big y should really be except ValueError: @@ -753,7 +753,7 @@ def shape(self): len(x._evaluation_array) for x in state_vectors_in_node ) # Pick a y that won't cause RuntimeWarnings - y = np.linspace(0.1, 0.9, min_y_size) + y = np.nan * np.ones((min_y_size, 1)) evaluated_self = self.evaluate(0, y, y, inputs="shape test") # Return shape of evaluated object diff --git a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py index d54abb3b98..f51c2fbdd4 100644 --- a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py +++ b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py @@ -30,11 +30,11 @@ def test_convert_scalar_symbols(self): self.assertEqual(abs(c).to_casadi(), casadi.MX(1)) # function - def sin(x): - return np.sin(x) + def square_plus_one(x): + return x ** 2 + 1 - f = pybamm.Function(sin, b) - self.assertEqual(f.to_casadi(), casadi.MX(np.sin(1))) + f = pybamm.Function(square_plus_one, b) + self.assertEqual(f.to_casadi(), 2) def myfunction(x, y): return x + y @@ -95,6 +95,12 @@ def test_special_functions(self): self.assert_casadi_equal( pybamm.Function(np.abs, c).to_casadi(), casadi.MX(3), evalf=True ) + for np_fun in [np.sqrt, np.tanh, np.cosh, np.sinh, + np.exp, np.log, np.sign, np.sin, np.cos, + np.arccosh, np.arcsinh]: + self.assert_casadi_equal( + pybamm.Function(np_fun, c).to_casadi(), casadi.MX(np_fun(3)), evalf=True + ) def test_interpolation(self): x = np.linspace(0, 1)[:, np.newaxis]