Skip to content

Commit

Permalink
#1863 flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 3, 2022
1 parent cf53d1f commit f4e1d9e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 21 deletions.
23 changes: 8 additions & 15 deletions pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pybamm
import numpy as np
import scipy.sparse as sparse
import numbers

import importlib

Expand Down Expand Up @@ -193,28 +192,22 @@ def _integrate(self, model, t_eval, inputs_dict=None):

if model.convert_to_format == "jax":
mass_matrix = model.mass_matrix.entries.toarray()
elif model.convert_to_format == "casadi":
#mass_matrix = casadi.DM(model.mass_matrix.entries)
mass_matrix = model.mass_matrix.entries
else:
mass_matrix = model.mass_matrix.entries

# construct residuals function by binding inputs
if model.convert_to_format == "casadi":
#y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg)
#ydot_casadi = casadi.MX.sym("ydot", model.len_rhs_and_alg)
#t_casadi = casadi.MX.sym("t")
#casadi_resfn = casadi.Function(
# "residuals",
# [t_casadi, y_casadi, ydot_casadi],
# [model.rhs_algebraic_eval(t_casadi, y_casadi, inputs) - mass_matrix @
# ydot_casadi]
#)
def resfn(t, y, ydot):
return model.rhs_algebraic_eval(t, y, inputs).full().flatten() - mass_matrix @ ydot
return (
model.rhs_algebraic_eval(t, y, inputs).full().flatten()
- mass_matrix @ ydot
)
else:
def resfn(t, y, ydot):
return model.rhs_algebraic_eval(t, y, inputs).flatten() - mass_matrix @ ydot
return (
model.rhs_algebraic_eval(t, y, inputs).flatten()
- mass_matrix @ ydot
)

jac_y0_t0 = model.jac_rhs_algebraic_eval(t_eval[0], y0, inputs)
if sparse.issparse(jac_y0_t0):
Expand Down
7 changes: 4 additions & 3 deletions pybamm/solvers/scikits_dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def _integrate(self, model, t_eval, inputs_dict=None):

y0 = model.y0
if isinstance(y0, casadi.DM):
y0 = y0.full().flatten()
y0 = y0.full()
y0 = y0.flatten()

residuals = model.residuals_eval
residuals = model.rhs_algebraic_eval
events = model.terminate_events_eval
jacobian = model.jacobian_eval
jacobian = model.jac_rhs_algebraic_eval
mass_matrix = model.mass_matrix.entries

def eqsres(t, y, ydot, return_residuals):
Expand Down
7 changes: 4 additions & 3 deletions pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ def _integrate(self, model, t_eval, inputs_dict=None):
"""
inputs_dict = inputs_dict or {}
if model.rhs_eval.form == "casadi":
if model.convert_to_format == "casadi":
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
else:
inputs = inputs_dict

y0 = model.y0
if isinstance(y0, casadi.DM):
y0 = y0.full().flatten()
y0 = y0.full()
y0 = y0.flatten()

derivs = model.rhs_eval
events = model.terminate_events_eval
jacobian = model.jacobian_eval
jacobian = model.jac_rhs_eval

def eqsydot(t, y, return_ydot):
return_ydot[:] = derivs(t, y, inputs)
Expand Down

0 comments on commit f4e1d9e

Please sign in to comment.