Skip to content

Commit

Permalink
diff of vector expression wrt itself returns a vector rather than a s…
Browse files Browse the repository at this point in the history
…calar #4087
  • Loading branch information
martinjrobins committed Jul 12, 2024
1 parent 9b2b3d1 commit c4ed4ea
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
8 changes: 4 additions & 4 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,10 @@ def diff(self, variable: Symbol):
"""
if variable == self:
eval_shape = self.evaluate_for_shape()
if isinstance(eval_shape, numbers.Number):
if isinstance(eval_shape, numbers.Number) or len(eval_shape) == 1:
return pybamm.Scalar(1)
else:
return pybamm.Vector(np.ones_like(eval_shape), domain=self.domains)
return pybamm.Vector(np.ones_like(eval_shape))
elif any(variable == x for x in self.pre_order()):
return self._diff(variable)
elif variable == pybamm.t and self.has_symbol_of_classes(
Expand All @@ -726,10 +726,10 @@ def diff(self, variable: Symbol):
return self._diff(variable)
else:
eval_shape = self.evaluate_for_shape()
if isinstance(eval_shape, numbers.Number):
if isinstance(eval_shape, numbers.Number) or len(eval_shape) == 1:
return pybamm.Scalar(0)
else:
return pybamm.Vector(np.zeros_like(eval_shape), domain=self.domains)
return pybamm.Vector(np.zeros_like(eval_shape))

def _diff(self, variable):
"""
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ def _check_events_with_initial_conditions(t_eval, model, y0, inputs):
x for x in model.events if x.event_type == pybamm.EventType.TERMINATION
]
idxs = np.where(events_eval < 0)[0]
event_names = [termination_events[idx / len(inputs)].name for idx in idxs]
event_names = [termination_events[idx // len(inputs)].name for idx in idxs]
raise pybamm.SolverError(
f"Events {event_names} are non-positive at initial conditions"
)
Expand Down
8 changes: 0 additions & 8 deletions tests/unit/test_expression_tree/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,6 @@ def test_multiple_symbols(self):
for node, expect in zip(exp.pre_order(), expected_preorder):
self.assertEqual(node.name, expect)

def test_symbol_diff(self):
a = pybamm.Symbol("a")
b = pybamm.Symbol("b")
self.assertIsInstance(a.diff(a), pybamm.Scalar)
self.assertEqual(a.diff(a).evaluate(), 1)
self.assertIsInstance(a.diff(b), pybamm.Scalar)
self.assertEqual(a.diff(b).evaluate(), 0)

def test_symbol_evaluation(self):
a = pybamm.Symbol("a")
with self.assertRaises(NotImplementedError):
Expand Down

0 comments on commit c4ed4ea

Please sign in to comment.