Skip to content

Commit

Permalink
#871 fix 2d plots
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 20, 2020
1 parent 8532c55 commit d95909b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ exclude=
share,
pyvenv.cfg,
third-party,
sundials-5.0.0,
ignore=
# False positive for white space before ':' on list slice
# black should format these correctly
Expand Down
45 changes: 28 additions & 17 deletions pybamm/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ def __init__(
else:
variable = list(variable_tuple)
try:
self.variable_limits[variable_tuple] = variable_limits[
variable
]
self.variable_limits[variable_tuple] = variable_limits[variable]
except KeyError:
# if variable_tuple is not provided, default to "fixed"
self.variable_limits[variable_tuple] = "fixed"
Expand Down Expand Up @@ -402,7 +400,6 @@ def reset_axis(self):
variables in each subplot
"""
self.axis_limits = {}
self.var_limits = {}
for key, variable_lists in self.variables.items():
if variable_lists[0][0].dimensions == 0:
x_min = self.min_t
Expand Down Expand Up @@ -457,7 +454,7 @@ def reset_axis(self):
if variable_lists[0][0].dimensions in [0, 1]:
self.axis_limits[key] = [x_min, x_max, var_min, var_max]
else:
self.var_limits[key] = (var_min, var_max)
self.variable_limits[key] = (var_min, var_max)

def plot(self, t):
"""Produces a quick plot with the internal states at time t.
Expand Down Expand Up @@ -587,14 +584,19 @@ def plot(self, t):
ax.set_ylabel(
"{} [{}]".format(y_name, self.spatial_unit), fontsize=fontsize
)
vmin, vmax = self.var_limits[key]
contour = ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax)
vmin, vmax = self.variable_limits[key]
ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm"
)
if vmin is None and vmax is None:
self.colorbars[key] = self.fig.colorbar(contour, ax=ax)
else:
self.fig.colorbar(
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)), ax=ax
)
vmin = ax_min(var)
vmax = ax_max(var)
self.colorbars[key] = self.fig.colorbar(
cm.ScalarMappable(
colors.Normalize(vmin=vmin, vmax=vmax), cmap="coolwarm"
),
ax=ax,
)
# Set either y label or legend entries
if len(key) == 1:
title = split_long_string(key[0])
Expand Down Expand Up @@ -661,6 +663,8 @@ def slider_update(self, t):
"""
Update the plot in self.plot() with values at new time
"""
from matplotlib import cm, colors

t_dimensionless = t / self.time_scale
for k, (key, plot) in enumerate(self.plots.items()):
ax = self.axes[k]
Expand Down Expand Up @@ -699,7 +703,7 @@ def slider_update(self, t):
spatial_vars = self.spatial_variable_dict[key]
# there can only be one entry in the variable list
variable = self.variables[key][0][0]
vmin, vmax = self.var_limits[key]
vmin, vmax = self.variable_limits[key]
if self.is_x_r[key] is True:
x = self.second_dimensional_spatial_variable[key]
y = self.first_dimensional_spatial_variable[key]
Expand All @@ -708,10 +712,17 @@ def slider_update(self, t):
x = self.first_dimensional_spatial_variable[key]
y = self.second_dimensional_spatial_variable[key]
var = variable(t_dimensionless, **spatial_vars, warn=False).T
ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax)
if self.variable_limits[key] == "tight":
ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm"
)
if (vmin, vmax) == (None, None):
vmin = ax_min(var)
vmax = ax_max(var)
cb = self.colorbars[key]
cb.set_clim(vmin=np.nanmin(var), vmax=np.nanmax(var))
cb.draw_all()
cb.update_bruteforce(
cm.ScalarMappable(
colors.Normalize(vmin=vmin, vmax=vmax), cmap="coolwarm"
)
)

self.fig.canvas.draw_idle()

0 comments on commit d95909b

Please sign in to comment.