diff --git a/CHANGELOG.md b/CHANGELOG.md index f2ba3f2b..863d4028 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- [#418](https://github.com/pybop-team/PyBOP/issues/418) - Wraps the `get_parameter_info` method from PyBaMM to get a dictionary of parameter names and types. - [#413](https://github.com/pybop-team/PyBOP/pull/413) - Adds `DesignCost` functionality to `WeightedCost` class with additional tests. - [#357](https://github.com/pybop-team/PyBOP/pull/357) - Adds `Transformation()` class with `LogTransformation()`, `IdentityTransformation()`, and `ScaledTransformation()`, `ComposedTransformation()` implementations with corresponding examples and tests. - [#427](https://github.com/pybop-team/PyBOP/issues/427) - Adds the nbstripout pre-commit hook to remove unnecessary metadata from notebooks. diff --git a/pybop/models/base_model.py b/pybop/models/base_model.py index a706923c..9e547b65 100644 --- a/pybop/models/base_model.py +++ b/pybop/models/base_model.py @@ -718,3 +718,23 @@ def solver(self): @solver.setter def solver(self, solver): self._solver = solver.copy() if solver is not None else None + + def get_parameter_info(self, print_info: bool = False): + """ + Extracts the parameter names and types and returns them as a dictionary. + """ + if not self.pybamm_model._built: + self.pybamm_model.build_model() + + info = self.pybamm_model.get_parameter_info() + + reduced_info = dict() + for param, param_type in info.values(): + param_name = getattr(param, "name", str(param)) + reduced_info[param_name] = param_type + + if print_info: + for param, param_type in info.values(): + print(param, " : ", param_type) + + return reduced_info diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index b12b3639..551379b7 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -1,3 +1,6 @@ +import sys +from io import StringIO + import numpy as np import pybamm import pytest @@ -368,3 +371,24 @@ def test_non_converged_solution(self): for key in problem.signal: assert np.allclose(output.get(key, [])[0], output.get(key, [])) assert np.allclose(output_S1.get(key, [])[0], output_S1.get(key, [])) + + @pytest.mark.unit + def test_get_parameter_info(self, model): + if isinstance(model, pybop.empirical.Thevenin): + # Test at least one model without a built pybamm model + model = pybop.empirical.Thevenin(build=False) + + parameter_info = model.get_parameter_info() + assert isinstance(parameter_info, dict) + + captured_output = StringIO() + sys.stdout = captured_output + + model.get_parameter_info(print_info=True) + sys.stdout = sys.__stdout__ + + printed_messaage = captured_output.getvalue().strip() + + for key, value in parameter_info.items(): + assert key in printed_messaage + assert value in printed_messaage