diff --git a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py index 0ebdb31ebf..fa9992ad24 100644 --- a/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py +++ b/tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py @@ -66,9 +66,11 @@ def test_convert_array_symbols(self): casadi_t = casadi.MX.sym("t") casadi_y = casadi.MX.sym("y", 10) + casadi_y_dot = casadi.MX.sym("y_dot", 10) pybamm_t = pybamm.Time() pybamm_y = pybamm.StateVector(slice(0, 10)) + pybamm_y_dot = pybamm.StateVectorDot(slice(0, 10)) # Time self.assertEqual(pybamm_t.to_casadi(casadi_t, casadi_y), casadi_t) @@ -76,6 +78,12 @@ def test_convert_array_symbols(self): # State Vector self.assert_casadi_equal(pybamm_y.to_casadi(casadi_t, casadi_y), casadi_y) + # State Vector Dot + self.assert_casadi_equal( + pybamm_y_dot.to_casadi(casadi_t, casadi_y, casadi_y_dot), + casadi_y_dot + ) + def test_special_functions(self): a = pybamm.Array(np.array([1, 2, 3, 4, 5])) self.assert_casadi_equal(pybamm.max(a).to_casadi(), casadi.MX(5), evalf=True)