diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c9bd978c32..33fb9e9106 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -126,26 +126,18 @@ This allows people to (1) use PyBaMM without importing optional dependencies by **Writing Tests for Optional Dependencies** -Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in `test_util.py`. This ensures that an error is raised upon the absence of said dependency. Here's an example: +Below, we list the currently available test functions to provide an overview. If you find it useful to add new test cases please do so within `tests/unit/test_util.py`. -```python -from tests import TestCase -import pybamm - - -class TestUtil(TestCase): - def test_optional_dependency(self): - # Test that an error is raised when pybtex is not available - with self.assertRaisesRegex( - ModuleNotFoundError, "Optional dependency pybtex is not available" - ): - sys.modules["pybtex"] = None - pybamm.function_using_pybtex(x, y, z) - - # Test that the function works when pybtex is available - sys.modules["pybtex"] = pybamm.util.import_optional_dependency("pybtex") - pybamm.function_using_pybtex(x, y, z) -``` +Currently, there are three functions to test what concerns optional dependencies: +- `test_import_optional_dependency` +- `test_pybamm_import` +- `test_optional_dependencies` + +The `test_import_optional_dependency` function extracts the optional dependencies installed in the setup environment, makes them unimportable (by setting them to `None` among the `sys.modules`), and tests that the `pybamm.util.import_optional_dependency` function throws a `ModuleNotFoundError` exception when their import is attempted. + +The `test_pybamm_import` function extracts the optional dependencies installed in the setup environment and makes them unimportable (by setting them to `None` among the `sys.modules`), unloads `pybamm` and its sub-modules, and finally tests that `pybamm` can be imported successfully. In fact, it is essential that the `pybamm` package is importable with only the mandatory dependencies. + +The `test_optional_dependencies` function extracts `pybamm` mandatory distribution packages and verifies that they are not present in the optional distribution packages list in `pyproject.toml`. This test is crucial for ensuring the consistency of the released package information and potential updates to dependencies during development. ## Testing diff --git a/pybamm/citations.py b/pybamm/citations.py index ff1851bfaa..ae18f9adc7 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -67,29 +67,41 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ - parse_file = import_optional_dependency("pybtex.database", "parse_file") - citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") - bib_data = parse_file(citations_file, bib_format="bibtex") - for key, entry in bib_data.entries.items(): - self._add_citation(key, entry) + try: + parse_file = import_optional_dependency("pybtex.database", "parse_file") + citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") + bib_data = parse_file(citations_file, bib_format="bibtex") + for key, entry in bib_data.entries.items(): + self._add_citation(key, entry) + except ModuleNotFoundError: # pragma: no cover + pybamm.logger.warning( + "Citations could not be read because the 'pybtex' library is not installed. " + "Install 'pybamm[cite]' to enable citation reading." + ) def _add_citation(self, key, entry): """Adds `entry` to `self._all_citations` under `key`, warning the user if a previous entry is overwritten """ - Entry = import_optional_dependency("pybtex.database", "Entry") - # Check input types are correct - if not isinstance(key, str) or not isinstance(entry, Entry): - raise TypeError() - - # Warn if overwriting a previous citation - new_citation = entry.to_string("bibtex") - if key in self._all_citations and new_citation != self._all_citations[key]: - warnings.warn(f"Replacing citation for {key}", stacklevel=2) - - # Add to database - self._all_citations[key] = new_citation + try: + Entry = import_optional_dependency("pybtex.database", "Entry") + # Check input types are correct + if not isinstance(key, str) or not isinstance(entry, Entry): + raise TypeError() + + # Warn if overwriting a previous citation + new_citation = entry.to_string("bibtex") + if key in self._all_citations and new_citation != self._all_citations[key]: + warnings.warn(f"Replacing citation for {key}", stacklevel=2) + + # Add to database + self._all_citations[key] = new_citation + except ModuleNotFoundError: # pragma: no cover + pybamm.logger.warning( + f"Could not add citation for '{key}' because the 'pybtex' library is not installed. " + "Install 'pybamm[cite]' to enable adding citations." + ) def _add_citation_tag(self, key, entry): """Adds a tag for a citation key in the dict, which represents the name of the @@ -143,24 +155,32 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ - PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError") - parse_string = import_optional_dependency("pybtex.database", "parse_string") try: - # Parse string as a bibtex citation, and check that a citation was found - bib_data = parse_string(key, bib_format="bibtex") - if not bib_data.entries: - raise PybtexError("no entries found") - - # Add and register all citations - for key, entry in bib_data.entries.items(): - # Add to _all_citations dictionary - self._add_citation(key, entry) - # Add to _papers_to_cite set - self._papers_to_cite.add(key) - return - except PybtexError as error: - # Unable to parse / unknown key - raise KeyError(f"Not a bibtex citation or known citation: {key}") from error + PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError") + parse_string = import_optional_dependency("pybtex.database", "parse_string") + try: + # Parse string as a bibtex citation, and check that a citation was found + bib_data = parse_string(key, bib_format="bibtex") + if not bib_data.entries: + raise PybtexError("no entries found") + + # Add and register all citations + for key, entry in bib_data.entries.items(): + # Add to _all_citations dictionary + self._add_citation(key, entry) + # Add to _papers_to_cite set + self._papers_to_cite.add(key) + return + except PybtexError as error: + # Unable to parse / unknown key + raise KeyError( + f"Not a bibtex citation or known citation: {key}" + ) from error + except ModuleNotFoundError: # pragma: no cover + pybamm.logger.warning( + f"Could not parse citation for '{key}' because the 'pybtex' library is not installed. " + "Install 'pybamm[cite]' to enable citation parsing." + ) def _tag_citations(self): """Prints the citation tags for the citations that have been registered @@ -211,38 +231,44 @@ def print(self, filename=None, output_format="text", verbose=False): """ # Parse citations that were not known keys at registration, but do not # fail if they cannot be parsed - pybtex = import_optional_dependency("pybtex") try: - for key in self._unknown_citations: - self._parse_citation(key) - except KeyError: # pragma: no cover - warnings.warn( - message=f'\nCitation with key "{key}" is invalid. Please try again\n', - category=UserWarning, - stacklevel=2, - ) - # delete the invalid citation from the set - self._unknown_citations.remove(key) - - if output_format == "text": - citations = pybtex.format_from_strings( - self._cited, style="plain", output_backend="plaintext" + pybtex = import_optional_dependency("pybtex") + try: + for key in self._unknown_citations: + self._parse_citation(key) + except KeyError: # pragma: no cover + warnings.warn( + message=f'\nCitation with key "{key}" is invalid. Please try again\n', + category=UserWarning, + stacklevel=2, + ) + # delete the invalid citation from the set + self._unknown_citations.remove(key) + + if output_format == "text": + citations = pybtex.format_from_strings( + self._cited, style="plain", output_backend="plaintext" + ) + elif output_format == "bibtex": + citations = "\n".join(self._cited) + else: + raise pybamm.OptionError( + f"Output format {output_format} not recognised." + "It should be 'text' or 'bibtex'." + ) + + if filename is None: + print(citations) + if verbose: + self._tag_citations() # pragma: no cover + else: + with open(filename, "w") as f: + f.write(citations) + except ModuleNotFoundError: # pragma: no cover + pybamm.logger.warning( + "Could not print citations because the 'pybtex' library is not installed. " + "Please, install 'pybamm[cite]' to print citations." ) - elif output_format == "bibtex": - citations = "\n".join(self._cited) - else: - raise pybamm.OptionError( - f"Output format {output_format} not recognised." - "It should be 'text' or 'bibtex'." - ) - - if filename is None: - print(citations) - if verbose: - self._tag_citations() # pragma: no cover - else: - with open(filename, "w") as f: - f.write(citations) def print_citations(filename=None, output_format="text", verbose=False): diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 3d6fafdd6b..603af50560 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -93,7 +93,7 @@ def test_import_optional_dependency(self): "pybamm", optional_distribution_deps=optional_distribution_deps ) - # Save optional dependencies, then set to None + # Save optional dependencies, then make them not importable modules = {} for import_pkg in present_optional_import_deps: modules[import_pkg] = sys.modules.get(import_pkg) @@ -117,23 +117,31 @@ def test_pybamm_import(self): "pybamm", optional_distribution_deps=optional_distribution_deps ) - # Save optional dependencies, then set to None + # Save optional dependencies and their sub-modules, then make them not importable modules = {} - for import_pkg in present_optional_import_deps: - modules[import_pkg] = sys.modules.get(import_pkg) - sys.modules[import_pkg] = None + for module_name, module in sys.modules.items(): + base_module_name = module_name.split(".")[0] + if base_module_name in present_optional_import_deps: + modules[module_name] = module + sys.modules[module_name] = None + + # Unload pybamm and its sub-modules + for module_name in list(sys.modules.keys()): + base_module_name = module_name.split(".")[0] + if base_module_name == "pybamm": + sys.modules.pop(module_name) # Test pybamm is still importable try: - importlib.reload(importlib.import_module("pybamm")) + importlib.import_module("pybamm") except ModuleNotFoundError as error: self.fail( f"Import of 'pybamm' shouldn't require optional dependencies. Error: {error}" ) - - # Restore optional dependencies - for import_pkg in present_optional_import_deps: - sys.modules[import_pkg] = modules[import_pkg] + finally: + # Restore optional dependencies and their sub-modules + for module_name, module in modules.items(): + sys.modules[module_name] = module def test_optional_dependencies(self): optional_distribution_deps = get_optional_distribution_deps("pybamm")