Skip to content

Commit

Permalink
#759 discontinuity events seem to be working now
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Feb 5, 2020
1 parent 1b42969 commit 50ca0fc
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 57 deletions.
2 changes: 1 addition & 1 deletion pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def process_model(self, model, inplace=True, check_model=True):
processed_events = []
pybamm.logger.info("Discretise events for {}".format(model.name))
for event in model.events:
pybamm.logger.debug("Discretise event '{}'".format(event))
pybamm.logger.debug("Discretise event '{}'".format(event.name))
processed_event = pybamm.Event(
event.name,
self.process_symbol(event.expression),
Expand Down
101 changes: 84 additions & 17 deletions pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from scipy import optimize
from scipy.sparse import issparse
import sys


class BaseSolver(object):
Expand Down Expand Up @@ -218,13 +219,15 @@ def report(string):
)
terminate_events_eval = [
process(event.expression, "event", use_jacobian=False)[1]
for event in model.events
if events.type == pybamm.EventType.TERMINATION
for event in model.events
if event.event_type == pybamm.EventType.TERMINATION
]

# discontinuity events are evaluated before the solver is called, so don't need
# to process them
discontinuity_events_eval = [
process(event.expression, "event", use_jacobian=False)[1]
for event in model.events
if events.type == pybamm.EventType.DISCONTINUITY
event for event in model.events
if event.event_type == pybamm.EventType.DISCONTINUITY
]

# Add the solver attributes
Expand All @@ -243,7 +246,8 @@ def report(string):
residuals, residuals_eval, jacobian_eval = process(all_states, "residuals")
model.residuals_eval = residuals_eval
model.jacobian_eval = jacobian_eval
model.y0 = self.calculate_consistent_initial_conditions(model)
y0_guess = model.concatenated_initial_conditions.flatten()
model.y0 = self.calculate_consistent_state(model, 0, y0_guess)
else:
# can use DAE solver to solve ODE model
model.residuals_eval = Residuals(rhs, "residuals", model)
Expand Down Expand Up @@ -281,14 +285,12 @@ def set_inputs(self, model, ext_and_inputs):
model.residuals_eval.set_inputs(ext_and_inputs)
for evnt in model.terminate_events_eval:
evnt.set_inputs(ext_and_inputs)
for evnt in model.discontinuity_events_eval:
evnt.set_inputs(ext_and_inputs)
if model.jacobian_eval:
model.jacobian_eval.set_inputs(ext_and_inputs)

def calculate_consistent_initial_conditions(self, model):
def calculate_consistent_state(self, model, time=0, y0_guess=None):
"""
Calculate consistent initial conditions for the algebraic equations through
Calculate consistent state for the algebraic equations through
root-finding
Parameters
Expand All @@ -305,8 +307,9 @@ def calculate_consistent_initial_conditions(self, model):
pybamm.logger.info("Start calculating consistent initial conditions")
rhs = model.rhs_eval
algebraic = model.algebraic_eval
y0_guess = model.concatenated_initial_conditions.flatten()
jac = model.jac_algebraic_eval
if y0_guess is None:
y0_guess = model.concatenated_initial_conditions.flatten()

# Split y0_guess into differential and algebraic
len_rhs = rhs(0, y0_guess).shape[0]
Expand All @@ -315,7 +318,7 @@ def calculate_consistent_initial_conditions(self, model):
def root_fun(y0_alg):
"Evaluates algebraic using y0_diff (fixed) and y0_alg (changed by algo)"
y0 = np.concatenate([y0_diff, y0_alg])
out = algebraic(0, y0)
out = algebraic(time, y0)
pybamm.logger.debug(
"Evaluating algebraic equations at t=0, L2-norm is {}".format(
np.linalg.norm(out)
Expand Down Expand Up @@ -421,13 +424,77 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
# Set inputs and external
self.set_inputs(model, ext_and_inputs)

timer.reset()
pybamm.logger.info("Calling solver")
solution = self._integrate(model, t_eval, ext_and_inputs)
# Calculate discontinuities
discontinuities = [
event.expression.evaluate(u=inputs) for event in model.discontinuity_events_eval
]

# make sure they are increasing in time
discontinuities = sorted(discontinuities)
pybamm.logger.info(
'Discontinuity events found at t = {}'.format(discontinuities)
)
# remove any identical discontinuities
discontinuities = [
v for i, v in enumerate(discontinuities)
if i==len(discontinuities)-1 or discontinuities[i] < discontinuities[i+1]
]

# insert time points around discontinuities in t_eval
# keep track of sub sections to integrate by storing start and end indices
start_indices = [0]
end_indices = []
for dtime in discontinuities:
dindex = np.searchsorted(t_eval, dtime, side='left')
end_indices.append(dindex+1)
start_indices.append(dindex+1)
if t_eval[dindex] == dtime:
t_eval[dindex] += sys.float_info.epsilon
t_eval = np.insert(t_eval, dindex, dtime - sys.float_info.epsilon)
else:
t_eval = np.insert(t_eval, dindex,
[dtime - sys.float_info.epsilon, dtime + sys.float_info.epsilon])
end_indices.append(len(t_eval))

old_y0 = model.y0
solution = None
for start_index, end_index in zip(start_indices, end_indices):
pybamm.logger.info("Calling solver for {} < t < {}"
.format(t_eval[start_index], t_eval[end_index-1]))
timer.reset()
if solution is None:
solution = self._integrate(
model, t_eval[start_index:end_index], ext_and_inputs)
solution.solve_time = timer.time()
else:
new_solution = self._integrate(
model, t_eval[start_index:end_index], ext_and_inputs)
new_solution.solve_time = timer.time()
solution.append(new_solution, start_index=0)

if solution.termination != "final time":
break

if end_index != len(t_eval):
# setup for next integration subsection
y0_guess = solution.y[:, -1]
if model.algebraic:
model.y0 = self.calculate_consistent_state(model, t_eval[end_index], y0_guess)
else:
model.y0 = y0_guess

last_state = solution.y[:, -1]
if len(model.algebraic) > 0:
model.y0 = self.calculate_consistent_state(
model, t_eval[end_index], last_state)
else:
model.y0 = last_state

# restore old y0
model.y0 = old_y0

# Assign times
solution.set_up_time = set_up_time
solution.solve_time = timer.time()

# Add model and inputs to solution
solution.model = model
Expand Down Expand Up @@ -571,7 +638,7 @@ def get_termination_reason(self, solution, events):
final_event_values = {}

for event in events:
if event.type == pybamm.EventType.TERMINATION:
if event.event_type == pybamm.EventType.TERMINATION:
final_event_values[event.name] = abs(
event.expression.evaluate(
solution.t_event,
Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/scikits_dae_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _integrate(self, model, t_eval, inputs=None):
"""
residuals = model.residuals_eval
y0 = model.y0
events = model.events_eval
events = model.terminate_events_eval
jacobian = model.jacobian_eval
mass_matrix = model.mass_matrix.entries

