From c1e615e836a10df5a065cffdfe5337ccaa6b7913 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Thu, 5 Mar 2020 14:47:45 -0500 Subject: [PATCH] #871 start making changes to quickplot, still working on 3D --- examples/scripts/DFN.py | 14 +- examples/scripts/SPMe.py | 2 +- pybamm/parameters/parameter_values.py | 8 +- pybamm/processed_variable.py | 16 +- pybamm/quick_plot.py | 340 ++++++++++++------ .../test_models/standard_output_comparison.py | 4 +- tests/integration/test_quick_plot.py | 26 +- tests/unit/test_quick_plot.py | 46 ++- 8 files changed, 311 insertions(+), 145 deletions(-) diff --git a/examples/scripts/DFN.py b/examples/scripts/DFN.py index 9155d9eaa7..88c9319aed 100644 --- a/examples/scripts/DFN.py +++ b/examples/scripts/DFN.py @@ -36,5 +36,17 @@ solution = solver.solve(model, t_eval) # plot -plot = pybamm.QuickPlot(solution) +plot = pybamm.QuickPlot( + solution, + output_variables=[ + "Negative particle surface concentration [mol.m-3]", + "Electrolyte concentration [mol.m-3]", + "Positive particle surface concentration [mol.m-3]", + "Current [A]", + "Negative electrode potential [V]", + "Electrolyte potential [V]", + "Positive electrode potential [V]", + "Terminal voltage [V]", + ], +) plot.dynamic_plot() diff --git a/examples/scripts/SPMe.py b/examples/scripts/SPMe.py index 364c7c0508..7f1615053c 100644 --- a/examples/scripts/SPMe.py +++ b/examples/scripts/SPMe.py @@ -31,5 +31,5 @@ solution = model.default_solver.solve(model, t_eval) # plot -plot = pybamm.QuickPlot(solution) +plot = pybamm.QuickPlot(solution, ["Negative particle concentration"]) plot.dynamic_plot() diff --git a/pybamm/parameters/parameter_values.py b/pybamm/parameters/parameter_values.py index dd596e746f..292bdd269e 100644 --- a/pybamm/parameters/parameter_values.py +++ b/pybamm/parameters/parameter_values.py @@ -109,7 +109,9 @@ def update_from_chemistry(self, chemistry): """ base_chemistry = chemistry["chemistry"] # Create path to file - path = os.path.join("input", "parameters", base_chemistry) + path = os.path.join( + pybamm.root_dir(), "pybamm", "input", "parameters", base_chemistry + ) # Load each component name for component_group in [ "cell", @@ -227,7 +229,9 @@ def update(self, values, check_conflict=False, check_already_exists=True, path=" # Data is flagged with the string "[data]" or "[current data]" elif value.startswith("[current data]") or value.startswith("[data]"): if value.startswith("[current data]"): - data_path = os.path.join("input", "drive_cycles") + data_path = os.path.join( + pybamm.root_dir(), "pybamm", "input", "drive_cycles" + ) filename = os.path.join(data_path, value[14:] + ".csv") function_name = value[14:] else: diff --git a/pybamm/processed_variable.py b/pybamm/processed_variable.py index 23d74a461c..74090bd34e 100644 --- a/pybamm/processed_variable.py +++ b/pybamm/processed_variable.py @@ -107,7 +107,7 @@ def initialise_1D(self): ) self.entries = entries - self.dimensions = 1 + self.dimensions = 0 def initialise_2D(self): len_space = self.base_eval.shape[0] @@ -147,7 +147,7 @@ def initialise_2D(self): # assign attributes for reference (either x_sol or r_sol) self.entries = entries - self.dimensions = 2 + self.dimensions = 1 if self.domain[0] in ["negative particle", "positive particle"]: self.first_dimension = "r" self.r_sol = space @@ -243,7 +243,7 @@ def initialise_3D(self): # assign attributes for reference self.entries = entries - self.dimensions = 3 + self.dimensions = 2 self.first_dim_pts = first_dim_pts self.second_dim_pts = second_dim_pts @@ -307,7 +307,7 @@ def initialise_3D_scikit_fem(self): # assign attributes for reference self.entries = entries - self.dimensions = 3 + self.dimensions = 2 self.y_sol = y_sol self.z_sol = z_sol self.first_dimension = "y" @@ -322,15 +322,15 @@ def __call__(self, t=None, x=None, r=None, y=None, z=None, warn=True): """ Evaluate the variable at arbitrary t (and x, r, y and/or z), using interpolation """ - if self.dimensions == 1: + if self.dimensions == 0: out = self._interpolation_function(t) + elif self.dimensions == 1: + out = self.call_2D(t, x, r, z) elif self.dimensions == 2: if t is None: out = self._interpolation_function(y, z) else: - out = self.call_2D(t, x, r, z) - elif self.dimensions == 3: - out = self.call_3D(t, x, r, y, z) + out = self.call_3D(t, x, r, y, z) if warn is True and np.isnan(out).any(): pybamm.logger.warning( "Calling variable outside interpolation range (returns 'nan')" diff --git a/pybamm/quick_plot.py b/pybamm/quick_plot.py index 9c49794e20..b293da7296 100644 --- a/pybamm/quick_plot.py +++ b/pybamm/quick_plot.py @@ -37,6 +37,19 @@ def split_long_string(title, max_words=4): return first_line + "\n" + second_line +def get_spatial_scale(key, spatial_var_name, spatial_scales): + "Return the appropriate spatial scale" + if spatial_var_name == "r": + if "negative" in key[0].lower(): + spatial_scale = spatial_scales["r_n"] + elif "positive" in key[0].lower(): + spatial_scale = spatial_scales["r_p"] + else: + spatial_scale = spatial_scales[spatial_var_name] + + return spatial_scale + + class QuickPlot(object): """ Generates a quick plot of a subset of key outputs of the model so that the model @@ -71,6 +84,8 @@ def __init__( labels=None, colors=None, linestyles=None, + figsize=None, + time_format=None, ): if isinstance(solutions, pybamm.Solution): solutions = [solutions] @@ -80,14 +95,22 @@ def __init__( models = [solution.model for solution in solutions] # Set labels - self.labels = labels or [model.name for model in models] + if labels is None: + self.labels = [model.name for model in models] + else: + if len(labels) != len(models): + raise ValueError( + "labels '{}' have different length to models '{}'".format( + labels, [model.name for model in models] + ) + ) + self.labels = labels - # Set colors and linestyles - self.colors = colors - self.linestyles = linestyles + # Set colors, linestyles, figsize + self.colors = colors or ["r", "b", "k", "g", "m", "c"] + self.linestyles = linestyles or ["-", ":", "--", "-."] + self.figsize = figsize or (15, 8) - # Time scale in hours - self.time_scale = models[0].timescale_eval / 3600 # Spatial scales (default to 1 if information not in model) variables = models[0].variables self.spatial_scales = {"x": 1, "y": 1, "z": 1, "r_n": 1, "r_p": 1} @@ -113,17 +136,34 @@ def __init__( ).evaluate()[-1] # Time parameters + model_timescale_in_seconds = models[0].timescale_eval self.ts = [solution.t for solution in solutions] - self.min_t = np.min([t[0] for t in self.ts]) * self.time_scale - self.max_t = np.max([t[-1] for t in self.ts]) * self.time_scale + self.min_t = np.min([t[0] for t in self.ts]) * model_timescale_in_seconds + self.max_t = np.max([t[-1] for t in self.ts]) * model_timescale_in_seconds + + # Set timescale + if time_format is None: + # defaults depend on how long the simulation is + if self.max_t >= 3600: + self.time_scale = model_timescale_in_seconds / 3600 # time in hours + else: + self.time_scale = model_timescale_in_seconds # time in seconds + elif time_format == "seconds": + self.time_scale = model_timescale_in_seconds + elif time_format == "minutes": + self.time_scale = model_timescale_in_seconds / 60 + elif time_format == "hours": + self.time_scale = model_timescale_in_seconds / 3600 + else: + raise ValueError("time format '{}' not recognized".format(time_format)) # Default output variables for lead-acid and lithium-ion if output_variables is None: if isinstance(models[0], pybamm.lithium_ion.BaseModel): output_variables = [ - "Negative particle surface concentration", - "Electrolyte concentration", - "Positive particle surface concentration", + "Negative particle surface concentration [mol.m-3]", + "Electrolyte concentration [mol.m-3]", + "Positive particle surface concentration [mol.m-3]", "Current [A]", "Negative electrode potential [V]", "Electrolyte potential [V]", @@ -156,45 +196,54 @@ def set_output_variables(self, output_variables, solutions): self.n_rows = int(len(output_variables) // np.sqrt(len(output_variables))) self.n_cols = int(np.ceil(len(output_variables) / self.n_rows)) - # Process output variables into a form that can be plotted - processed_variables = {} - for solution in solutions: - processed_variables[solution] = {} - for variable_list in output_variables: - # Make sure we always have a list of lists of variables - if isinstance(variable_list, str): - variable_list = [variable_list] - # Add all variables to the list of variables that should be processed - processed_variables[solution].update( - {var: solution[var] for var in variable_list} - ) - # Prepare dictionary of variables + # output_variables is a list of strings or lists, e.g. + # ["var 1", ["variable 2", "var 3"]] for k, variable_list in enumerate(output_variables): - # Make sure we always have a list of lists of variables + # Make sure we always have a list of lists of variables, e.g. + # [["var 1"], ["variable 2", "var 3"]] if isinstance(variable_list, str): variable_list = [variable_list] - # Prepare list of variables + # Store the key as a tuple + # key is the variable names, e.g. ("var 1",) or ("var 2", "var 3") key = tuple(variable_list) - self.variables[key] = [None] * len(solutions) + + # Prepare list of variables + variables = [None] * len(solutions) # process each variable in variable_list for each model for i, solution in enumerate(solutions): - # self.variables is a dictionary of lists of lists - self.variables[key][i] = [ - processed_variables[solution][var] for var in variable_list - ] + # variables lists of lists + variables[i] = [] + # first index is the solution number + # second index is the variable number + for var in variable_list: + sol = solution[var] + # Check variable isn't all-nan + if np.all(np.isnan(sol.entries)): + raise ValueError("All-NaN variable '{}' provided".format(var)) + # If ok, add to the list of solutions + else: + variables[i].append(sol) # Make sure variables have the same dimensions and domain - first_variable = self.variables[key][0][0] + # just use the first solution to check this + first_solution = variables[0] + first_variable = first_solution[0] domain = first_variable.domain - for variable in self.variables[key][0]: + # check all other variables against the first variable + for idx, variable in enumerate(first_solution): if variable.domain != domain: - raise ValueError("mismatching variable domains") + raise ValueError( + "Mismatching variable domains. " + "'{}' has domain '{}', but '{}' has domain '{}'".format( + key[0], domain, key[idx], variable.domain + ) + ) - # Set the x variable for any two-dimensional variables - if first_variable.dimensions == 2: + # Set the x variable (i.e. "x" or "r" for any one-dimensional variables) + if first_variable.dimensions == 1: spatial_variable_key = first_variable.first_dimension spatial_variable_value = first_variable.first_dim_pts self.spatial_variable[key] = ( @@ -202,11 +251,26 @@ def set_output_variables(self, output_variables, solutions): spatial_variable_value, ) - # Don't allow 3D variables - elif any(var.dimensions == 3 for var in self.variables[key][0]): - raise NotImplementedError("cannot plot 3D variables") + # Don't allow 2D variables if there are multiple solutions + elif first_variable.dimensions == 2: + if len(variables) > 1: + raise NotImplementedError( + "Cannot plot 3D variables when comparing multiple solutions, " + "but {} is 3D".format() + ) + else: + # Add both spatial variables to the keys + first_spatial_variable_key = first_variable.first_dimension + first_spatial_variable_value = first_variable.first_dim_pts + second_spatial_variable_key = first_variable.second_dimension + second_spatial_variable_value = first_variable.second_dim_pts + self.spatial_variable[key] = ( + (first_spatial_variable_key, first_spatial_variable_value), + (second_spatial_variable_key, second_spatial_variable_value), + ) - # Define subplot position + # Store variables and subplot position + self.variables[key] = variables self.subplot_positions[key] = (self.n_rows, self.n_cols, k + 1) def reset_axis(self): @@ -217,59 +281,72 @@ def reset_axis(self): """ self.axis = {} for key, variable_lists in self.variables.items(): - if variable_lists[0][0].dimensions == 1: + if variable_lists[0][0].dimensions == 0: spatial_var_name, spatial_var_value = "x", None x_min = self.min_t x_max = self.max_t - elif variable_lists[0][0].dimensions == 2: + elif variable_lists[0][0].dimensions == 1: spatial_var_name, spatial_var_value = self.spatial_variable[key] - if spatial_var_name == "r": - if "negative" in key[0].lower(): - spatial_var_scaled = ( - spatial_var_value * self.spatial_scales["r_n"] - ) - elif "positive" in key[0].lower(): - spatial_var_scaled = ( - spatial_var_value * self.spatial_scales["r_p"] - ) - else: - spatial_var_scaled = ( - spatial_var_value * self.spatial_scales[spatial_var_name] - ) + spatial_scale = get_spatial_scale( + key, spatial_var_name, self.spatial_scales + ) + spatial_var_scaled = spatial_var_value * spatial_scale x_min = spatial_var_scaled[0] x_max = spatial_var_scaled[-1] - - # Get min and max y values - y_min = np.min( - [ - ax_min( - var( - self.ts[i], - **{spatial_var_name: spatial_var_value}, - warn=False + elif variable_lists[0][0].dimensions == 2: + spatial_vars = self.spatial_variable[key] + # First spatial variable + first_spatial_var_name, first_spatial_var_value = spatial_vars[0] + first_spatial_scale = get_spatial_scale( + key, first_spatial_var_name, self.spatial_scales + ) + first_spatial_var_scaled = first_spatial_var_value * first_spatial_scale + x_min = first_spatial_var_scaled[0] + x_max = first_spatial_var_scaled[-1] + # Second spatial variable + second_spatial_var_name, second_spatial_var_value = spatial_vars[1] + second_spatial_scale = get_spatial_scale( + key, second_spatial_var_name, self.spatial_scales + ) + second_spatial_var_scaled = ( + second_spatial_var_value * second_spatial_scale + ) + y_min = second_spatial_var_scaled[0] + y_max = second_spatial_var_scaled[-1] + self.axis[key] = [x_min, x_max, y_min, y_max] + + # Get min and max variable values + if variable_lists[0][0].dimensions in [0, 1]: + var_min = np.min( + [ + ax_min( + var( + self.ts[i], + **{spatial_var_name: spatial_var_value}, + warn=False + ) ) - ) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - y_max = np.max( - [ - ax_max( - var( - self.ts[i], - **{spatial_var_name: spatial_var_value}, - warn=False + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + var_max = np.max( + [ + ax_max( + var( + self.ts[i], + **{spatial_var_name: spatial_var_value}, + warn=False + ) ) - ) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - if y_min == y_max: - y_min -= 1 - y_max += 1 - self.axis[key] = [x_min, x_max, y_min, y_max] + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + if var_min == var_max: + var_min -= 1 + var_max += 1 + self.axis[key] = [x_min, x_max, var_min, var_max] def plot(self, t): """Produces a quick plot with the internal states at time t. @@ -283,14 +360,12 @@ def plot(self, t): import matplotlib.pyplot as plt t /= self.time_scale - self.fig, self.ax = plt.subplots(self.n_rows, self.n_cols, figsize=(15, 8)) + self.fig, self.ax = plt.subplots(self.n_rows, self.n_cols, figsize=self.figsize) plt.tight_layout() plt.subplots_adjust(left=-0.1) self.plots = {} self.time_lines = {} - colors = self.colors or ["r", "b", "k", "g", "m", "c"] - linestyles = self.linestyles or ["-", ":", "--", "-."] fontsize = 42 // self.n_cols for k, (key, variable_lists) in enumerate(self.variables.items()): @@ -303,30 +378,8 @@ def plot(self, t): ax.xaxis.set_major_locator(plt.MaxNLocator(3)) self.plots[key] = defaultdict(dict) # Set labels for the first subplot only (avoid repetition) - if variable_lists[0][0].dimensions == 2: - # 2D plot: plot as a function of x at time t - spatial_var_name, spatial_var_value = self.spatial_variable[key] - ax.set_xlabel(spatial_var_name + " [m]", fontsize=fontsize) - for i, variable_list in enumerate(variable_lists): - for j, variable in enumerate(variable_list): - if spatial_var_name == "r": - if "negative" in key[0].lower(): - spatial_scale = self.spatial_scales["r_n"] - elif "positive" in key[0].lower(): - spatial_scale = self.spatial_scales["r_p"] - else: - spatial_scale = self.spatial_scales[spatial_var_name] - (self.plots[key][i][j],) = ax.plot( - spatial_var_value * spatial_scale, - variable( - t, **{spatial_var_name: spatial_var_value}, warn=False - ), - lw=2, - color=colors[i], - linestyle=linestyles[j], - ) - else: - # 1D plot: plot as a function of time, indicating time t with a line + if variable_lists[0][0].dimensions == 0: + # 0D plot: plot as a function of time, indicating time t with a line ax.set_xlabel("Time [h]", fontsize=fontsize) for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): @@ -335,13 +388,61 @@ def plot(self, t): full_t * self.time_scale, variable(full_t, warn=False), lw=2, - color=colors[i], - linestyle=linestyles[j], + color=self.colors[i], + linestyle=self.linestyles[j], ) y_min, y_max = self.axis[key][2:] (self.time_lines[key],) = ax.plot( [t * self.time_scale, t * self.time_scale], [y_min, y_max], "k--" ) + elif variable_lists[0][0].dimensions == 1: + # 1D plot: plot as a function of x at time t + spatial_var_name, spatial_var_value = self.spatial_variable[key] + ax.set_xlabel(spatial_var_name + " [m]", fontsize=fontsize) + for i, variable_list in enumerate(variable_lists): + for j, variable in enumerate(variable_list): + spatial_scale = get_spatial_scale( + key, spatial_var_name, self.spatial_scales + ) + (self.plots[key][i][j],) = ax.plot( + spatial_var_value * spatial_scale, + variable( + t, **{spatial_var_name: spatial_var_value}, warn=False + ), + lw=2, + color=self.colors[i], + linestyle=self.linestyles[j], + ) + elif variable_lists[0][0].dimensions == 2: + # 2D plot: plot as a function of x and y at time t + spatial_vars = self.spatial_variable[key] + # first spatial variable + first_spatial_var_name, first_spatial_var_value = spatial_vars[0] + ax.set_xlabel(first_spatial_var_name + " [m]", fontsize=fontsize) + first_spatial_scale = get_spatial_scale( + key, first_spatial_var_name, self.spatial_scales + ) + # second spatial variable + second_spatial_var_name, second_spatial_var_value = spatial_vars[1] + ax.set_ylabel(second_spatial_var_name + " [m]", fontsize=fontsize) + second_spatial_scale = get_spatial_scale( + key, second_spatial_var_name, self.spatial_scales + ) + # there can only be one entry in the variable list + variable = variable_lists[0][0] + self.plots[key][0][0] = ax.contourf( + second_spatial_var_value * second_spatial_scale, + first_spatial_var_value * first_spatial_scale, + variable( + t, + **{ + first_spatial_var_name: first_spatial_var_value, + second_spatial_var_name: second_spatial_var_value, + }, + warn=False + ), + ) + # Set either y label or legend entries if len(key) == 1: title = split_long_string(key[0]) @@ -353,8 +454,9 @@ def plot(self, t): fontsize=8, loc="lower center", ) - if k == len(self.variables) - 1: - ax.legend(self.labels, loc="upper right", bbox_to_anchor=(1, -0.2)) + + # Set global legend + # self.fig.legend(self.labels, loc="lower right") def dynamic_plot(self, testing=False): """ @@ -388,7 +490,9 @@ def update(self, val): t = self.sfreq.val t_dimensionless = t / self.time_scale for key, plot in self.plots.items(): - if self.variables[key][0][0].dimensions == 2: + if self.variables[key][0][0].dimensions == 0: + self.time_lines[key].set_xdata([t]) + if self.variables[key][0][0].dimensions == 1: spatial_var_name, spatial_var_value = self.spatial_variable[key] for i, variable_lists in enumerate(self.variables[key]): for j, variable in enumerate(variable_lists): @@ -399,7 +503,5 @@ def update(self, val): warn=False ) ) - else: - self.time_lines[key].set_xdata([t]) self.fig.canvas.draw_idle() diff --git a/tests/integration/test_models/standard_output_comparison.py b/tests/integration/test_models/standard_output_comparison.py index b36ac1a7bc..e96a060669 100644 --- a/tests/integration/test_models/standard_output_comparison.py +++ b/tests/integration/test_models/standard_output_comparison.py @@ -68,9 +68,9 @@ def compare(self, var, tol=1e-2): var0 = model_variables[0] spatial_pts = {} - if var0.dimensions >= 2: + if var0.dimensions >= 1: spatial_pts[var0.first_dimension] = var0.first_dim_pts - if var0.dimensions >= 3: + if var0.dimensions >= 2: spatial_pts[var0.second_dimension] = var0.second_dim_pts # Calculate tolerance based on the value of var0 diff --git a/tests/integration/test_quick_plot.py b/tests/integration/test_quick_plot.py index 9be762f7b0..722c4cc8bc 100644 --- a/tests/integration/test_quick_plot.py +++ b/tests/integration/test_quick_plot.py @@ -29,12 +29,16 @@ def test_plot_lithium_ion(self): # update the axis new_axis = [0, 0.5, 0, 1] - quick_plot.axis.update({("Electrolyte concentration",): new_axis}) - self.assertEqual(quick_plot.axis[("Electrolyte concentration",)], new_axis) + quick_plot.axis.update({("Electrolyte concentration [mol.m-3]",): new_axis}) + self.assertEqual( + quick_plot.axis[("Electrolyte concentration [mol.m-3]",)], new_axis + ) # and now reset them quick_plot.reset_axis() - self.assertNotEqual(quick_plot.axis[("Electrolyte concentration",)], new_axis) + self.assertNotEqual( + quick_plot.axis[("Electrolyte concentration [mol.m-3]",)], new_axis + ) # check dynamic plot loads quick_plot.dynamic_plot(testing=True) @@ -43,9 +47,9 @@ def test_plot_lithium_ion(self): # Test with different output variables output_vars = [ - "Negative particle surface concentration", - "Electrolyte concentration", - "Positive particle surface concentration", + "Negative particle surface concentration [mol.m-3]", + "Electrolyte concentration [mol.m-3]", + "Positive particle surface concentration [mol.m-3]", ] quick_plot = pybamm.QuickPlot(solution_spm, output_vars) self.assertEqual(len(quick_plot.axis), 3) @@ -53,12 +57,16 @@ def test_plot_lithium_ion(self): # update the axis new_axis = [0, 0.5, 0, 1] - quick_plot.axis.update({("Electrolyte concentration",): new_axis}) - self.assertEqual(quick_plot.axis[("Electrolyte concentration",)], new_axis) + quick_plot.axis.update({("Electrolyte concentration [mol.m-3]",): new_axis}) + self.assertEqual( + quick_plot.axis[("Electrolyte concentration [mol.m-3]",)], new_axis + ) # and now reset them quick_plot.reset_axis() - self.assertNotEqual(quick_plot.axis[("Electrolyte concentration",)], new_axis) + self.assertNotEqual( + quick_plot.axis[("Electrolyte concentration [mol.m-3]",)], new_axis + ) # check dynamic plot loads quick_plot.dynamic_plot(testing=True) diff --git a/tests/unit/test_quick_plot.py b/tests/unit/test_quick_plot.py index 2abe645568..a94570f4b6 100644 --- a/tests/unit/test_quick_plot.py +++ b/tests/unit/test_quick_plot.py @@ -108,16 +108,56 @@ def test_simple_ode_model(self): quick_plot = pybamm.QuickPlot(solution) quick_plot.plot(0) + # Test different inputs + quick_plot = pybamm.QuickPlot( + [solution, solution], + colors=["r", "g", "b"], + linestyles=["-", "--"], + figsize=(1, 2), + labels=["sol 1", "sol 2"], + ) + self.assertEqual(quick_plot.colors, ["r", "g", "b"]) + self.assertEqual(quick_plot.linestyles, ["-", "--"]) + self.assertEqual(quick_plot.figsize, (1, 2)) + self.assertEqual(quick_plot.labels, ["sol 1", "sol 2"]) + + # Test different time formats + quick_plot = pybamm.QuickPlot(solution) + self.assertEqual(quick_plot.time_scale, 1) + quick_plot = pybamm.QuickPlot(solution, time_format="seconds") + self.assertEqual(quick_plot.time_scale, 1) + quick_plot = pybamm.QuickPlot(solution, time_format="minutes") + self.assertEqual(quick_plot.time_scale, 1 / 60) + quick_plot = pybamm.QuickPlot(solution, time_format="hours") + self.assertEqual(quick_plot.time_scale, 1 / 3600) + with self.assertRaisesRegex(ValueError, "time format"): + pybamm.QuickPlot(solution, time_format="bad format") + # long solution defaults to hours instead of seconds + solution_long = solver.solve(model, np.linspace(0, 1e5)) + quick_plot = pybamm.QuickPlot(solution_long) + self.assertEqual(quick_plot.time_scale, 1 / 3600) + # Test errors - with self.assertRaisesRegex(ValueError, "mismatching variable domains"): + with self.assertRaisesRegex(ValueError, "Mismatching variable domains"): pybamm.QuickPlot(solution, [["a", "b broadcasted"]]) model.variables["3D variable"] = disc.process_symbol( pybamm.FullBroadcast( 1, "negative particle", {"secondary": "negative electrode"} ) ) - with self.assertRaisesRegex(NotImplementedError, "cannot plot 3D variables"): - pybamm.QuickPlot(solution, ["3D variable"]) + with self.assertRaisesRegex(NotImplementedError, "Cannot plot 3D variables"): + pybamm.QuickPlot([solution, solution], ["3D variable"]) + with self.assertRaisesRegex(ValueError, "labels"): + quick_plot = pybamm.QuickPlot( + [solution, solution], labels=["sol 1", "sol 2", "sol 3"] + ) + + # No variable can be NaN + model.variables["NaN variable"] = disc.process_symbol(pybamm.Scalar(np.nan)) + with self.assertRaisesRegex( + ValueError, "All-NaN variable 'NaN variable' provided" + ): + pybamm.QuickPlot(solution, ["NaN variable"]) def test_loqs_spm_base(self): t_eval = np.linspace(0, 10, 2)