Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optional dependencies error detection and remove pybtex hidden dependency #3968

Merged
merged 14 commits into from
Apr 14, 2024
Merged
154 changes: 90 additions & 64 deletions pybamm/citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,41 @@ def read_citations(self):
"""Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited
by passing a BibTeX citation to :meth:`register`.
"""
parse_file = import_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
self._add_citation(key, entry)
try:
parse_file = import_optional_dependency("pybtex.database", "parse_file")
citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib")
bib_data = parse_file(citations_file, bib_format="bibtex")
for key, entry in bib_data.entries.items():
self._add_citation(key, entry)
except ModuleNotFoundError:
pybamm.logger.warning(
"Citations could not be read because the 'pybtex' library is not installed. "
"Install 'pybtex' to enable citation reading."
lorenzofavaro marked this conversation as resolved.
Show resolved Hide resolved
)

def _add_citation(self, key, entry):
"""Adds `entry` to `self._all_citations` under `key`, warning the user if a
previous entry is overwritten
"""

Entry = import_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
raise TypeError()

# Warn if overwriting a previous citation
new_citation = entry.to_string("bibtex")
if key in self._all_citations and new_citation != self._all_citations[key]:
warnings.warn(f"Replacing citation for {key}", stacklevel=2)

# Add to database
self._all_citations[key] = new_citation
try:
Entry = import_optional_dependency("pybtex.database", "Entry")
# Check input types are correct
if not isinstance(key, str) or not isinstance(entry, Entry):
raise TypeError()

# Warn if overwriting a previous citation
new_citation = entry.to_string("bibtex")
if key in self._all_citations and new_citation != self._all_citations[key]:
warnings.warn(f"Replacing citation for {key}", stacklevel=2)

# Add to database
self._all_citations[key] = new_citation
except ModuleNotFoundError:
pybamm.logger.warning(
f"Could not add citation for '{key}' because the 'pybtex' library is not installed. "
"Install 'pybtex' to enable adding citations."
lorenzofavaro marked this conversation as resolved.
Show resolved Hide resolved
)

def _add_citation_tag(self, key, entry):
"""Adds a tag for a citation key in the dict, which represents the name of the
Expand Down Expand Up @@ -143,24 +155,32 @@ def _parse_citation(self, key):
key: str
A BibTeX formatted citation
"""
PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = import_optional_dependency("pybtex.database", "parse_string")
try:
# Parse string as a bibtex citation, and check that a citation was found
bib_data = parse_string(key, bib_format="bibtex")
if not bib_data.entries:
raise PybtexError("no entries found")

# Add and register all citations
for key, entry in bib_data.entries.items():
# Add to _all_citations dictionary
self._add_citation(key, entry)
# Add to _papers_to_cite set
self._papers_to_cite.add(key)
return
except PybtexError as error:
# Unable to parse / unknown key
raise KeyError(f"Not a bibtex citation or known citation: {key}") from error
PybtexError = import_optional_dependency("pybtex.scanner", "PybtexError")
parse_string = import_optional_dependency("pybtex.database", "parse_string")
try:
# Parse string as a bibtex citation, and check that a citation was found
bib_data = parse_string(key, bib_format="bibtex")
if not bib_data.entries:
raise PybtexError("no entries found")

# Add and register all citations
for key, entry in bib_data.entries.items():
# Add to _all_citations dictionary
self._add_citation(key, entry)
# Add to _papers_to_cite set
self._papers_to_cite.add(key)
return
except PybtexError as error:
# Unable to parse / unknown key
raise KeyError(
f"Not a bibtex citation or known citation: {key}"
) from error
except ModuleNotFoundError:
pybamm.logger.warning(
f"Could not parse citation for '{key}' because the 'pybtex' library is not installed. "
"Install 'pybtex' to enable citation parsing."
lorenzofavaro marked this conversation as resolved.
Show resolved Hide resolved
)

