Skip to content

Commit

Permalink
#1863 fix scikits solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 3, 2022
1 parent ea960eb commit e1fc87a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
3 changes: 2 additions & 1 deletion pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):

y0 = model.y0
if isinstance(y0, casadi.DM):
y0 = y0.full().flatten()
y0 = y0.full()
y0 = y0.flatten()

# The casadi algebraic solver can read rhs equations, but leaves them unchanged
# i.e. the part of the solution vector that corresponds to the differential
Expand Down
8 changes: 6 additions & 2 deletions pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,12 @@ def _integrate(self, model, t_eval, inputs_dict=None):
events = model.terminate_events_eval
jacobian = model.jac_rhs_eval

def eqsydot(t, y, return_ydot):
return_ydot[:] = derivs(t, y, inputs)
if model.convert_to_format == "casadi":
def eqsydot(t, y, return_ydot):
return_ydot[:] = derivs(t, y, inputs).full().flatten()
else:
def eqsydot(t, y, return_ydot):
return_ydot[:] = derivs(t, y, inputs).flatten()

def rootfn(t, y, return_root):
return_root[:] = [event(t, y, inputs) for event in events]
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_solvers/test_scikits_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class Model:
length_scales = {}
convert_to_format = "python"

def residuals_eval(self, t, y, ydot, inputs):
return np.array([0.5 * np.ones_like(y[0]) - ydot[0], 2 * y[0] - y[1]])
def rhs_algebraic_eval(self, t, y, inputs):
return np.array([0.5 * np.ones_like(y[0]), 2 * y[0] - y[1]])

def jacobian_eval(self, t, y, inputs):
def jac_rhs_algebraic_eval(self, t, y, inputs):
return np.array([[0.0, 0.0], [2.0, -1.0]])

model = Model()
Expand Down Expand Up @@ -101,12 +101,12 @@ class Model:
convert_to_format = "python"
len_rhs_and_alg = 2

def residuals_eval(self, t, y, ydot, inputs):
def rhs_algebraic_eval(self, t, y, inputs):
return np.array(
[0.5 * np.ones_like(y[0]) - 4 * ydot[0], 2.0 * y[0] - y[1]]
[0.5 * np.ones_like(y[0]), 2.0 * y[0] - y[1]]
)

def jacobian_eval(self, t, y, inputs):
def jac_rhs_algebraic_eval(self, t, y, inputs):
return np.array([[0.0, 0.0], [2.0, -1.0]])

model = Model()
Expand Down

0 comments on commit e1fc87a

Please sign in to comment.