From f4e1d9ef6ec0994255afdbe546655f0923abdbd3 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Thu, 3 Mar 2022 14:28:46 +0000 Subject: [PATCH] #1863 flake8 --- pybamm/solvers/idaklu_solver.py | 23 ++++++++--------------- pybamm/solvers/scikits_dae_solver.py | 7 ++++--- pybamm/solvers/scikits_ode_solver.py | 7 ++++--- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/pybamm/solvers/idaklu_solver.py b/pybamm/solvers/idaklu_solver.py index 397f5fbb32..3a0753deb5 100644 --- a/pybamm/solvers/idaklu_solver.py +++ b/pybamm/solvers/idaklu_solver.py @@ -5,7 +5,6 @@ import pybamm import numpy as np import scipy.sparse as sparse -import numbers import importlib @@ -193,28 +192,22 @@ def _integrate(self, model, t_eval, inputs_dict=None): if model.convert_to_format == "jax": mass_matrix = model.mass_matrix.entries.toarray() - elif model.convert_to_format == "casadi": - #mass_matrix = casadi.DM(model.mass_matrix.entries) - mass_matrix = model.mass_matrix.entries else: mass_matrix = model.mass_matrix.entries # construct residuals function by binding inputs if model.convert_to_format == "casadi": - #y_casadi = casadi.MX.sym("y", model.len_rhs_and_alg) - #ydot_casadi = casadi.MX.sym("ydot", model.len_rhs_and_alg) - #t_casadi = casadi.MX.sym("t") - #casadi_resfn = casadi.Function( - # "residuals", - # [t_casadi, y_casadi, ydot_casadi], - # [model.rhs_algebraic_eval(t_casadi, y_casadi, inputs) - mass_matrix @ - # ydot_casadi] - #) def resfn(t, y, ydot): - return model.rhs_algebraic_eval(t, y, inputs).full().flatten() - mass_matrix @ ydot + return ( + model.rhs_algebraic_eval(t, y, inputs).full().flatten() + - mass_matrix @ ydot + ) else: def resfn(t, y, ydot): - return model.rhs_algebraic_eval(t, y, inputs).flatten() - mass_matrix @ ydot + return ( + model.rhs_algebraic_eval(t, y, inputs).flatten() + - mass_matrix @ ydot + ) jac_y0_t0 = model.jac_rhs_algebraic_eval(t_eval[0], y0, inputs) if sparse.issparse(jac_y0_t0): diff --git a/pybamm/solvers/scikits_dae_solver.py b/pybamm/solvers/scikits_dae_solver.py index eff037cc77..7d93b01c8e 100644 --- a/pybamm/solvers/scikits_dae_solver.py +++ b/pybamm/solvers/scikits_dae_solver.py @@ -94,11 +94,12 @@ def _integrate(self, model, t_eval, inputs_dict=None): y0 = model.y0 if isinstance(y0, casadi.DM): - y0 = y0.full().flatten() + y0 = y0.full() + y0 = y0.flatten() - residuals = model.residuals_eval + residuals = model.rhs_algebraic_eval events = model.terminate_events_eval - jacobian = model.jacobian_eval + jacobian = model.jac_rhs_algebraic_eval mass_matrix = model.mass_matrix.entries def eqsres(t, y, ydot, return_residuals): diff --git a/pybamm/solvers/scikits_ode_solver.py b/pybamm/solvers/scikits_ode_solver.py index 0c8913b9c8..847d44badc 100644 --- a/pybamm/solvers/scikits_ode_solver.py +++ b/pybamm/solvers/scikits_ode_solver.py @@ -84,18 +84,19 @@ def _integrate(self, model, t_eval, inputs_dict=None): """ inputs_dict = inputs_dict or {} - if model.rhs_eval.form == "casadi": + if model.convert_to_format == "casadi": inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) else: inputs = inputs_dict y0 = model.y0 if isinstance(y0, casadi.DM): - y0 = y0.full().flatten() + y0 = y0.full() + y0 = y0.flatten() derivs = model.rhs_eval events = model.terminate_events_eval - jacobian = model.jacobian_eval + jacobian = model.jac_rhs_eval def eqsydot(t, y, return_ydot): return_ydot[:] = derivs(t, y, inputs)