Expand Down
2 changes: 1 addition & 1 deletion pybamm/solvers/scikits_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _integrate(self, model, t_eval, inputs=None):
"""
derivs = model.rhs_eval
y0 = model.y0
events = model.events_eval
events = model.terminate_events_eval
jacobian = model.jacobian_eval

def eqsydot(t, y, return_ydot):
Expand Down
6 changes: 3 additions & 3 deletions pybamm/solvers/scipy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def _integrate(self, model, t_eval, inputs=None):
extra_options.update({"jac": model.jacobian_eval})

# make events terminal so that the solver stops when they are reached
if model.events_eval:
for event in model.events_eval:
if model.terminate_events_eval:
for event in model.terminate_events_eval:
event.terminal = True
extra_options.update({"events": model.events_eval})
extra_options.update({"events": model.terminate_events_eval})

sol = it.solve_ivp(
model.rhs_eval,
Expand Down
17 changes: 10 additions & 7 deletions pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,22 +129,25 @@ def __add__(self, other):
self.append(other)
return self

def append(self, solution):
def append(self, solution, start_index=1):
"""
Appends solution.t and solution.y onto self.t and self.y.
Note: this process removes the initial time and state of solution to avoid
duplicate times and states being stored (self.t[-1] is equal to solution.t[0],
and self.y[:, -1] is equal to solution.y[:, 0]).
Note: by default this process removes the initial time and state of solution to
avoid duplicate times and states being stored (self.t[-1] is equal to
solution.t[0], and self.y[:, -1] is equal to solution.y[:, 0]). Set the optional
argument ``start_index`` to override this behavior
"""
# Update t, y and inputs
self.t = np.concatenate((self.t, solution.t[1:]))
self.y = np.concatenate((self.y, solution.y[:, 1:]), axis=1)
self.t = np.concatenate((self.t, solution.t[start_index:]))
self.y = np.concatenate((self.y, solution.y[:, start_index:]), axis=1)
for name, inp in self.inputs.items():
solution_inp = solution.inputs[name]
if isinstance(solution_inp, numbers.Number):
solution_inp = solution_inp * np.ones_like(solution.t)
self.inputs[name] = np.concatenate((inp, solution_inp[1:]))
self.inputs[name] = np.concatenate((inp, solution_inp[start_index:]))
# Update solution time
self.solve_time += solution.solve_time
# Update termination
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def algebraic_eval(self, t, y):
return y + 2

solver = pybamm.BaseSolver()
init_cond = solver.calculate_consistent_initial_conditions(ScalarModel())
model = ScalarModel()
init_cond = solver.calculate_consistent_state(model)
np.testing.assert_array_equal(init_cond, -2)

# More complicated system
Expand All @@ -75,15 +76,15 @@ def algebraic_eval(self, t, y):
return (y[1:] - vec[1:]) ** 2

model = VectorModel()
init_cond = solver.calculate_consistent_initial_conditions(model)
init_cond = solver.calculate_consistent_state(model)
np.testing.assert_array_almost_equal(init_cond, vec)

# With jacobian
def jac_dense(t, y):
return 2 * np.hstack([np.zeros((3, 1)), np.diag(y[1:] - vec[1:])])

model.jac_algebraic_eval = jac_dense
init_cond = solver.calculate_consistent_initial_conditions(model)
init_cond = solver.calculate_consistent_state(model)
np.testing.assert_array_almost_equal(init_cond, vec)

# With sparse jacobian
Expand All @@ -93,7 +94,7 @@ def jac_sparse(t, y):
)

model.jac_algebraic_eval = jac_sparse
init_cond = solver.calculate_consistent_initial_conditions(model)
init_cond = solver.calculate_consistent_state(model)
np.testing.assert_array_almost_equal(init_cond, vec)

def test_fail_consistent_initial_conditions(self):
Expand All @@ -114,13 +115,13 @@ def algebraic_eval(self, t, y):
pybamm.SolverError,
"Could not find consistent initial conditions: The iteration is not making",
):
solver.calculate_consistent_initial_conditions(Model())
solver.calculate_consistent_state(Model())
solver = pybamm.BaseSolver()
with self.assertRaisesRegex(
pybamm.SolverError,
"Could not find consistent initial conditions: solver terminated",
):
solver.calculate_consistent_initial_conditions(Model())
solver.calculate_consistent_state(Model())


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 50ca0fc

Please sign in to comment.