diff --git a/examples/scripts/compare_lithium_ion_3D.py b/examples/scripts/compare_lithium_ion_3D.py index a28a5ebce3..1c2810003d 100644 --- a/examples/scripts/compare_lithium_ion_3D.py +++ b/examples/scripts/compare_lithium_ion_3D.py @@ -56,7 +56,6 @@ solutions[i] = solution # plot -# TO DO: plotting 3D variables output_variables = ["Terminal voltage [V]"] plot = pybamm.QuickPlot(solutions, output_variables) plot.dynamic_plot() diff --git a/pybamm/processed_variable.py b/pybamm/processed_variable.py index 88fbfb6cba..dc6a4aeff7 100644 --- a/pybamm/processed_variable.py +++ b/pybamm/processed_variable.py @@ -276,6 +276,8 @@ def initialise_2Dspace_scikit_fem(self): self.z_sol = z_sol self.first_dimension = "y" self.second_dimension = "z" + self.first_dim_pts = y_sol + self.second_dim_pts = z_sol # set up interpolation self._interpolation_function = interp.interp2d( @@ -313,6 +315,8 @@ def initialise_3D_scikit_fem(self): self.z_sol = z_sol self.first_dimension = "y" self.second_dimension = "z" + self.first_dim_pts = y_sol + self.second_dim_pts = z_sol # set up interpolation self._interpolation_function = interp.RegularGridInterpolator( diff --git a/pybamm/quick_plot.py b/pybamm/quick_plot.py index 779f502034..b0c6b6432d 100644 --- a/pybamm/quick_plot.py +++ b/pybamm/quick_plot.py @@ -138,25 +138,33 @@ def __init__( raise ValueError("spatial unit '{}' not recognized".format(spatial_unit)) variables = models[0].variables - self.spatial_scales = {"x": 1, "y": 1, "z": 1, "r_n": 1, "r_p": 1} + # empty spatial scales, will raise error later if can't find a particular one + self.spatial_scales = {} if "x [m]" and "x" in variables: - self.spatial_scales["x"] = (variables["x [m]"] / variables["x"]).evaluate()[ + x_scale = (variables["x [m]"] / variables["x"]).evaluate()[ -1 ] * spatial_factor + self.spatial_scales.update( + { + "negative electrode": x_scale, + "separator": x_scale, + "positive electrode": x_scale, + } + ) if "y [m]" and "y" in variables: - self.spatial_scales["y"] = (variables["y [m]"] / variables["y"]).evaluate()[ - -1 - ] * spatial_factor + self.spatial_scales["current collector y"] = ( + variables["y [m]"] / variables["y"] + ).evaluate()[-1] * spatial_factor if "z [m]" and "z" in variables: - self.spatial_scales["z"] = (variables["z [m]"] / variables["z"]).evaluate()[ - -1 - ] * spatial_factor + self.spatial_scales["current collector z"] = ( + variables["z [m]"] / variables["z"] + ).evaluate()[-1] * spatial_factor if "r_n [m]" and "r_n" in variables: - self.spatial_scales["r_n"] = ( + self.spatial_scales["negative particle"] = ( variables["r_n [m]"] / variables["r_n"] ).evaluate()[-1] * spatial_factor if "r_p [m]" and "r_p" in variables: - self.spatial_scales["r_p"] = ( + self.spatial_scales["positive particle"] = ( variables["r_p [m]"] / variables["r_p"] ).evaluate()[-1] * spatial_factor @@ -169,7 +177,7 @@ def __init__( # Set timescale if time_unit is None: # defaults depend on how long the simulation is - if self.max_t >= 3600: + if max_t >= 3600: time_scaling_factor = 3600 # time in hours self.time_unit = "h" else: @@ -212,9 +220,6 @@ def __init__( "Electrolyte potential [V]", "Terminal voltage [V]", ] - # else plot all variables in first model - else: - output_variables = models[0].variables self.set_output_variables(output_variables, solutions) self.reset_axis() @@ -285,7 +290,9 @@ def set_output_variables(self, output_variables, solutions): spatial_scale, ) = self.get_spatial_var(key, first_variable, "first") self.spatial_variable_dict[key] = {spatial_var_name: spatial_var_value} - self.first_dimensional_spatial_variable[key] = spatial_var_value * spatial_scale + self.first_dimensional_spatial_variable[key] = ( + spatial_var_value * spatial_scale + ) self.first_spatial_scale[key] = spatial_scale elif first_variable.dimensions == 2: @@ -312,8 +319,12 @@ def set_output_variables(self, output_variables, solutions): first_spatial_var_name: first_spatial_var_value, second_spatial_var_name: second_spatial_var_value, } - self.first_dimensional_spatial_variable[key] = first_spatial_var_value * first_spatial_scale - self.second_dimensional_spatial_variable[key] = second_spatial_var_value * second_spatial_scale + self.first_dimensional_spatial_variable[key] = ( + first_spatial_var_value * first_spatial_scale + ) + self.second_dimensional_spatial_variable[key] = ( + second_spatial_var_value * second_spatial_scale + ) # Store variables and subplot position self.variables[key] = variables @@ -323,25 +334,33 @@ def get_spatial_var(self, key, variable, dimension): "Return the appropriate spatial variable(s)" # Extract name and dimensionless value + # Special case for current collector, which is 2D but in a weird way (both + # first and second variables are in the same domain, not auxiliary domain) if dimension == "first": spatial_var_name = variable.first_dimension spatial_var_value = variable.first_dim_pts + domain = variable.domain[0] elif dimension == "second": spatial_var_name = variable.second_dimension spatial_var_value = variable.second_dim_pts + if variable.domain[0] == "current collector": + domain = "current collector" + else: + domain = variable.auxiliary_domains["secondary"][0] + + if domain == "current collector": + domain += " {}".format(spatial_var_name) # Get scale - if spatial_var_name == "r": - if "negative" in key[0].lower(): - spatial_scale = self.spatial_scales["r_n"] - elif "positive" in key[0].lower(): - spatial_scale = self.spatial_scales["r_p"] - else: - raise NotImplementedError( - "Cannot determine the spatial scale for '{}'".format(key[0]) - ) - else: - spatial_scale = self.spatial_scales[spatial_var_name] + try: + spatial_scale = self.spatial_scales[domain] + except KeyError: + raise KeyError( + ( + "Can't find spatial scale for '{}', make sure both '{} [m]' " + + "and '{}' are defined in the model variables" + ).format(domain, *[spatial_var_name] * 2) + ) return spatial_var_name, spatial_var_value, spatial_scale diff --git a/tests/integration/test_quick_plot.py b/tests/integration/test_quick_plot.py index 722c4cc8bc..bc82202b36 100644 --- a/tests/integration/test_quick_plot.py +++ b/tests/integration/test_quick_plot.py @@ -43,7 +43,7 @@ def test_plot_lithium_ion(self): # check dynamic plot loads quick_plot.dynamic_plot(testing=True) - quick_plot.update(0.01) + quick_plot.slider_update(0.01) # Test with different output variables output_vars = [ @@ -71,7 +71,7 @@ def test_plot_lithium_ion(self): # check dynamic plot loads quick_plot.dynamic_plot(testing=True) - quick_plot.update(0.01) + quick_plot.slider_update(0.01) def test_plot_lead_acid(self): loqs = pybamm.lead_acid.LOQS() @@ -87,6 +87,39 @@ def test_plot_lead_acid(self): pybamm.QuickPlot(solution_loqs) + def test_plot_2plus1D_spm(self): + spm = pybamm.lithium_ion.SPM( + {"current collector": "potential pair", "dimensionality": 2} + ) + geometry = spm.default_geometry + param = spm.default_parameter_values + param.process_model(spm) + param.process_geometry(geometry) + var = pybamm.standard_spatial_vars + var_pts = { + var.x_n: 5, + var.x_s: 5, + var.x_p: 5, + var.r_n: 5, + var.r_p: 5, + var.y: 5, + var.z: 5, + } + mesh = pybamm.Mesh(geometry, spm.default_submesh_types, var_pts) + disc_spm = pybamm.Discretisation(mesh, spm.default_spatial_methods) + disc_spm.process_model(spm) + t_eval = np.linspace(0, 3600, 100) + solution_spm = spm.default_solver.solve(spm, t_eval) + + pybamm.QuickPlot( + solution_spm, + [ + "Negative current collector potential [V]", + "Positive current collector potential [V]", + "Terminal voltage [V]", + ], + ) + if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_quick_plot.py b/tests/unit/test_quick_plot.py index aa139e1ae0..d6d5c7db3b 100644 --- a/tests/unit/test_quick_plot.py +++ b/tests/unit/test_quick_plot.py @@ -37,6 +37,12 @@ def test_simple_ode_model(self): "c broadcasted positive electrode": pybamm.PrimaryBroadcast( c, "positive particle" ), + "x [m]": pybamm.standard_spatial_vars.x, + "x": pybamm.standard_spatial_vars.x, + "r_n [m]": pybamm.standard_spatial_vars.r_n, + "r_n": pybamm.standard_spatial_vars.r_n, + "r_p [m]": pybamm.standard_spatial_vars.r_p, + "r_p": pybamm.standard_spatial_vars.r_p, } # ODEs only (don't use jacobian) @@ -53,7 +59,16 @@ def test_simple_ode_model(self): solver = model.default_solver t_eval = np.linspace(0, 2, 100) solution = solver.solve(model, t_eval) - quick_plot = pybamm.QuickPlot(solution) + quick_plot = pybamm.QuickPlot( + solution, + [ + "a", + "b broadcasted", + "c broadcasted", + "b broadcasted negative electrode", + "c broadcasted positive electrode", + ], + ) quick_plot.plot(0) # update the axis @@ -105,12 +120,13 @@ def test_simple_ode_model(self): # Test longer name model.variables["Variable with a very long name"] = model.variables["a"] - quick_plot = pybamm.QuickPlot(solution) + quick_plot = pybamm.QuickPlot(solution, ["Variable with a very long name"]) quick_plot.plot(0) # Test different inputs quick_plot = pybamm.QuickPlot( [solution, solution], + ["a"], colors=["r", "g", "b"], linestyles=["-", "--"], figsize=(1, 2), @@ -122,32 +138,32 @@ def test_simple_ode_model(self): self.assertEqual(quick_plot.labels, ["sol 1", "sol 2"]) # Test different time units - quick_plot = pybamm.QuickPlot(solution) + quick_plot = pybamm.QuickPlot(solution, ["a"]) self.assertEqual(quick_plot.time_scale, 1) - quick_plot = pybamm.QuickPlot(solution, time_unit="seconds") + quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="seconds") self.assertEqual(quick_plot.time_scale, 1) - quick_plot = pybamm.QuickPlot(solution, time_unit="minutes") + quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="minutes") self.assertEqual(quick_plot.time_scale, 1 / 60) - quick_plot = pybamm.QuickPlot(solution, time_unit="hours") + quick_plot = pybamm.QuickPlot(solution, ["a"], time_unit="hours") self.assertEqual(quick_plot.time_scale, 1 / 3600) with self.assertRaisesRegex(ValueError, "time unit"): - pybamm.QuickPlot(solution, time_unit="bad unit") + pybamm.QuickPlot(solution, ["a"], time_unit="bad unit") # long solution defaults to hours instead of seconds solution_long = solver.solve(model, np.linspace(0, 1e5)) - quick_plot = pybamm.QuickPlot(solution_long) + quick_plot = pybamm.QuickPlot(solution_long, ["a"]) self.assertEqual(quick_plot.time_scale, 1 / 3600) # Test different spatial units - quick_plot = pybamm.QuickPlot(solution) + quick_plot = pybamm.QuickPlot(solution, ["a"]) self.assertEqual(quick_plot.spatial_unit, "$\mu m$") - quick_plot = pybamm.QuickPlot(solution, spatial_unit="m") + quick_plot = pybamm.QuickPlot(solution, ["a"], spatial_unit="m") self.assertEqual(quick_plot.spatial_unit, "m") - quick_plot = pybamm.QuickPlot(solution, spatial_unit="mm") + quick_plot = pybamm.QuickPlot(solution, ["a"], spatial_unit="mm") self.assertEqual(quick_plot.spatial_unit, "mm") - quick_plot = pybamm.QuickPlot(solution, spatial_unit="um") + quick_plot = pybamm.QuickPlot(solution, ["a"], spatial_unit="um") self.assertEqual(quick_plot.spatial_unit, "$\mu m$") with self.assertRaisesRegex(ValueError, "spatial unit"): - pybamm.QuickPlot(solution, spatial_unit="bad unit") + pybamm.QuickPlot(solution, ["a"], spatial_unit="bad unit") # Test 2D variables model.variables["2D variable"] = disc.process_symbol( @@ -161,17 +177,13 @@ def test_simple_ode_model(self): with self.assertRaisesRegex(NotImplementedError, "Cannot plot 2D variables"): pybamm.QuickPlot([solution, solution], ["2D variable"]) - with self.assertRaisesRegex( - NotImplementedError, "Cannot determine the spatial scale" - ): - pybamm.QuickPlot(solution, ["2D variable"]) # Test errors with self.assertRaisesRegex(ValueError, "Mismatching variable domains"): pybamm.QuickPlot(solution, [["a", "b broadcasted"]]) with self.assertRaisesRegex(ValueError, "labels"): quick_plot = pybamm.QuickPlot( - [solution, solution], labels=["sol 1", "sol 2", "sol 3"] + [solution, solution], ["a"], labels=["sol 1", "sol 2", "sol 3"] ) # No variable can be NaN @@ -181,50 +193,50 @@ def test_simple_ode_model(self): ): pybamm.QuickPlot(solution, ["NaN variable"]) - # def test_spm_simulation(self): - # # SPM - # model = pybamm.lithium_ion.SPM() - # sim = pybamm.Simulation(model) - - # t_eval = np.linspace(0, 10, 2) - # sim.solve(t_eval) - - # # mixed simulation and solution input - # # solution should be extracted from the simulation - # quick_plot = pybamm.QuickPlot([sim, sim.solution]) - # quick_plot.plot(0) - # - # def test_loqs_spm_base(self): - # t_eval = np.linspace(0, 10, 2) - - # # SPM - # for model in [pybamm.lithium_ion.SPM(), pybamm.lead_acid.LOQS()]: - # geometry = model.default_geometry - # param = model.default_parameter_values - # param.process_model(model) - # param.process_geometry(geometry) - # mesh = pybamm.Mesh( - # geometry, model.default_submesh_types, model.default_var_pts - # ) - # disc = pybamm.Discretisation(mesh, model.default_spatial_methods) - # disc.process_model(model) - # solver = model.default_solver - # solution = solver.solve(model, t_eval) - # pybamm.QuickPlot(solution) - - # # test quick plot of particle for spm - # if model.name == "Single Particle Model": - # output_variables = [ - # "X-averaged negative particle concentration [mol.m-3]", - # "X-averaged positive particle concentration [mol.m-3]", - # "Negative particle concentration [mol.m-3]", - # "Positive particle concentration [mol.m-3]", - # ] - # pybamm.QuickPlot(solution, output_variables) - - # def test_failure(self): - # with self.assertRaisesRegex(TypeError, "'solutions' must be"): - # pybamm.QuickPlot(1) + def test_spm_simulation(self): + # SPM + model = pybamm.lithium_ion.SPM() + sim = pybamm.Simulation(model) + + t_eval = np.linspace(0, 10, 2) + sim.solve(t_eval) + + # mixed simulation and solution input + # solution should be extracted from the simulation + quick_plot = pybamm.QuickPlot([sim, sim.solution]) + quick_plot.plot(0) + + def test_loqs_spm_base(self): + t_eval = np.linspace(0, 10, 2) + + # SPM + for model in [pybamm.lithium_ion.SPM(), pybamm.lead_acid.LOQS()]: + geometry = model.default_geometry + param = model.default_parameter_values + param.process_model(model) + param.process_geometry(geometry) + mesh = pybamm.Mesh( + geometry, model.default_submesh_types, model.default_var_pts + ) + disc = pybamm.Discretisation(mesh, model.default_spatial_methods) + disc.process_model(model) + solver = model.default_solver + solution = solver.solve(model, t_eval) + pybamm.QuickPlot(solution) + + # test quick plot of particle for spm + if model.name == "Single Particle Model": + output_variables = [ + "X-averaged negative particle concentration [mol.m-3]", + "X-averaged positive particle concentration [mol.m-3]", + "Negative particle concentration [mol.m-3]", + "Positive particle concentration [mol.m-3]", + ] + pybamm.QuickPlot(solution, output_variables) + + def test_failure(self): + with self.assertRaisesRegex(TypeError, "solutions must be"): + pybamm.QuickPlot(1) if __name__ == "__main__":