def _tag_citations(self):
"""Prints the citation tags for the citations that have been registered
Expand Down Expand Up @@ -211,38 +231,44 @@ def print(self, filename=None, output_format="text", verbose=False):
"""
# Parse citations that were not known keys at registration, but do not
# fail if they cannot be parsed
pybtex = import_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
self._parse_citation(key)
except KeyError: # pragma: no cover
warnings.warn(
message=f'\nCitation with key "{key}" is invalid. Please try again\n',
category=UserWarning,
stacklevel=2,
)
# delete the invalid citation from the set
self._unknown_citations.remove(key)

if output_format == "text":
citations = pybtex.format_from_strings(
self._cited, style="plain", output_backend="plaintext"
pybtex = import_optional_dependency("pybtex")
try:
for key in self._unknown_citations:
self._parse_citation(key)
except KeyError: # pragma: no cover
warnings.warn(
message=f'\nCitation with key "{key}" is invalid. Please try again\n',
category=UserWarning,
stacklevel=2,
)
# delete the invalid citation from the set
self._unknown_citations.remove(key)

if output_format == "text":
citations = pybtex.format_from_strings(
self._cited, style="plain", output_backend="plaintext"
)
elif output_format == "bibtex":
citations = "\n".join(self._cited)
else:
raise pybamm.OptionError(
f"Output format {output_format} not recognised."
"It should be 'text' or 'bibtex'."
)

if filename is None:
print(citations)
if verbose:
self._tag_citations() # pragma: no cover
else:
with open(filename, "w") as f:
f.write(citations)
except ModuleNotFoundError:
pybamm.logger.warning(
"Could not print citations because the 'pybtex' library is not installed. "
"Please, install 'pybtex' to print citations."
lorenzofavaro marked this conversation as resolved.
Show resolved Hide resolved
)
elif output_format == "bibtex":
citations = "\n".join(self._cited)
else:
raise pybamm.OptionError(
f"Output format {output_format} not recognised."
"It should be 'text' or 'bibtex'."
)

if filename is None:
print(citations)
if verbose:
self._tag_citations() # pragma: no cover
else:
with open(filename, "w") as f:
f.write(citations)


def print_citations(filename=None, output_format="text", verbose=False):
Expand Down
28 changes: 18 additions & 10 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_import_optional_dependency(self):
"pybamm", optional_distribution_deps=optional_distribution_deps
)

# Save optional dependencies, then set to None
# Save optional dependencies, then make them not importable
modules = {}
for import_pkg in present_optional_import_deps:
modules[import_pkg] = sys.modules.get(import_pkg)
Expand All @@ -132,23 +132,31 @@ def test_pybamm_import(self):
"pybamm", optional_distribution_deps=optional_distribution_deps
)

# Save optional dependencies, then set to None
# Save optional dependencies and their sub-modules, then make them not importable
modules = {}
for import_pkg in present_optional_import_deps:
modules[import_pkg] = sys.modules.get(import_pkg)
sys.modules[import_pkg] = None
for module_name, module in sys.modules.items():
base_module_name = module_name.split(".")[0]
if base_module_name in present_optional_import_deps:
modules[module_name] = module
sys.modules[module_name] = None

# Unload pybamm and its sub-modules
for module_name in list(sys.modules.keys()):
base_module_name = module_name.split(".")[0]
if base_module_name == "pybamm":
sys.modules.pop(module_name)

# Test pybamm is still importable
try:
importlib.reload(importlib.import_module("pybamm"))
importlib.import_module("pybamm")
except ModuleNotFoundError as error:
self.fail(
f"Import of 'pybamm' shouldn't require optional dependencies. Error: {error}"
)

# Restore optional dependencies
for import_pkg in present_optional_import_deps:
sys.modules[import_pkg] = modules[import_pkg]
finally:
# Restore optional dependencies and their sub-modules
for module_name, module in modules.items():
sys.modules[module_name] = module

def test_optional_dependencies(self):
optional_distribution_deps = get_optional_distribution_deps("pybamm")
Expand Down
Loading