Skip to content

Commit

Permalink
#871 fix variable limits for 1D plots
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Mar 20, 2020
1 parent 0853d1b commit 8532c55
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 57 deletions.
1 change: 0 additions & 1 deletion examples/scripts/SPMe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,5 @@
],
time_unit="seconds",
spatial_unit="um",
# axis_limits="tight",
)
plot.dynamic_plot()
157 changes: 101 additions & 56 deletions pybamm/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,9 @@ class QuickPlot(object):
Parameters
----------
models: (iter of) :class:`pybamm.BaseModel`
The model(s) to plot the outputs of.
meshes: (iter of) :class:`pybamm.Mesh`
The mesh(es) on which the model(s) were solved.
solutions: (iter of) :class:`pybamm.Solver`
The numerical solution(s) for the model(s) which contained the solution to the
model(s).
solutions: (iter of) :class:`pybamm.Solution` or :class:`pybamm.Simulation`
The numerical solution(s) for the model(s), or the simulation object(s)
containing the solution(s).
output_variables : list of str, optional
List of variables to plot
labels : list of str, optional
Expand All @@ -74,7 +70,7 @@ class QuickPlot(object):
Format for the time output ("hours", "minutes" or "seconds")
spatial_unit : str, optional
Format for the spatial axes ("m", "mm" or "um")
axis_limits : str or dict of str, optional
variable_limits : str or dict of str, optional
How to set the axis limits (for 0D or 1D variables) or colorbar limits (for 2D
variables). Options are:
Expand All @@ -94,7 +90,7 @@ def __init__(
figsize=None,
time_unit=None,
spatial_unit="um",
axis_limits="fixed",
variable_limits="fixed",
):
if isinstance(solutions, (pybamm.Solution, pybamm.Simulation)):
solutions = [solutions]
Expand Down Expand Up @@ -225,7 +221,7 @@ def __init__(
# output_variables is a list of strings or lists, e.g.
# ["var 1", ["variable 2", "var 3"]]
output_variable_tuples = []
self.axis_limits = {}
self.variable_limits = {}
for variable_list in output_variables:
# Make sure we always have a list of lists of variables, e.g.
# [["var 1"], ["variable 2", "var 3"]]
Expand All @@ -237,16 +233,25 @@ def __init__(
output_variable_tuples.append(variable_tuple)

# axis limits
if axis_limits in ["fixed", "tight"]:
self.axis_limits[variable_tuple] = axis_limits
if variable_limits in ["fixed", "tight"]:
self.variable_limits[variable_tuple] = variable_limits
else:
# If there is only one variable, extract it
if len(variable_tuple) == 1:
variable_tuple = variable_tuple[0]
variable = variable_tuple[0]
else:
variable = list(variable_tuple)
try:
self.axis_limits[variable_tuple] = axis_limits[variable_tuple]
self.variable_limits[variable_tuple] = variable_limits[
variable
]
except KeyError:
# if variable_tuple is not provided, default to "fixed"
self.variable_limits[variable_tuple] = "fixed"
except TypeError:
raise TypeError("axis_limits must be 'fixed', 'tight', or a dict")
raise TypeError(
"variable_limits must be 'fixed', 'tight', or a dict"
)

self.set_output_variables(output_variable_tuples, solutions)
self.reset_axis()
Expand Down Expand Up @@ -396,7 +401,7 @@ def reset_axis(self):
These are calculated to fit around the minimum and maximum values of all the
variables in each subplot
"""
self.axis = {}
self.axis_limits = {}
self.var_limits = {}
for key, variable_lists in self.variables.items():
if variable_lists[0][0].dimensions == 0:
Expand All @@ -419,29 +424,38 @@ def reset_axis(self):
y_max = self.second_dimensional_spatial_variable[key][-1]

# Create axis for contour plot
self.axis[key] = [x_min, x_max, y_min, y_max]
self.axis_limits[key] = [x_min, x_max, y_min, y_max]

# Get min and max variable values
spatial_vars = self.spatial_variable_dict[key]
var_min = np.min(
[
ax_min(var(self.ts[i], **spatial_vars, warn=False))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
)
var_max = np.max(
[
ax_max(var(self.ts[i], **spatial_vars, warn=False))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
)
if var_min == var_max:
var_min -= 1
var_max += 1
if self.variable_limits[key] == "fixed":
# fixed variable limits: calculate "globlal" min and max
spatial_vars = self.spatial_variable_dict[key]
var_min = np.min(
[
ax_min(var(self.ts[i], **spatial_vars, warn=False))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
)
var_max = np.max(
[
ax_max(var(self.ts[i], **spatial_vars, warn=False))
for i, variable_list in enumerate(variable_lists)
for var in variable_list
]
)
if var_min == var_max:
var_min -= 1
var_max += 1
elif self.variable_limits[key] == "tight":
# tight variable limits: axes will adjust each time
var_min, var_max = None, None
else:
# user-specified axis limits
var_min, var_max = self.variable_limits[key]

if variable_lists[0][0].dimensions in [0, 1]:
self.axis[key] = [x_min, x_max, var_min, var_max]
self.axis_limits[key] = [x_min, x_max, var_min, var_max]
else:
self.var_limits[key] = (var_min, var_max)

Expand All @@ -464,6 +478,7 @@ def plot(self, t):
self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols)
self.plots = {}
self.time_lines = {}
self.colorbars = {}
self.axes = []

# initialize empty handles, to be created only if the appropriate plots are made
Expand All @@ -477,8 +492,10 @@ def plot(self, t):
for k, (key, variable_lists) in enumerate(self.variables.items()):
ax = self.fig.add_subplot(self.gridspec[k])
self.axes.append(ax)
ax.set_xlim(self.axis[key][:2])
ax.set_ylim(self.axis[key][2:])
x_min, x_max, y_min, y_max = self.axis_limits[key]
ax.set_xlim(x_min, x_max)
if y_min is not None and y_max is not None:
ax.set_ylim(y_min, y_max)
ax.xaxis.set_major_locator(plt.MaxNLocator(3))
self.plots[key] = defaultdict(dict)
variable_handles = []
Expand All @@ -505,7 +522,8 @@ def plot(self, t):
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])
y_min, y_max = self.axis[key][2:]
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min, y_max)
(self.time_lines[key],) = ax.plot(
[t * self.time_scale, t * self.time_scale], [y_min, y_max], "k--"
)
Expand All @@ -518,11 +536,6 @@ def plot(self, t):
"{} [{}]".format(spatial_var_name, self.spatial_unit),
fontsize=fontsize,
)
# add dashed lines for boundaries between subdomains
for bnd in variable_lists[0][0].internal_boundaries:
bnd_dim = bnd * self.first_spatial_scale[key]
y_min, y_max = self.axis[key][2:]
ax.plot([bnd_dim, bnd_dim], [y_min, y_max], color="0.5", lw=1)
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
Expand All @@ -538,9 +551,18 @@ def plot(self, t):
lw=2,
color=self.colors[i],
linestyle=linestyle,
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])
# add dashed lines for boundaries between subdomains
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min, y_max)
for bnd in variable_lists[0][0].internal_boundaries:
bnd_dim = bnd * self.first_spatial_scale[key]
ax.plot(
[bnd_dim, bnd_dim], [y_min, y_max], color="0.5", lw=1, zorder=0
)
elif variable_lists[0][0].dimensions == 2:
# Read dictionary of spatial variables
spatial_vars = self.spatial_variable_dict[key]
Expand All @@ -566,11 +588,13 @@ def plot(self, t):
"{} [{}]".format(y_name, self.spatial_unit), fontsize=fontsize
)
vmin, vmax = self.var_limits[key]
ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax)
self.fig.colorbar(
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)), ax=ax
)

contour = ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax)
if vmin is None and vmax is None:
self.colorbars[key] = self.fig.colorbar(contour, ax=ax)
else:
self.fig.colorbar(
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)), ax=ax
)
# Set either y label or legend entries
if len(key) == 1:
title = split_long_string(key[0])
Expand Down Expand Up @@ -639,20 +663,37 @@ def slider_update(self, t):
"""
t_dimensionless = t / self.time_scale
for k, (key, plot) in enumerate(self.plots.items()):
ax = self.axes[k]
if self.variables[key][0][0].dimensions == 0:
self.time_lines[key].set_xdata([t])
elif self.variables[key][0][0].dimensions == 1:
var_min = np.inf
var_max = -np.inf
for i, variable_lists in enumerate(self.variables[key]):
for j, variable in enumerate(variable_lists):
plot[i][j].set_ydata(
variable(
t_dimensionless,
**self.spatial_variable_dict[key],
warn=False
)
var = variable(
t_dimensionless,
**self.spatial_variable_dict[key],
warn=False
)
plot[i][j].set_ydata(var)
var_min = min(var_min, np.nanmin(var))
var_max = max(var_max, np.nanmax(var))
# add dashed lines for boundaries between subdomains
y_min, y_max = self.axis_limits[key][2:]
if y_min is None and y_max is None:
y_min, y_max = ax_min(var_min), ax_max(var_max)
ax.set_ylim(y_min, y_max)
for bnd in self.variables[key][0][0].internal_boundaries:
bnd_dim = bnd * self.first_spatial_scale[key]
ax.plot(
[bnd_dim, bnd_dim],
[y_min, y_max],
color="0.5",
lw=1,
zorder=0,
)
elif self.variables[key][0][0].dimensions == 2:
ax = self.axes[k]
# 2D plot: plot as a function of x and y at time t
# Read dictionary of spatial variables
spatial_vars = self.spatial_variable_dict[key]
Expand All @@ -668,5 +709,9 @@ def slider_update(self, t):
y = self.second_dimensional_spatial_variable[key]
var = variable(t_dimensionless, **spatial_vars, warn=False).T
ax.contourf(x, y, var, levels=100, vmin=vmin, vmax=vmax)
if self.variable_limits[key] == "tight":
cb = self.colorbars[key]
cb.set_clim(vmin=np.nanmin(var), vmax=np.nanmax(var))
cb.draw_all()

self.fig.canvas.draw_idle()

0 comments on commit 8532c55

Please sign in to comment.