diff --git a/pybamm/expression_tree/operations/evaluate_python.py b/pybamm/expression_tree/operations/evaluate_python.py index e4aea0a8ca..ac549d79f3 100644 --- a/pybamm/expression_tree/operations/evaluate_python.py +++ b/pybamm/expression_tree/operations/evaluate_python.py @@ -486,15 +486,21 @@ def __call__(self, t=None, y=None, inputs=None): y = y.reshape(-1, 1) if isinstance(inputs, list): - ny = y.shape[0] - ni = len(inputs) - result = np.zeros((ni * ny, 1)) - i = 0 - for input in inputs: - result[i : i + ny] += self._evaluate( - self._constants, t, y[i : i + ny], input - ) - i += ny + result = self._evaluate(self._constants, t, y, inputs[0]) + if len(inputs) > 1: + if isinstance(result, numbers.Number): + result = np.array([result]) + ny = result.shape[0] + ni = len(inputs) + results = np.zeros((ni * ny, 1)) + results[:ny] = result + i = ny + for input in inputs[1:]: + results[i : i + ny] += self._evaluate( + self._constants, t, y[i : i + ny], input + ) + i += ny + result = results else: result = self._evaluate(self._constants, t, y, inputs) diff --git a/pybamm/solvers/algebraic_solver.py b/pybamm/solvers/algebraic_solver.py index 2ede221012..12a0735ee6 100644 --- a/pybamm/solvers/algebraic_solver.py +++ b/pybamm/solvers/algebraic_solver.py @@ -59,7 +59,7 @@ def _integrate(self, model, t_eval, inputs_list=None): inputs_list: list of dict, optional Any input parameters to pass to the model when solving """ - inputs_list = inputs_list or {} + inputs_list = inputs_list or [{}] if model.convert_to_format == "casadi": inputs = casadi.vertcat( *[x for inputs in inputs_list for x in inputs.values()] @@ -149,7 +149,7 @@ def jac_fn(y_alg, jac=jac): if jac_fn is None: jac_fn = "2-point" timer.reset() - sol = optimize.least_squares( + solns = optimize.least_squares( root_fun, y0_alg, method=method, @@ -186,7 +186,7 @@ def jac_norm(y, jac_fn=jac_fn): ] extra_options["bounds"] = bounds timer.reset() - sol = optimize.minimize( + solns = optimize.minimize( root_norm, y0_alg, method=method, @@ -197,7 +197,7 @@ def jac_norm(y, jac_fn=jac_fn): integration_time += timer.time() else: timer.reset() - sol = optimize.root( + solns = optimize.root( root_fun, y0_alg, method=self.method, @@ -207,23 +207,23 @@ def jac_norm(y, jac_fn=jac_fn): ) integration_time += timer.time() - if sol.success and np.all(abs(sol.fun) < self.tol): + if solns.success and np.all(abs(solns.fun) < self.tol): # update initial guess for the next iteration - y0_alg = sol.x + y0_alg = solns.x # update solution array y_alg[:, idx] = y0_alg success = True - elif not sol.success: + elif not solns.success: raise pybamm.SolverError( - f"Could not find acceptable solution: {sol.message}" + f"Could not find acceptable solution: {solns.message}" ) else: - y0_alg = sol.x + y0_alg = solns.x if itr > maxiter: raise pybamm.SolverError( "Could not find acceptable solution: solver terminated " "successfully, but maximum solution error " - f"({np.max(abs(sol.fun))}) above tolerance ({self.tol})" + f"({np.max(abs(solns.fun))}) above tolerance ({self.tol})" ) itr += 1 @@ -231,8 +231,9 @@ def jac_norm(y, jac_fn=jac_fn): y_diff = np.r_[[y0_diff] * len(t_eval)].T y_sol = np.r_[y_diff, y_alg] # Return solution object (no events, so pass None to t_event, y_event) - sol = pybamm.Solution.from_concatenated_state( + solns = pybamm.Solution.from_concatenated_state( t_eval, y_sol, model, inputs_list, termination="final time" ) - sol.integration_time = integration_time - return sol + for s in solns: + s.integration_time = integration_time + return solns diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 5055ae17c2..f42633cd1f 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -102,12 +102,16 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): model : :class:`pybamm.BaseModel` The model whose solution to calculate. Must have attributes rhs and initial_conditions - inputs : list of dict, optional + inputs : dict or list of dict, optional Any input parameters to pass to the model when solving t_eval : numeric type, optional The times (in seconds) at which to compute the solution """ - inputs = inputs or [{}] + + if inputs is None: + inputs = [{}] + elif isinstance(inputs, dict): + inputs = [inputs] if ics_only: pybamm.logger.info("Start solver set-up, initial_conditions only") @@ -152,7 +156,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): if model.convert_to_format == "casadi": # stack inputs inputs_casadi = casadi.vertcat( - *[x for inpt in inputs for x in inpt.items()] + *[x for inpt in inputs for x in inpt.values()] ) model.y0 = initial_conditions(0.0, y_zero, inputs_casadi) if jacp_ic is None: @@ -389,7 +393,7 @@ def _get_vars_for_processing(model, inputs, calculate_sensitivities_explicit): y_alg = casadi.MX.sym("y_alg", model.len_alg) y_casadi = casadi.vertcat(y_diff, y_alg) p_casadi = {} - for name, value in inputs.items(): + for name, value in inputs[0].items(): if isinstance(value, numbers.Number): p_casadi[name] = casadi.MX.sym(name) else: @@ -610,7 +614,7 @@ def _set_up_events(self, model, t_eval, inputs, vars_for_processing): discontinuity_events, ) - def _set_initial_conditions(self, model, time, inputs_dict, update_rhs): + def _set_initial_conditions(self, model, time, inputs_list, update_rhs): """ Set initial conditions for the model. This is skipped if the solver is an algebraic solver (since this would make the algebraic solver redundant), and if @@ -621,7 +625,7 @@ def _set_initial_conditions(self, model, time, inputs_dict, update_rhs): ---------- model : :class:`pybamm.BaseModel` The model for which to calculate initial conditions. - inputs_dict : dict + inputs_list: list of dict Any input parameters to pass to the model when solving update_rhs : bool Whether to update the rhs. True for 'solve', False for 'step'. @@ -635,9 +639,11 @@ def _set_initial_conditions(self, model, time, inputs_dict, update_rhs): if model.convert_to_format == "casadi": # stack inputs - inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) + inputs = casadi.vertcat( + *[x for inpts in inputs_list for x in inpts.values()] + ) else: - inputs = inputs_dict + inputs = inputs_list if self.algebraic_solver is True: # Don't update model.y0 @@ -667,7 +673,7 @@ def _set_initial_conditions(self, model, time, inputs_dict, update_rhs): model.y0 = np.vstack( (y0_from_inputs[:len_rhs], y0_from_model[len_rhs:]) ) - y0 = self.calculate_consistent_state(model, time, inputs_dict) + y0 = self.calculate_consistent_state(model, time, inputs_list) # Make y0 a function of inputs if doing symbolic with casadi model.y0 = y0 @@ -696,15 +702,22 @@ def calculate_consistent_state(self, model, time=0, inputs=None): if self.root_method is None: return model.y0 try: - root_sol = self.root_method._integrate(model, np.array([time]), inputs) + root_sols = self.root_method._integrate(model, np.array([time]), inputs) except pybamm.SolverError as e: raise pybamm.SolverError( f"Could not find consistent states: {e.args[0]}" ) from e pybamm.logger.debug("Found consistent states") - self.check_extrapolation(root_sol, model.events) - y0 = root_sol.all_ys[0] + y0s = [] + for s in root_sols: + self.check_extrapolation(s, model.events) + y0s.append(s.all_ys[0]) + + if isinstance(y0s[0], casadi.DM): + y0 = casadi.horzcat(*y0s) + else: + y0 = np.hstack(y0s) return y0 def solve( @@ -960,7 +973,7 @@ def solve( ) # Return solution(s) - if len(inputs) == 1: + if len(inputs_list) == 1: return solutions[0] else: return solutions @@ -1026,20 +1039,22 @@ def _get_discontinuity_start_end_indices(model, inputs, t_eval): return start_indices, end_indices, t_eval @staticmethod - def _check_events_with_initial_conditions(t_eval, model, inputs_dict): + def _check_events_with_initial_conditions(t_eval, model, inputs_list): num_terminate_events = len(model.terminate_events_eval) if num_terminate_events == 0: return if model.convert_to_format == "casadi": - inputs = casadi.vertcat(*[x for x in inputs_dict.values()]) + inputs = casadi.vertcat( + *[x for inputs in inputs_list for x in inputs.values()] + ) events_eval = [None] * num_terminate_events for idx, event in enumerate(model.terminate_events_eval): if model.convert_to_format == "casadi": event_eval = event(t_eval[0], model.y0, inputs) elif model.convert_to_format in ["python", "jax"]: - event_eval = event(t=t_eval[0], y=model.y0, inputs=inputs_dict) + event_eval = event(t=t_eval[0], y=model.y0, inputs=inputs_list) events_eval[idx] = event_eval events_eval = np.array(events_eval) @@ -1071,7 +1086,7 @@ def step( Parameters ---------- - old_solution : :class:`pybamm.Solution` or None + old_solution : :class:`pybamm.Solution` or list of :class:`pybamm.Solution` or None The previous solution to be added to. If `None`, a new solution is created. model : :class:`pybamm.BaseModel` The model whose solution to calculate. Must have attributes rhs and @@ -1083,7 +1098,7 @@ def step( (Note: t_eval is the time measured from the start of the step, so should start at 0 and end at dt). By default, the solution is returned at t0 and t0 + dt. npts : deprecated - inputs : dict, optional + inputs : dict or list of dict, optional Any input parameters to pass to the model when solving save : bool, optional Save solution with all previous timesteps. Defaults to True. @@ -1094,17 +1109,23 @@ def step( `model.variables = {}`) """ + inputs_list = inputs if isinstance(inputs, list) else [inputs] if old_solution is None: - old_solution = pybamm.EmptySolution() + old_solutions = [pybamm.EmptySolution()] * len(inputs_list) + elif not isinstance(old_solution, list): + old_solutions = [old_solution] if not ( - isinstance(old_solution, pybamm.EmptySolution) - or old_solution.termination == "final time" - or "[experiment]" in old_solution.termination + isinstance(old_solutions[0], pybamm.EmptySolution) + or old_solutions[0].termination == "final time" + or "[experiment]" in old_solutions[0].termination ): # Return same solution as an event has already been triggered # With hack to allow stepping past experiment current / voltage cut-off - return old_solution + if len(old_solutions) == 1: + return old_solutions[0] + else: + return old_solutions # Make sure model isn't empty if len(model.rhs) == 0 and len(model.algebraic) == 0: @@ -1139,7 +1160,7 @@ def step( else: t_eval = np.array([0, dt]) - t_start = old_solution.t[-1] + t_start = old_solutions[0].t[-1] t_eval = t_start + t_eval t_end = t_start + dt @@ -1157,7 +1178,9 @@ def step( timer = pybamm.Timer() # Set up inputs - model_inputs = self._set_up_model_inputs(model, inputs) + model_inputs_list = [ + self._set_up_model_inputs(model, inputs) for inputs in inputs_list + ] first_step_this_model = False if model not in self._model_set_up: @@ -1169,71 +1192,81 @@ def step( f'"{existing_model.name}". Please create a separate ' "solver for this model" ) - self.set_up(model, model_inputs) + self.set_up(model, model_inputs_list) self._model_set_up.update( {model: {"initial conditions": model.concatenated_initial_conditions}} ) if ( - isinstance(old_solution, pybamm.EmptySolution) - and old_solution.termination is None + isinstance(old_solutions[0], pybamm.EmptySolution) + and old_solutions[0].termination is None ): pybamm.logger.verbose(f"Start stepping {model.name} with {self.name}") - if isinstance(old_solution, pybamm.EmptySolution): + if isinstance(old_solutions[0], pybamm.EmptySolution): if not first_step_this_model: # reset y0 to original initial conditions - self.set_up(model, model_inputs, ics_only=True) + self.set_up(model, model_inputs_list, ics_only=True) else: - if old_solution.all_models[-1] == model: + if old_solutions[0].all_models[-1] == model: # initialize with old solution - model.y0 = old_solution.all_ys[-1][:, -1] + y0s = [s.all_ys[-1][:, -1] for s in old_solutions] + else: - _, concatenated_initial_conditions = model.set_initial_conditions_from( - old_solution, return_type="ics" - ) - model.y0 = concatenated_initial_conditions.evaluate( - 0, inputs=model_inputs - ) + y0s = [] + for soln, inputs in zip(old_solutions, model_inputs_list): + _, concatenated_initial_conditions = ( + model.set_initial_conditions_from(soln, return_type="ics") + ) + y0s.append( + concatenated_initial_conditions.evaluate(0, inputs=inputs) + ) + + model.y0 = casadi.vertcat(*y0s) set_up_time = timer.time() # (Re-)calculate consistent initial conditions self._set_initial_conditions( - model, t_start_shifted, model_inputs, update_rhs=False + model, t_start_shifted, model_inputs_list, update_rhs=False ) # Check initial conditions don't violate events - self._check_events_with_initial_conditions(t_eval, model, model_inputs) + self._check_events_with_initial_conditions(t_eval, model, model_inputs_list) # Step pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}") timer.reset() - solution = self._integrate(model, t_eval, model_inputs) - solution.solve_time = timer.time() + solutions = self._integrate(model, t_eval, model_inputs_list) + for i, s in enumerate(solutions): + solutions[i].solve_time = timer.time() - # Check if extrapolation occurred - self.check_extrapolation(solution, model.events) + # Check if extrapolation occurred + self.check_extrapolation(s, model.events) - # Identify the event that caused termination and update the solution to - # include the event time and state - solution, termination = self.get_termination_reason(solution, model.events) + # Identify the event that caused termination and update the solution to + # include the event time and state + solutions[i], termination = self.get_termination_reason(s, model.events) - # Assign setup time - solution.set_up_time = set_up_time + # Assign setup time + solutions[i].set_up_time = set_up_time # Report times pybamm.logger.verbose(f"Finish stepping {model.name} ({termination})") pybamm.logger.verbose( - f"Set-up time: {solution.set_up_time}, Step time: {solution.solve_time} (of which integration time: {solution.integration_time}), " - f"Total time: {solution.total_time}" + f"Set-up time: {solutions[0].set_up_time}, Step time: {solutions[0].solve_time} (of which integration time: {solutions[0].integration_time}), " + f"Total time: {solutions[0].total_time}" ) # Return solution if save is False: - return solution + ret = solutions + else: + ret = [old_s + s for (old_s, s) in zip(old_solutions, solutions)] + if len(ret) == 1: + return ret[0] else: - return old_solution + solution + return ret @staticmethod def get_termination_reason(solution, events): diff --git a/pybamm/solvers/casadi_algebraic_solver.py b/pybamm/solvers/casadi_algebraic_solver.py index 2a665f0773..17142428d8 100644 --- a/pybamm/solvers/casadi_algebraic_solver.py +++ b/pybamm/solvers/casadi_algebraic_solver.py @@ -51,7 +51,7 @@ def _integrate(self, model, t_eval, inputs_list=None): Any input parameters to pass to the model when solving. """ # Record whether there are any symbolic inputs - inputs_list = inputs_list or {} + inputs_list = inputs_list or [{}] # Create casadi objects for the root-finder inputs = casadi.vertcat(*[v for inputs in inputs_list for v in inputs.values()]) @@ -164,7 +164,7 @@ def _integrate(self, model, t_eval, inputs_list=None): except AttributeError: explicit_sensitivities = False - sol = pybamm.Solution.from_concatenated_state( + sols = pybamm.Solution.from_concatenated_state( [t_eval], y_sol, model, @@ -172,5 +172,6 @@ def _integrate(self, model, t_eval, inputs_list=None): termination="final time", sensitivities=explicit_sensitivities, ) - sol.integration_time = integration_time - return sol + for sol in sols: + sol.integration_time = integration_time + return sols diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py index c658ec70ed..cf86c20821 100644 --- a/pybamm/solvers/casadi_solver.py +++ b/pybamm/solvers/casadi_solver.py @@ -148,7 +148,7 @@ def _integrate(self, model, t_eval, inputs_list=None): """ # Record whether there are any symbolic inputs - inputs_list = inputs_list or [] + inputs_list = inputs_list or [{}] # convert inputs to casadi format inputs = casadi.vertcat(*[x for inputs in inputs_list for x in inputs.values()]) @@ -294,13 +294,18 @@ def _integrate(self, model, t_eval, inputs_list=None): # termination point and exit # Note: this is only done for the first solution, is this correct? current_step_sols[0] = self._solve_for_event(current_step_sols[0]) - for soln, current_step_sol in zip(current_step_sols, solutions): - # assign temporary solve time - current_step_sol.solve_time = np.nan - # append solution from the current step to solution - soln = soln + current_step_sol + if solutions is None: + for s in current_step_sols: + s.solve_time = np.nan + solutions = current_step_sols + else: + for i, current_step_sol in enumerate(current_step_sols): + # assign temporary solve time + current_step_sol.solve_time = np.nan + # append solution from the current step to solution + solutions[i] = solutions[i] + current_step_sol - if current_step_sol[0].termination == "event": + if current_step_sols[0].termination == "event": break else: # update time as time @@ -450,10 +455,10 @@ def integer_bisect(): use_grid = True y0 = coarse_solution.y[:, event_idx_lower] - dense_step_sol = self._run_integrator( + [dense_step_sol] = self._run_integrator( model, y0, - inputs_dict, + [inputs_dict], inputs, t_window_event_dense, use_grid=use_grid, @@ -602,7 +607,7 @@ def _run_integrator( self, model, y0, - inputs_dict, + inputs_list, inputs, t_eval, use_grid=True, @@ -617,7 +622,7 @@ def _run_integrator( The model whose solution to calculate. y0: casadi vector of initial conditions - inputs_dict : dict, optional + inputs_list : list of dict, optional Any input parameters to pass to the model when solving inputs: Casadi vector of inputs @@ -701,7 +706,7 @@ def _run_integrator( t_eval, y_sol, model, - inputs_dict, + inputs_list, sensitivities=extract_sensitivities_in_solution, check_solution=False, ) @@ -742,7 +747,7 @@ def _run_integrator( t_eval, y_sol, model, - inputs_dict, + inputs_list, sensitivities=extract_sensitivities_in_solution, check_solution=False, ) diff --git a/pybamm/solvers/dummy_solver.py b/pybamm/solvers/dummy_solver.py index 47b80a9a09..f4b19ac428 100644 --- a/pybamm/solvers/dummy_solver.py +++ b/pybamm/solvers/dummy_solver.py @@ -37,4 +37,4 @@ def _integrate(self, model, t_eval, inputs_list=None): t_eval, y_sol, model, inputs_list, termination="final time" ) sol.integration_time = 0 - return sol + return [sol] diff --git a/pybamm/solvers/scipy_solver.py b/pybamm/solvers/scipy_solver.py index cc1b420925..5685967ffa 100644 --- a/pybamm/solvers/scipy_solver.py +++ b/pybamm/solvers/scipy_solver.py @@ -143,7 +143,7 @@ def event_fn(t, y): termination = "final time" t_event = None y_event = np.array(None) - sol = pybamm.Solution.from_concatenated_state( + solns = pybamm.Solution.from_concatenated_state( sol.t, sol.y, model, @@ -153,7 +153,8 @@ def event_fn(t, y): termination, sensitivities=bool(model.calculate_sensitivities), ) - sol.integration_time = integration_time - return sol + for s in solns: + s.integration_time = integration_time + return solns else: raise pybamm.SolverError(sol.message) diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py index 281e265d2a..8b253dfdb2 100644 --- a/pybamm/solvers/solution.py +++ b/pybamm/solvers/solution.py @@ -172,7 +172,7 @@ def from_concatenated_state( return [ cls( t, - y[i * ny : (i + 1) * ny], + y[i * ny : (i + 1) * ny, :], model, input_list[i], t_event, diff --git a/tests/unit/test_solvers/test_algebraic_solver.py b/tests/unit/test_solvers/test_algebraic_solver.py index 22017f8cf8..ca5098c036 100644 --- a/tests/unit/test_solvers/test_algebraic_solver.py +++ b/tests/unit/test_solvers/test_algebraic_solver.py @@ -55,13 +55,13 @@ def algebraic_eval(self, t, y, inputs): # Try passing extra options to solver solver = pybamm.AlgebraicSolver(extra_options={"maxiter": 100}) model = Model() - solution = solver._integrate(model, np.array([0])) - np.testing.assert_array_equal(solution.y, -2) + solutions = solver._integrate(model, np.array([0])) + np.testing.assert_array_equal(solutions[0].y, -2) # Relax options and see worse results solver = pybamm.AlgebraicSolver(extra_options={"ftol": 1}) - solution = solver._integrate(model, np.array([0])) - self.assertNotEqual(solution.y, -2) + solutions = solver._integrate(model, np.array([0])) + self.assertNotEqual(solutions[0].y, -2) def test_root_find_fail(self): class Model(pybamm.BaseModel): @@ -116,8 +116,8 @@ def jac_algebraic_eval(self, t, y, inputs): sol = np.array([3, -4])[:, np.newaxis] solver = pybamm.AlgebraicSolver() - solution = solver._integrate(model, np.array([0])) - np.testing.assert_array_almost_equal(solution.y, sol) + solutions = solver._integrate(model, np.array([0])) + np.testing.assert_array_almost_equal(solutions[0].y, sol) def test_model_solver(self): # Create model