Skip to content

Commit

Permalink
#1863 get algebraic solver to work
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jan 21, 2022
1 parent fda76c9 commit 76df2cc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
11 changes: 9 additions & 2 deletions pybamm/solvers/algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ def _integrate(self, model, t_eval, inputs_dict=None):
len_rhs = model.rhs_eval(t_eval[0], y0, inputs).shape[0]
y0_diff, y0_alg = np.split(y0, [len_rhs])

algebraic = model.algebraic_eval
if model.convert_to_format == 'casadi':
def algebraic(t, y):
result = model.algebraic_eval(t, y, inputs)
return result.full().flatten()
else:
def algebraic(t, y):
result = model.algebraic_eval(t, y, inputs)
return result.flatten()

y_alg = np.empty((len(y0_alg), len(t_eval)))

Expand All @@ -91,7 +98,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
def root_fun(y_alg):
"Evaluates algebraic using y"
y = np.concatenate([y0_diff, y_alg])
out = algebraic(t, y, inputs)
out = algebraic(t, y)
pybamm.logger.debug(
"Evaluating algebraic equations at t={}, L2-norm is {}".format(
t * model.timescale_eval, np.linalg.norm(out)
Expand Down
14 changes: 9 additions & 5 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
t_eval : numeric type, optional
The times (in seconds) at which to compute the solution
"""
pybamm.logger.info("Start solver set-up")

if ics_only:
pybamm.logger.info("Start solver set-up, initial_conditions only")
else:
pybamm.logger.info("Start solver set-up")

# Check model.algebraic for ode solvers
if self.ode_solver is True and len(model.algebraic) > 0:
Expand Down Expand Up @@ -474,6 +478,7 @@ def jacp(*args, **kwargs):
"initial_conditions",
use_jacobian=False,
)[0]
model.initial_conditions_eval = initial_conditions

# evaluate initial condition
y0_total_size = (
Expand Down Expand Up @@ -640,12 +645,11 @@ def jacp(*args, **kwargs):
model.jac_rhs_eval = jac_rhs
model.jacp_rhs_eval = jacp_rhs

model.jacp_algebraic_eval = jacp_algebraic
model.jac_algebraic_eval = jac_algebraic
model.jacp_algebraic_eval = jacp_algebraic

model.jac_rhs_algebraic_eval = jac_rhs_algebraic
model.jacp_rhs_algebraic_eval = jacp_rhs_algebraic
model.initial_conditions_eval = initial_conditions

# Save CasADi functions for the CasADi solver
# Save CasADi functions for solvers that use CasADi
Expand Down Expand Up @@ -723,7 +727,7 @@ def _set_initial_conditions(self, model, inputs_dict, update_rhs):
model.y0 = casadi.vertcat(
y0_from_inputs[:len_rhs], y0_from_model[len_rhs:]
)
y0 = self.calculate_consistent_state(model, 0, inputs)
y0 = self.calculate_consistent_state(model, 0, inputs_dict)
# Make y0 a function of inputs if doing symbolic with casadi
model.y0 = y0

Expand All @@ -738,7 +742,7 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
The model for which to calculate initial conditions.
time : float
The time at which to calculate the states
inputs_dict : dict, optional
inputs: dict, optional
Any input parameters to pass to the model when solving
Returns
Expand Down
15 changes: 12 additions & 3 deletions tests/unit/test_solvers/test_algebraic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ class Model(pybamm.BaseModel):
timescale_eval = 1
length_scales = {}
jac_algebraic_eval = None
convert_to_format = "python"
len_rhs_and_alg = 1

def __init__(self):
super().__init__()
self.convert_to_format = "python"

def algebraic_eval(self, t, y, inputs):
return y + 2

Expand All @@ -68,9 +71,12 @@ class Model(pybamm.BaseModel):
timescale_eval = 1
length_scales = {}
jac_algebraic_eval = None
convert_to_format = "casadi"
len_rhs_and_alg = 1

def __init__(self):
super().__init__()
self.convert_to_format = "python"

def algebraic_eval(self, t, y, inputs):
# algebraic equation has no real root
return y ** 2 + 1
Expand Down Expand Up @@ -99,9 +105,12 @@ class Model(pybamm.BaseModel):
rhs = {}
timescale_eval = 1
length_scales = {}
convert_to_format = "python"
len_rhs_and_alg = 2

def __init__(self):
super().__init__()
self.convert_to_format = "python"

def algebraic_eval(self, t, y, inputs):
return A @ y - b

Expand Down

0 comments on commit 76df2cc

Please sign in to comment.