diff --git a/.flake8 b/.flake8 index 5915461f29..bad0d2ef96 100644 --- a/.flake8 +++ b/.flake8 @@ -12,6 +12,7 @@ exclude= share, pyvenv.cfg, third-party, + sundials-5.0.0, ignore= # False positive for white space before ':' on list slice # black should format these correctly diff --git a/pybamm/quick_plot.py b/pybamm/quick_plot.py index a801120baf..901e347e80 100644 --- a/pybamm/quick_plot.py +++ b/pybamm/quick_plot.py @@ -242,9 +242,7 @@ def __init__( else: variable = list(variable_tuple) try: - self.variable_limits[variable_tuple] = variable_limits[ - variable - ] + self.variable_limits[variable_tuple] = variable_limits[variable] except KeyError: # if variable_tuple is not provided, default to "fixed" self.variable_limits[variable_tuple] = "fixed" @@ -402,7 +400,6 @@ def reset_axis(self): variables in each subplot """ self.axis_limits = {} - self.var_limits = {} for key, variable_lists in self.variables.items(): if variable_lists[0][0].dimensions == 0: x_min = self.min_t @@ -457,7 +454,7 @@ def reset_axis(self): if variable_lists[0][0].dimensions in [0, 1]: self.axis_limits[key] = [x_min, x_max, var_min, var_max] else: - self.var_limits[key] = (var_min, var_max) + self.variable_limits[key] = (var_min, var_max) def plot(self, t): """Produces a quick plot with the internal states at time t. @@ -587,14 +584,19 @@ def plot(self, t): ax.set_ylabel( "{} [{}]".format(y_name, self.spatial_unit), fontsize=fontsize ) - vmin, vmax = self.var_limits[key] - contour = ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax) + vmin, vmax = self.variable_limits[key] + ax.contourf( + x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm" + ) if vmin is None and vmax is None: - self.colorbars[key] = self.fig.colorbar(contour, ax=ax) - else: - self.fig.colorbar( - cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)), ax=ax - ) + vmin = ax_min(var) + vmax = ax_max(var) + self.colorbars[key] = self.fig.colorbar( + cm.ScalarMappable( + colors.Normalize(vmin=vmin, vmax=vmax), cmap="coolwarm" + ), + ax=ax, + ) # Set either y label or legend entries if len(key) == 1: title = split_long_string(key[0]) @@ -661,6 +663,8 @@ def slider_update(self, t): """ Update the plot in self.plot() with values at new time """ + from matplotlib import cm, colors + t_dimensionless = t / self.time_scale for k, (key, plot) in enumerate(self.plots.items()): ax = self.axes[k] @@ -699,7 +703,7 @@ def slider_update(self, t): spatial_vars = self.spatial_variable_dict[key] # there can only be one entry in the variable list variable = self.variables[key][0][0] - vmin, vmax = self.var_limits[key] + vmin, vmax = self.variable_limits[key] if self.is_x_r[key] is True: x = self.second_dimensional_spatial_variable[key] y = self.first_dimensional_spatial_variable[key] @@ -708,10 +712,17 @@ def slider_update(self, t): x = self.first_dimensional_spatial_variable[key] y = self.second_dimensional_spatial_variable[key] var = variable(t_dimensionless, **spatial_vars, warn=False).T - ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax) - if self.variable_limits[key] == "tight": + ax.contourf( + x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm" + ) + if (vmin, vmax) == (None, None): + vmin = ax_min(var) + vmax = ax_max(var) cb = self.colorbars[key] - cb.set_clim(vmin=np.nanmin(var), vmax=np.nanmax(var)) - cb.draw_all() + cb.update_bruteforce( + cm.ScalarMappable( + colors.Normalize(vmin=vmin, vmax=vmax), cmap="coolwarm" + ) + ) self.fig.canvas.draw_idle()