Skip to content

Commit

Permalink
#858 make_semi_explicit updates initial conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 6, 2020
1 parent d330e67 commit eb1b381
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 21 deletions.
8 changes: 3 additions & 5 deletions pybamm/solvers/casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def get_integrator(self, model, t_eval, inputs):
# Only set up problem once
if model not in self.problems:
y0 = model.y0
ydot0 = model.ydot0
rhs = model.casadi_rhs
algebraic = model.casadi_algebraic
u_stacked = casadi.vertcat(*[x for x in inputs.values()])
Expand All @@ -191,22 +190,21 @@ def get_integrator(self, model, t_eval, inputs):
u = casadi.MX.sym("u", u_stacked.shape[0])
y_diff = casadi.MX.sym("y_diff", rhs(t_eval[0], y0, u).shape[0])
problem = {"t": t, "x": y_diff, "p": u}
if algebraic(t_eval[0], y0, ydot0, u).is_empty():
if algebraic(t_eval[0], y0, u).is_empty():
method = "cvodes"
problem.update({"ode": rhs(t, y_diff, u)})
else:
options["calc_ic"] = True
method = "idas"
y_alg = casadi.MX.sym("y_alg", algebraic(t_eval[0], y0, ydot0, u).shape[0])
y_alg = casadi.MX.sym("y_alg", algebraic(t_eval[0], y0, u).shape[0])
y_full = casadi.vertcat(y_diff, y_alg)
problem.update(
{
"z": y_alg,
"ode": rhs(t, y_full, u),
"alg": algebraic(t, y_full, ydot0, u),
"alg": algebraic(t, y_full, u),
}
)
print('using method' ,method)
self.problems[model] = problem
self.options[model] = options
self.methods[model] = method
Expand Down
18 changes: 2 additions & 16 deletions tests/unit/test_solvers/test_casadi_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,7 @@ def test_model_solver_with_dvdt(self):
model.rhs = {var1: -2 * var1 * pybamm.t}
model.algebraic = {var2: var2 - pybamm.d_dt(var1)}
model.initial_conditions = {var1: 1, var2: 0}
print('before semi explicit')
for key,value in model.rhs.items():
print('{}: {}'.format(key, value))
for key,value in model.algebraic.items():
print('{}: {}'.format(key, value))
pybamm.make_semi_explicit(model)
print('after semi explicit')
for key,value in model.rhs.items():
print('{}: {}'.format(key, value))
for key,value in model.algebraic.items():
print('{}: {}'.format(key, value))
disc = get_discretisation_for_testing()
disc.process_model(model)

Expand All @@ -335,13 +325,9 @@ def test_model_solver_with_dvdt(self):
t_eval = np.linspace(0, 1, 100)
solution = solver.solve(model, t_eval)
np.testing.assert_array_equal(solution.t, t_eval)
import matplotlib.pyplot as plt
plt.plot(solution.y[0])
plt.plot(np.exp(-solution.t**2))
plt.show()
np.testing.assert_allclose(solution.y[0], np.exp(-solution.t ** 2), rtol=1e-06)
np.testing.assert_allclose(solution.y[0], np.exp(-solution.t**2), rtol=1e-06)
np.testing.assert_allclose(solution.y[-1],
-2 * solution.t * np.exp(-solution.t**2))
-2 * solution.t * np.exp(-solution.t**2), rtol=1e-06)


if __name__ == "__main__":
Expand Down

0 comments on commit eb1b381

Please sign in to comment.