From b1377bcb9a6e6f7f6e149c315175d2a7de447db9 Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Tue, 10 Mar 2020 22:11:58 -0400 Subject: [PATCH] #871 more work on 3d --- examples/scripts/SPMe.py | 4 +- pybamm/quick_plot.py | 277 ++++++++++++++++++---------------- tests/unit/test_quick_plot.py | 94 +++++++----- 3 files changed, 211 insertions(+), 164 deletions(-) diff --git a/examples/scripts/SPMe.py b/examples/scripts/SPMe.py index 7f1615053c..7152f57f9a 100644 --- a/examples/scripts/SPMe.py +++ b/examples/scripts/SPMe.py @@ -31,5 +31,7 @@ solution = model.default_solver.solve(model, t_eval) # plot -plot = pybamm.QuickPlot(solution, ["Negative particle concentration"]) +plot = pybamm.QuickPlot( + solution, ["Negative particle concentration"], spatial_format="um" +) plot.dynamic_plot() diff --git a/pybamm/quick_plot.py b/pybamm/quick_plot.py index 1b803bfb66..b61ef9de89 100644 --- a/pybamm/quick_plot.py +++ b/pybamm/quick_plot.py @@ -37,19 +37,6 @@ def split_long_string(title, max_words=4): return first_line + "\n" + second_line -def get_spatial_scale(key, spatial_var_name, spatial_scales): - "Return the appropriate spatial scale" - if spatial_var_name == "r": - if "negative" in key[0].lower(): - spatial_scale = spatial_scales["r_n"] - elif "positive" in key[0].lower(): - spatial_scale = spatial_scales["r_p"] - else: - spatial_scale = spatial_scales[spatial_var_name] - - return spatial_scale - - class QuickPlot(object): """ Generates a quick plot of a subset of key outputs of the model so that the model @@ -75,6 +62,13 @@ class QuickPlot(object): ["r", "b", "k", "g", "m", "c"] linestyles : list of str, optional The linestyles to loop over when plotting. Defaults to ["-", ":", "--", "-."] + figsize : tuple of floats + The size of the figure to make + time_format : str + Format for the time output ("hours", "minutes" or "seconds") + spatial_format : str + Format for the spatial axes ("m", "mm" or "um") + """ def __init__( @@ -86,6 +80,7 @@ def __init__( linestyles=None, figsize=None, time_format=None, + spatial_format="m", ): if isinstance(solutions, pybamm.Solution): solutions = [solutions] @@ -112,28 +107,41 @@ def __init__( self.figsize = figsize or (15, 8) # Spatial scales (default to 1 if information not in model) + if spatial_format == "m": + spatial_factor = 1 + elif spatial_format == "mm": + spatial_factor = 1e3 + elif spatial_format == "um": # micrometers + spatial_factor = 1e6 + else: + raise ValueError( + "spatial format '{}' not recognized".format(spatial_format) + ) + + self.spatial_format = spatial_format + variables = models[0].variables self.spatial_scales = {"x": 1, "y": 1, "z": 1, "r_n": 1, "r_p": 1} if "x [m]" and "x" in variables: self.spatial_scales["x"] = (variables["x [m]"] / variables["x"]).evaluate()[ -1 - ] + ] * spatial_factor if "y [m]" and "y" in variables: self.spatial_scales["y"] = (variables["y [m]"] / variables["y"]).evaluate()[ -1 - ] + ] * spatial_factor if "z [m]" and "z" in variables: self.spatial_scales["z"] = (variables["z [m]"] / variables["z"]).evaluate()[ -1 - ] + ] * spatial_factor if "r_n [m]" and "r_n" in variables: self.spatial_scales["r_n"] = ( variables["r_n [m]"] / variables["r_n"] - ).evaluate()[-1] + ).evaluate()[-1] * spatial_factor if "r_p [m]" and "r_p" in variables: self.spatial_scales["r_p"] = ( variables["r_p [m]"] / variables["r_p"] - ).evaluate()[-1] + ).evaluate()[-1] * spatial_factor # Time parameters model_timescale_in_seconds = models[0].timescale_eval @@ -189,7 +197,9 @@ def __init__( def set_output_variables(self, output_variables, solutions): # Set up output variables self.variables = {} - self.spatial_variable = {} + self.spatial_variable_dict = {} + self.first_dimensional_spatial_variable = {} + self.second_dimensional_spatial_variable = {} # Calculate subplot positions based on number of variables supplied self.subplot_positions = {} @@ -214,10 +224,8 @@ def set_output_variables(self, output_variables, solutions): # process each variable in variable_list for each model for i, solution in enumerate(solutions): - # variables lists of lists + # variables lists of lists, so variables[i] is a list variables[i] = [] - # first index is the solution number - # second index is the variable number for var in variable_list: sol = solution[var] # Check variable isn't all-nan @@ -232,7 +240,7 @@ def set_output_variables(self, output_variables, solutions): first_solution = variables[0] first_variable = first_solution[0] domain = first_variable.domain - # check all other variables against the first variable + # check all other solutions against the first solution for idx, variable in enumerate(first_solution): if variable.domain != domain: raise ValueError( @@ -244,35 +252,80 @@ def set_output_variables(self, output_variables, solutions): # Set the x variable (i.e. "x" or "r" for any one-dimensional variables) if first_variable.dimensions == 1: - spatial_variable_key = first_variable.first_dimension - spatial_variable_value = first_variable.first_dim_pts - self.spatial_variable[key] = ( - spatial_variable_key, - spatial_variable_value, - ) + ( + spatial_var_name, + spatial_var_value, + spatial_var_value_dimensional, + ) = self.get_spatial_var(key, first_variable, "first") + self.spatial_variable_dict[key] = {spatial_var_name: spatial_var_value} + self.first_dimensional_spatial_variable[ + key + ] = spatial_var_value_dimensional - # Don't allow 2D variables if there are multiple solutions elif first_variable.dimensions == 2: + # Don't allow 2D variables if there are multiple solutions if len(variables) > 1: raise NotImplementedError( - "Cannot plot 3D variables when comparing multiple solutions, " - "but {} is 3D".format() + "Cannot plot 2D variables when comparing multiple solutions, " + "but {} is 2D".format(key[0]) ) + # But do allow if just a single solution else: # Add both spatial variables to the keys - first_spatial_variable_key = first_variable.first_dimension - first_spatial_variable_value = first_variable.first_dim_pts - second_spatial_variable_key = first_variable.second_dimension - second_spatial_variable_value = first_variable.second_dim_pts - self.spatial_variable[key] = ( - (first_spatial_variable_key, first_spatial_variable_value), - (second_spatial_variable_key, second_spatial_variable_value), - ) + ( + first_spatial_var_name, + first_spatial_var_value, + first_spatial_var_value_dimensional, + ) = self.get_spatial_var(key, first_variable, "first") + ( + second_spatial_var_name, + second_spatial_var_value, + second_spatial_var_value_dimensional, + ) = self.get_spatial_var(key, first_variable, "second") + self.spatial_variable_dict[key] = { + first_spatial_var_name: first_spatial_var_value, + second_spatial_var_name: second_spatial_var_value, + } + self.first_dimensional_spatial_variable[ + key + ] = first_spatial_var_value_dimensional + self.second_dimensional_spatial_variable[ + key + ] = second_spatial_var_value_dimensional # Store variables and subplot position self.variables[key] = variables self.subplot_positions[key] = (self.n_rows, self.n_cols, k + 1) + def get_spatial_var(self, key, variable, dimension): + "Return the appropriate spatial variable(s)" + + # Extract name and dimensionless value + if dimension == "first": + spatial_var_name = variable.first_dimension + spatial_var_value = variable.first_dim_pts + elif dimension == "second": + spatial_var_name = variable.second_dimension + spatial_var_value = variable.second_dim_pts + + # Get scale + if spatial_var_name == "r": + if "negative" in key[0].lower(): + spatial_scale = self.spatial_scales["r_n"] + elif "positive" in key[0].lower(): + spatial_scale = self.spatial_scales["r_p"] + else: + raise NotImplementedError( + "Cannot determine the spatial scale for '{}'".format(key[0]) + ) + else: + spatial_scale = self.spatial_scales[spatial_var_name] + + # Get dimensional variable + spatial_var_value_dim = spatial_var_value * spatial_scale + + return spatial_var_name, spatial_var_value, spatial_var_value_dim + def reset_axis(self): """ Reset the axis limits to the default values. @@ -282,70 +335,46 @@ def reset_axis(self): self.axis = {} for key, variable_lists in self.variables.items(): if variable_lists[0][0].dimensions == 0: - spatial_var_name, spatial_var_value = "x", None x_min = self.min_t x_max = self.max_t + spatial_vars = {} elif variable_lists[0][0].dimensions == 1: - spatial_var_name, spatial_var_value = self.spatial_variable[key] - spatial_scale = get_spatial_scale( - key, spatial_var_name, self.spatial_scales - ) - spatial_var_scaled = spatial_var_value * spatial_scale - x_min = spatial_var_scaled[0] - x_max = spatial_var_scaled[-1] + x_min = self.first_dimensional_spatial_variable[key][0] + x_max = self.first_dimensional_spatial_variable[key][-1] + # Read dictionary of spatial variables + spatial_vars = self.spatial_variable_dict[key] elif variable_lists[0][0].dimensions == 2: - spatial_vars = self.spatial_variable[key] # First spatial variable - first_spatial_var_name, first_spatial_var_value = spatial_vars[0] - first_spatial_scale = get_spatial_scale( - key, first_spatial_var_name, self.spatial_scales - ) - first_spatial_var_scaled = first_spatial_var_value * first_spatial_scale - x_min = first_spatial_var_scaled[0] - x_max = first_spatial_var_scaled[-1] - # Second spatial variable - second_spatial_var_name, second_spatial_var_value = spatial_vars[1] - second_spatial_scale = get_spatial_scale( - key, second_spatial_var_name, self.spatial_scales - ) - second_spatial_var_scaled = ( - second_spatial_var_value * second_spatial_scale - ) - y_min = second_spatial_var_scaled[0] - y_max = second_spatial_var_scaled[-1] + x_min = self.first_dimensional_spatial_variable[key][0] + x_max = self.first_dimensional_spatial_variable[key][-1] + y_min = self.second_dimensional_spatial_variable[key][0] + y_max = self.second_dimensional_spatial_variable[key][-1] + + # Read dictionary of spatial variables + spatial_vars = self.spatial_variable_dict[key] + + # Create axis for contour plot self.axis[key] = [x_min, x_max, y_min, y_max] # Get min and max variable values + var_min = np.min( + [ + ax_min(var(self.ts[i], **spatial_vars, warn=False)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + var_max = np.max( + [ + ax_max(var(self.ts[i], **spatial_vars, warn=False)) + for i, variable_list in enumerate(variable_lists) + for var in variable_list + ] + ) + if var_min == var_max: + var_min -= 1 + var_max += 1 if variable_lists[0][0].dimensions in [0, 1]: - var_min = np.min( - [ - ax_min( - var( - self.ts[i], - **{spatial_var_name: spatial_var_value}, - warn=False - ) - ) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - var_max = np.max( - [ - ax_max( - var( - self.ts[i], - **{spatial_var_name: spatial_var_value}, - warn=False - ) - ) - for i, variable_list in enumerate(variable_lists) - for var in variable_list - ] - ) - if var_min == var_max: - var_min -= 1 - var_max += 1 self.axis[key] = [x_min, x_max, var_min, var_max] def plot(self, t): @@ -366,7 +395,10 @@ def plot(self, t): self.plots = {} self.time_lines = {} - fontsize = 42 // self.n_cols + if self.n_cols == 1: + fontsize = 30 + else: + fontsize = 42 // self.n_cols for k, (key, variable_lists) in enumerate(self.variables.items()): if len(self.variables) == 1: @@ -397,50 +429,40 @@ def plot(self, t): ) elif variable_lists[0][0].dimensions == 1: # 1D plot: plot as a function of x at time t - spatial_var_name, spatial_var_value = self.spatial_variable[key] - ax.set_xlabel(spatial_var_name + " [m]", fontsize=fontsize) + # Read dictionary of spatial variables + spatial_vars = self.spatial_variable_dict[key] + spatial_var_name = list(spatial_vars.keys())[0] + ax.set_xlabel( + "{} [{}]".format(spatial_var_name, self.spatial_format), + fontsize=fontsize, + ) for i, variable_list in enumerate(variable_lists): for j, variable in enumerate(variable_list): - spatial_scale = get_spatial_scale( - key, spatial_var_name, self.spatial_scales - ) (self.plots[key][i][j],) = ax.plot( - spatial_var_value * spatial_scale, - variable( - t, **{spatial_var_name: spatial_var_value}, warn=False - ), + self.first_dimensional_spatial_variable[key], + variable(t, **spatial_vars, warn=False), lw=2, color=self.colors[i], linestyle=self.linestyles[j], ) elif variable_lists[0][0].dimensions == 2: # 2D plot: plot as a function of x and y at time t - spatial_vars = self.spatial_variable[key] - # first spatial variable - first_spatial_var_name, first_spatial_var_value = spatial_vars[0] - ax.set_xlabel(first_spatial_var_name + " [m]", fontsize=fontsize) - first_spatial_scale = get_spatial_scale( - key, first_spatial_var_name, self.spatial_scales + # Read dictionary of spatial variables + spatial_vars = self.spatial_variable_dict[key] + x_name = list(spatial_vars.keys())[0][0] + y_name = list(spatial_vars.keys())[1][0] + ax.set_xlabel( + "{} [{}]".format(x_name, self.spatial_format), fontsize=fontsize ) - # second spatial variable - second_spatial_var_name, second_spatial_var_value = spatial_vars[1] - ax.set_ylabel(second_spatial_var_name + " [m]", fontsize=fontsize) - second_spatial_scale = get_spatial_scale( - key, second_spatial_var_name, self.spatial_scales + ax.set_ylabel( + "{} [{}]".format(y_name, self.spatial_format), fontsize=fontsize ) # there can only be one entry in the variable list variable = variable_lists[0][0] self.plots[key][0][0] = ax.contourf( - second_spatial_var_value * second_spatial_scale, - first_spatial_var_value * first_spatial_scale, - variable( - t, - **{ - first_spatial_var_name: first_spatial_var_value, - second_spatial_var_name: second_spatial_var_value, - }, - warn=False - ), + self.second_dimensional_spatial_variable[key], + self.first_dimensional_spatial_variable[key], + variable(t, **spatial_vars, warn=False), ) # Set either y label or legend entries @@ -493,13 +515,12 @@ def update(self, val): if self.variables[key][0][0].dimensions == 0: self.time_lines[key].set_xdata([t]) if self.variables[key][0][0].dimensions == 1: - spatial_var_name, spatial_var_value = self.spatial_variable[key] for i, variable_lists in enumerate(self.variables[key]): for j, variable in enumerate(variable_lists): plot[i][j].set_ydata( variable( t_dimensionless, - **{spatial_var_name: spatial_var_value}, + **self.spatial_variable_dict[key], warn=False ) ) diff --git a/tests/unit/test_quick_plot.py b/tests/unit/test_quick_plot.py index a94570f4b6..a21ee4f982 100644 --- a/tests/unit/test_quick_plot.py +++ b/tests/unit/test_quick_plot.py @@ -137,16 +137,38 @@ def test_simple_ode_model(self): quick_plot = pybamm.QuickPlot(solution_long) self.assertEqual(quick_plot.time_scale, 1 / 3600) - # Test errors - with self.assertRaisesRegex(ValueError, "Mismatching variable domains"): - pybamm.QuickPlot(solution, [["a", "b broadcasted"]]) - model.variables["3D variable"] = disc.process_symbol( + # Test different spatial formats + quick_plot = pybamm.QuickPlot(solution) + self.assertEqual(quick_plot.spatial_format, "m") + quick_plot = pybamm.QuickPlot(solution, spatial_format="m") + self.assertEqual(quick_plot.spatial_format, "m") + quick_plot = pybamm.QuickPlot(solution, spatial_format="mm") + self.assertEqual(quick_plot.spatial_format, "mm") + quick_plot = pybamm.QuickPlot(solution, spatial_format="um") + self.assertEqual(quick_plot.spatial_format, "um") + with self.assertRaisesRegex(ValueError, "spatial format"): + pybamm.QuickPlot(solution, spatial_format="bad format") + + # Test 2D variables + model.variables["2D variable"] = disc.process_symbol( pybamm.FullBroadcast( 1, "negative particle", {"secondary": "negative electrode"} ) ) - with self.assertRaisesRegex(NotImplementedError, "Cannot plot 3D variables"): - pybamm.QuickPlot([solution, solution], ["3D variable"]) + model.variables["Negative 2D variable"] = model.variables["2D variable"] + quick_plot = pybamm.QuickPlot(solution, ["Negative 2D variable"]) + quick_plot.plot(0) + + with self.assertRaisesRegex(NotImplementedError, "Cannot plot 2D variables"): + pybamm.QuickPlot([solution, solution], ["2D variable"]) + with self.assertRaisesRegex( + NotImplementedError, "Cannot determine the spatial scale" + ): + pybamm.QuickPlot(solution, ["2D variable"]) + + # Test errors + with self.assertRaisesRegex(ValueError, "Mismatching variable domains"): + pybamm.QuickPlot(solution, [["a", "b broadcasted"]]) with self.assertRaisesRegex(ValueError, "labels"): quick_plot = pybamm.QuickPlot( [solution, solution], labels=["sol 1", "sol 2", "sol 3"] @@ -159,35 +181,37 @@ def test_simple_ode_model(self): ): pybamm.QuickPlot(solution, ["NaN variable"]) - def test_loqs_spm_base(self): - t_eval = np.linspace(0, 10, 2) - - # SPM - for model in [pybamm.lithium_ion.SPM(), pybamm.lead_acid.LOQS()]: - geometry = model.default_geometry - param = model.default_parameter_values - param.process_model(model) - param.process_geometry(geometry) - mesh = pybamm.Mesh( - geometry, model.default_submesh_types, model.default_var_pts - ) - disc = pybamm.Discretisation(mesh, model.default_spatial_methods) - disc.process_model(model) - solver = model.default_solver - solution = solver.solve(model, t_eval) - pybamm.QuickPlot(solution) - - # test quick plot of particle for spm - if model.name == "Single Particle Model": - output_variables = [ - "X-averaged negative particle concentration [mol.m-3]", - "X-averaged positive particle concentration [mol.m-3]", - ] - pybamm.QuickPlot(solution, output_variables) - - def test_failure(self): - with self.assertRaisesRegex(TypeError, "'solutions' must be"): - pybamm.QuickPlot(1) + # def test_loqs_spm_base(self): + # t_eval = np.linspace(0, 10, 2) + + # # SPM + # for model in [pybamm.lithium_ion.SPM(), pybamm.lead_acid.LOQS()]: + # geometry = model.default_geometry + # param = model.default_parameter_values + # param.process_model(model) + # param.process_geometry(geometry) + # mesh = pybamm.Mesh( + # geometry, model.default_submesh_types, model.default_var_pts + # ) + # disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + # disc.process_model(model) + # solver = model.default_solver + # solution = solver.solve(model, t_eval) + # pybamm.QuickPlot(solution) + + # # test quick plot of particle for spm + # if model.name == "Single Particle Model": + # output_variables = [ + # "X-averaged negative particle concentration [mol.m-3]", + # "X-averaged positive particle concentration [mol.m-3]", + # "Negative particle concentration [mol.m-3]", + # "Positive particle concentration [mol.m-3]", + # ] + # pybamm.QuickPlot(solution, output_variables) + + # def test_failure(self): + # with self.assertRaisesRegex(TypeError, "'solutions' must be"): + # pybamm.QuickPlot(1) if __name__ == "__main__":