Skip to content

Commit

Permalink
Add get_parameter_info to BaseModel (#423)
Browse files Browse the repository at this point in the history
* Add get_parameter_info with test

* Add print_info option and test

* Test one without built pybamm model
  • Loading branch information
NicolaCourtier authored Aug 2, 2024
1 parent 5b163ae commit 8b09214
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Features

- [#418](https://github.com/pybop-team/PyBOP/issues/418) - Wraps the `get_parameter_info` method from PyBaMM to get a dictionary of parameter names and types.
- [#413](https://github.com/pybop-team/PyBOP/pull/413) - Adds `DesignCost` functionality to `WeightedCost` class with additional tests.
- [#357](https://github.com/pybop-team/PyBOP/pull/357) - Adds `Transformation()` class with `LogTransformation()`, `IdentityTransformation()`, and `ScaledTransformation()`, `ComposedTransformation()` implementations with corresponding examples and tests.
- [#427](https://github.com/pybop-team/PyBOP/issues/427) - Adds the nbstripout pre-commit hook to remove unnecessary metadata from notebooks.
Expand Down
20 changes: 20 additions & 0 deletions pybop/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,23 @@ def solver(self):
@solver.setter
def solver(self, solver):
self._solver = solver.copy() if solver is not None else None

def get_parameter_info(self, print_info: bool = False):
"""
Extracts the parameter names and types and returns them as a dictionary.
"""
if not self.pybamm_model._built:
self.pybamm_model.build_model()

info = self.pybamm_model.get_parameter_info()

reduced_info = dict()
for param, param_type in info.values():
param_name = getattr(param, "name", str(param))
reduced_info[param_name] = param_type

if print_info:
for param, param_type in info.values():
print(param, " : ", param_type)

return reduced_info
24 changes: 24 additions & 0 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
from io import StringIO

import numpy as np
import pybamm
import pytest
Expand Down Expand Up @@ -368,3 +371,24 @@ def test_non_converged_solution(self):
for key in problem.signal:
assert np.allclose(output.get(key, [])[0], output.get(key, []))
assert np.allclose(output_S1.get(key, [])[0], output_S1.get(key, []))

@pytest.mark.unit
def test_get_parameter_info(self, model):
if isinstance(model, pybop.empirical.Thevenin):
# Test at least one model without a built pybamm model
model = pybop.empirical.Thevenin(build=False)

parameter_info = model.get_parameter_info()
assert isinstance(parameter_info, dict)

captured_output = StringIO()
sys.stdout = captured_output

model.get_parameter_info(print_info=True)
sys.stdout = sys.__stdout__

printed_messaage = captured_output.getvalue().strip()

for key, value in parameter_info.items():
assert key in printed_messaage
assert value in printed_messaage

0 comments on commit 8b09214

Please sign in to comment.