diff --git a/becquerel/core/calibration.py b/becquerel/core/calibration.py index d3ee4891..fb222469 100644 --- a/becquerel/core/calibration.py +++ b/becquerel/core/calibration.py @@ -240,9 +240,8 @@ def _validate_expression( raise CalibrationError(f"Independent variable {ind_var} must be 'x' or 'y'") ind_var_appears = False for node in ast.walk(ast.parse(expression)): - if type(node) is ast.Name: - if node.id == ind_var: - ind_var_appears = True + if type(node) is ast.Name and node.id == ind_var: + ind_var_appears = True if not ind_var_appears: raise CalibrationError( f'Independent variable "{ind_var}" must appear in the expression:\n' @@ -267,12 +266,11 @@ def _validate_expression( "Parameter indices in expression are not contiguous:\n" f"{expression}\n{param_indices}" ) - if params is not None: - if len(param_indices) != len(params): - raise CalibrationError( - "Not enough parameter indices in expression:\n" - f"{expression}\n{param_indices}" - ) + if params is not None and len(param_indices) != len(params): + raise CalibrationError( + "Not enough parameter indices in expression:\n" + f"{expression}\n{param_indices}" + ) # make sure the expression can be evaluated if params is not None: diff --git a/becquerel/core/fitting.py b/becquerel/core/fitting.py index 50a6109c..346dd84e 100644 --- a/becquerel/core/fitting.py +++ b/becquerel/core/fitting.py @@ -693,7 +693,7 @@ def _translate_model(self, m): raise FittingError(f"Unknown model type: {m}") def _make_model(self, model): - if isinstance(model, str) or isinstance(model, Model): + if isinstance(model, (str, Model)): model = [model] # Convert the model(s) to a list of Model classes / Model instancess self._model_cls_cnt = {} @@ -750,10 +750,9 @@ def _guess_param_defaults(self, **kwargs): def guess_param_defaults(self, update=False, **kwargs): defaults = self._guess_param_defaults(**kwargs) - if update: - if defaults is not None: - for dp in defaults: - self.set_param(*dp) + if update and defaults is not None: + for dp in defaults: + self.set_param(*dp) return defaults def fit(self, backend="lmfit", guess=None, limits=None): @@ -871,9 +870,7 @@ def model_loss(*args): min_vals[lim[0]] = lim[2] elif lim[1] == "max": max_vals[lim[0]] = lim[2] - limits_i = { - p: (min_vals.get(p, None), max_vals.get(p, None)) for p in free_vars - } + limits_i = {p: (min_vals.get(p), max_vals.get(p)) for p in free_vars} except NotImplementedError: # If the model/component does not have a guess() method limits_i = {} diff --git a/becquerel/core/plotting.py b/becquerel/core/plotting.py index df626e99..11a2652f 100644 --- a/becquerel/core/plotting.py +++ b/becquerel/core/plotting.py @@ -94,18 +94,17 @@ def xmode(self, mode): self._xmode = "energy" else: self._xmode = "channel" + elif mode.lower() in ("kev", "energy"): + if not self.spec.is_calibrated: + raise PlottingError( + "Spectrum is not calibrated, however " + "x axis was requested as energy" + ) + self._xmode = "energy" + elif mode.lower() in ("channel", "channels", "chn", "chns"): + self._xmode = "channel" else: - if mode.lower() in ("kev", "energy"): - if not self.spec.is_calibrated: - raise PlottingError( - "Spectrum is not calibrated, however " - "x axis was requested as energy" - ) - self._xmode = "energy" - elif mode.lower() in ("channel", "channels", "chn", "chns"): - self._xmode = "channel" - else: - raise PlottingError(f"Unknown x data mode: {mode}") + raise PlottingError(f"Unknown x data mode: {mode}") # Then, set the _xedges and _xlabel based on the _xmode xedges, xlabel = self.spec.parse_xmode(self._xmode) diff --git a/becquerel/core/spectrum.py b/becquerel/core/spectrum.py index 51f343a0..8e088109 100644 --- a/becquerel/core/spectrum.py +++ b/becquerel/core/spectrum.py @@ -176,12 +176,11 @@ def __init__( if realtime is not None: self.realtime = float(realtime) - if self.livetime is not None: - if self.livetime > self.realtime: - raise ValueError( - f"Livetime ({self.livetime}) cannot exceed realtime " - f"({self.realtime})" - ) + if self.livetime is not None and self.livetime > self.realtime: + raise ValueError( + f"Livetime ({self.livetime}) cannot exceed realtime " + f"({self.realtime})" + ) self.start_time = handle_datetime(start_time, "start_time", allow_none=True) self.stop_time = handle_datetime(stop_time, "stop_time", allow_none=True) @@ -903,20 +902,26 @@ def _add_sub_error_checking(self, other): "calibrated spectrum. If both have the same calibration, " 'please use the "calibrate_like" method' ) - if self.is_calibrated and other.is_calibrated: - if not np.all(self.bin_edges_kev == other.bin_edges_kev): - raise NotImplementedError( - "Addition/subtraction for arbitrary calibrated spectra " - "not implemented" - ) - # TODO: if both spectra are calibrated but with different - # calibrations, should one be rebinned to match? - if not self.is_calibrated and not other.is_calibrated: - if not np.all(self.bin_edges_raw == other.bin_edges_raw): - raise NotImplementedError( - "Addition/subtraction for arbitrary uncalibrated " - "spectra not implemented" - ) + if ( + self.is_calibrated + and other.is_calibrated + and not np.all(self.bin_edges_kev == other.bin_edges_kev) + ): + raise NotImplementedError( + "Addition/subtraction for arbitrary calibrated spectra " + "not implemented" + ) + # TODO: if both spectra are calibrated but with different + # calibrations, should one be rebinned to match? + if ( + not self.is_calibrated + and not other.is_calibrated + and not np.all(self.bin_edges_raw == other.bin_edges_raw) + ): + raise NotImplementedError( + "Addition/subtraction for arbitrary uncalibrated " + "spectra not implemented" + ) def __mul__(self, other): """Return a new Spectrum object with counts (or CPS) scaled up. @@ -937,7 +942,7 @@ def __mul__(self, other): # This line adds the right multiplication __rmul__ = __mul__ - def __div__(self, other): + def __truediv__(self, other): """Return a new Spectrum object with counts (or CPS) scaled down. Args: @@ -953,9 +958,6 @@ def __div__(self, other): return self._mul_div(other, div=True) - # This line adds true division - __truediv__ = __div__ - def _mul_div(self, scaling_factor: float, div=False): """Multiply or divide a spectrum by a scalar. Handle errors. @@ -980,13 +982,12 @@ def _mul_div(self, scaling_factor: float, div=False): or np.isnan(scaling_factor) ): raise ValueError("Scaling factor must be nonzero and finite") - else: - if ( - scaling_factor.nominal_value == 0 - or np.isinf(scaling_factor.nominal_value) - or np.isnan(scaling_factor.nominal_value) - ): - raise ValueError("Scaling factor must be nonzero and finite") + elif ( + scaling_factor.nominal_value == 0 + or np.isinf(scaling_factor.nominal_value) + or np.isnan(scaling_factor.nominal_value) + ): + raise ValueError("Scaling factor must be nonzero and finite") if div: multiplier = 1 / scaling_factor else: @@ -1146,10 +1147,7 @@ def has_uniform_bins(self, use_kev=None, rtol=None) -> bool: # first non-uniform bin. iterator = iter(bin_widths) x0 = next(iterator, None) - for x in iterator: - if abs(x / x0 - 1.0) > rtol: - return False - return True + return all(abs(x / x0 - 1.0) <= rtol for x in iterator) def find_bin_index(self, x: float, use_kev=None) -> int: """Find the Spectrum bin index or indices containing x-axis value(s) x. @@ -1348,13 +1346,16 @@ def rebin( "Cannot rebin spectrum without energy calibration" ) # TODO: why not? in_spec = self.counts_vals - if method.lower() == "listmode": - if (self._counts is None) and (self.livetime is not None): - warnings.warn( - "Rebinning by listmode method without explicit counts " - "provided in Spectrum object", - SpectrumWarning, - ) + if ( + method.lower() == "listmode" + and (self._counts is None) + and (self.livetime is not None) + ): + warnings.warn( + "Rebinning by listmode method without explicit counts " + "provided in Spectrum object", + SpectrumWarning, + ) out_spec = rebin( in_spec, self.bin_edges_kev, @@ -1476,7 +1477,7 @@ def plot(self, *fmt, **kwargs): color = ax.get_lines()[-1].get_color() if emode == "band": plotter.errorband(color=color, alpha=alpha * 0.5, label="_nolegend_") - elif emode == "bars" or emode == "bar": + elif emode in ("bars", "bar"): plotter.errorbar(color=color, label="_nolegend_") elif emode != "none": raise SpectrumError(f"Unknown error mode '{emode}', use 'bars' or 'band'") diff --git a/becquerel/io/h5.py b/becquerel/io/h5.py index 8c573bf7..b8a0bb82 100644 --- a/becquerel/io/h5.py +++ b/becquerel/io/h5.py @@ -101,7 +101,7 @@ def write_h5(name: str | h5py.File | h5py.Group, dsets: dict, attrs: dict) -> No """ with open_h5(name, "w") as file: # write the datasets - for key in dsets.keys(): + for key in dsets: try: file.create_dataset( key, @@ -137,7 +137,7 @@ def read_h5(name: str | h5py.File | h5py.Group) -> tuple[dict, dict, list]: skipped = [] with open_h5(name, "r") as file: # read the datasets - for key in file.keys(): + for key in file: # skip any non-datasets if not isinstance(file[key], h5py.Dataset): skipped.append(str(key)) diff --git a/becquerel/parsers/cnf.py b/becquerel/parsers/cnf.py index 6c45dd0c..a87704c0 100644 --- a/becquerel/parsers/cnf.py +++ b/becquerel/parsers/cnf.py @@ -296,7 +296,7 @@ def read(filename, verbose=False, cal_kwargs=None): data["counts"] = counts # clean up null characters in any strings - for key in data.keys(): + for key in data: if isinstance(data[key], str): data[key] = data[key].replace("\x00", " ") data[key] = data[key].replace("\x01", " ") diff --git a/becquerel/parsers/spc.py b/becquerel/parsers/spc.py index c319be9e..7a21deec 100644 --- a/becquerel/parsers/spc.py +++ b/becquerel/parsers/spc.py @@ -392,7 +392,7 @@ def read(filename, verbose=False, cal_kwargs=None): raise BecquerelParserError("Calibration parameters not found") from exc # clean up null characters in any strings - for key in data.keys(): + for key in data: if isinstance(data[key], str): data[key] = data[key].replace("\x00", " ") data[key] = data[key].replace("\x01", " ") diff --git a/becquerel/parsers/spe.py b/becquerel/parsers/spe.py index 1f4ebffb..75433350 100644 --- a/becquerel/parsers/spe.py +++ b/becquerel/parsers/spe.py @@ -97,9 +97,8 @@ def read(filename, verbose=False, cal_kwargs=None): while i < len(lines) and not lines[i].startswith("$"): values.append(lines[i]) i += 1 - if i < len(lines): - if lines[i].startswith("$"): - i -= 1 + if i < len(lines) and lines[i].startswith("$"): + i -= 1 if len(values) == 1: values = values[0] data[key] = values diff --git a/becquerel/tools/isotope.py b/becquerel/tools/isotope.py index 6ba615a1..6b041397 100644 --- a/becquerel/tools/isotope.py +++ b/becquerel/tools/isotope.py @@ -235,35 +235,31 @@ def _init_m(self, arg): if arg == "" or arg is None or arg == 0: self.m = "" self.M = 0 - else: - if isinstance(arg, int): - if arg == 1: - self.m = "m" - self.M = 1 - elif arg >= 2: - self.M = arg - self.m = f"m{self.M}" - else: - raise IsotopeError(f"Metastable level must be >= 0: {arg}") - elif isinstance(arg, str): - self.m = arg.lower() - if self.m[0] != "m": + elif isinstance(arg, int): + if arg == 1: + self.m = "m" + self.M = 1 + elif arg >= 2: + self.M = arg + self.m = f"m{self.M}" + else: + raise IsotopeError(f"Metastable level must be >= 0: {arg}") + elif isinstance(arg, str): + self.m = arg.lower() + if self.m[0] != "m": + raise IsotopeError(f'Metastable level must start with "m": {self.m}') + if len(self.m) > 1: + if not self.m[1:].isdigit(): raise IsotopeError( - f'Metastable level must start with "m": {self.m}' + f"Metastable level must be numeric: {self.m[0]} {self.m[1:]}" ) - if len(self.m) > 1: - if not self.m[1:].isdigit(): - raise IsotopeError( - "Metastable level must be numeric: " - f"{self.m[0]} {self.m[1:]}" - ) - self.M = int(self.m[1:]) - else: - self.M = 1 + self.M = int(self.m[1:]) else: - raise IsotopeError( - f"Metastable level must be integer or string: {arg} {type(arg)}" - ) + self.M = 1 + else: + raise IsotopeError( + f"Metastable level must be integer or string: {arg} {type(arg)}" + ) def __str__(self): """Define behavior of str() on Isotope.""" @@ -370,9 +366,8 @@ def abundance(self): df = self._wallet_card() data = df["Abundance (%)"].tolist() - if not isinstance(data[0], uncertainties.core.Variable): - if np.isnan(data[0]): - return None + if not isinstance(data[0], uncertainties.core.Variable) and np.isnan(data[0]): + return None return data[0] @property @@ -411,9 +406,8 @@ def mass_excess(self): df = self._wallet_card() data = df["Mass Excess (MeV)"].tolist() - if not isinstance(data[0], uncertainties.core.Variable): - if np.isnan(data[0]): - return None + if not isinstance(data[0], uncertainties.core.Variable) and np.isnan(data[0]): + return None return data[0] @property diff --git a/becquerel/tools/isotope_qty.py b/becquerel/tools/isotope_qty.py index 24be1575..fd9efa86 100644 --- a/becquerel/tools/isotope_qty.py +++ b/becquerel/tools/isotope_qty.py @@ -512,13 +512,8 @@ def __mul__(self, other): return self._mul_div(other, div=False) - def __div__(self, other): - """Divide the quantity""" - - return self._mul_div(other, div=True) - def __truediv__(self, other): - """Divide the quantity (python 3)""" + """Divide the quantity""" return self._mul_div(other, div=True) diff --git a/becquerel/tools/materials.py b/becquerel/tools/materials.py index 430755fa..ba84c6a5 100644 --- a/becquerel/tools/materials.py +++ b/becquerel/tools/materials.py @@ -35,12 +35,11 @@ def _load_and_compile_materials(): rho2 = data_elem["Density"][data_elem["Element"] == name].to_numpy()[0] elif name in data_mat["Material"].to_numpy(): rho2 = data_mat["Density"][data_mat["Material"] == name].to_numpy()[0] - if rho2: - if not np.isclose(rho1, rho2, atol=2e-2): - raise MaterialsError( - f"Material {name} densities do not match between different " - f"data sources: {rho1:.6f} {rho2:.6f}" - ) + if rho2 and not np.isclose(rho1, rho2, atol=2e-2): + raise MaterialsError( + f"Material {name} densities do not match between different " + f"data sources: {rho1:.6f} {rho2:.6f}" + ) for j in range(len(data_comp)): name = data_comp["Material"].to_numpy()[j] diff --git a/becquerel/tools/materials_compendium.py b/becquerel/tools/materials_compendium.py index e0b337f5..20678737 100644 --- a/becquerel/tools/materials_compendium.py +++ b/becquerel/tools/materials_compendium.py @@ -66,7 +66,7 @@ def fetch_compendium_data(): print("Pre-March 2022 JSON detected") elif isinstance(data, dict): print("Post-March 2022 JSON detected") - if "siteVersion" not in data.keys() or "data" not in data.keys(): + if "siteVersion" not in data or "data" not in data: raise MaterialsError( "Attempt to read Compendium JSON failed; " "dictionary must have keys 'siteVersion' " @@ -80,7 +80,7 @@ def fetch_compendium_data(): "object must be a list or dict but is a " + str(type(data)) ) names = [datum["Name"] for datum in data] - formulae = [datum["Formula"] if "Formula" in datum else "-" for datum in data] + formulae = [datum.get("Formula", "-") for datum in data] densities = [datum["Density"] for datum in data] weight_fracs = [ json_elements_to_weight_fractions(datum["Elements"]) for datum in data diff --git a/becquerel/tools/nndc.py b/becquerel/tools/nndc.py index d2481589..52d5f7c2 100644 --- a/becquerel/tools/nndc.py +++ b/becquerel/tools/nndc.py @@ -122,7 +122,7 @@ def _parse_headers(headers): # reformat column headers if needed for j, hd in enumerate(headers): # rename so always have T1/2 (s) - if hd == "T1/2 (num)" or hd == "T1/2 (seconds)": + if hd in ("T1/2 (num)", "T1/2 (seconds)"): hd = "T1/2 (s)" # for uncertainties, add previous column header to it if j > 0 and "Unc" in hd: @@ -260,11 +260,9 @@ def _parse_float_uncertainty(x, dx): if "8 .0E-E5" in x: x = x.replace("8 .0E-E5", "8.0E-5") # handle blank or missing data - if x == "" or x == " ": + if x in ("", " "): return None - if "****" in dx: - dx = "" - elif dx in ["LT", "GT", "LE", "GE", "AP", "CA", "SY"]: + if "****" in dx or dx in ["LT", "GT", "LE", "GE", "AP", "CA", "SY"]: dx = "" try: x2 = float(x) @@ -408,9 +406,7 @@ def __init__(self, **kwargs): def __len__(self): """Length of any one of the data lists.""" - if self.df is None: - return 0 - elif len(self.df.keys()) == 0: + if self.df is None or len(self.df.keys()) == 0: return 0 else: return len(self.df[self.df.keys()[0]]) diff --git a/becquerel/tools/xcom.py b/becquerel/tools/xcom.py index 3fa9ea03..6cf5cad7 100644 --- a/becquerel/tools/xcom.py +++ b/becquerel/tools/xcom.py @@ -178,9 +178,7 @@ def __init__(self, arg, **kwargs): def __len__(self): """Pass-through to use DataFrame len().""" - if self.df is None: - return 0 - elif len(self.df.keys()) == 0: + if self.df is None or len(self.df.keys()) == 0: return 0 else: return len(self.df[self.df.keys()[0]]) diff --git a/examples/nndc_chart_of_nuclides.ipynb b/examples/nndc_chart_of_nuclides.ipynb index 1787a409..7a31d805 100644 --- a/examples/nndc_chart_of_nuclides.ipynb +++ b/examples/nndc_chart_of_nuclides.ipynb @@ -17,9 +17,9 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.patches as patches\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "from matplotlib import patches\n", "\n", "import becquerel as bq" ] diff --git a/pyproject.toml b/pyproject.toml index 17a0c2c1..a8e69fdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ select = [ "RSE", # flake8-raise # "RET", # flake8-return # "SLF", # flake8-self - # "SIM", # flake8-simplify + "SIM", # flake8-simplify # "TID", # flake8-tidy-imports "TCH", # flake8-type-checking "INT", # flake8-gettext @@ -69,7 +69,7 @@ select = [ # "ERA", # eradicate "PD", # pandas-vet "PGH", # pygrep-hooks - # "PL", # pylint + "PL", # pylint "TRY", # tryceratops "FLY", # flynt "NPY", # NumPy-specific rules @@ -83,7 +83,12 @@ ignore = [ "B015", # Pointless comparison. Did you mean to assign a value? Otherwise, prepend `assert` or remove it. "B018", # Found useless expression. Either assign it to a variable or remove it. "B028", # No explicit `stacklevel` keyword argument found + "SIM105", # Use `contextlib.suppress(Exception)` instead of `try`-`except`-`pass` + "SIM108", # Use ternary operator instead of `if`-`else`-block + "SIM300", # Yoda conditions are discouraged "PD901", # Avoid using the generic variable name `df` for DataFrames + "PLW2901", # `for` loop variable overwritten by assignment target + "PLR2004", # Magic value used in comparison, consider replacing with a constant variable "TRY003", # Avoid specifying long messages outside the exception class "NPY002", # Replace legacy `np.random.poisson` call with `np.random.Generator` "PERF203", # `try`-`except` within a loop incurs performance overhead @@ -111,7 +116,8 @@ convention = "google" [tool.ruff.lint.pylint] max-args = 15 -max-branches = 15 +max-branches = 60 max-locals = 25 max-nested-blocks = 15 -max-statements = 100 +max-returns = 20 +max-statements = 150 diff --git a/tests/h5_tools_test.py b/tests/h5_tools_test.py index a74c68ba..f744f594 100644 --- a/tests/h5_tools_test.py +++ b/tests/h5_tools_test.py @@ -126,12 +126,12 @@ def check_dsets_attrs(dsets1, attrs1, dsets2, attrs2): """Check that the dataset and attribute dicts are identical.""" assert set(dsets1.keys()) == set(dsets2.keys()) assert set(attrs1.keys()) == set(attrs2.keys()) - for key in dsets1.keys(): + for key in dsets1: if "str" in key: assert ensure_string(dsets1[key]) == ensure_string(dsets2[key]) else: assert np.allclose(dsets1[key], dsets2[key]) - for key in attrs1.keys(): + for key in attrs1: if "str" in key: assert ensure_string(attrs1[key]) == ensure_string(attrs2[key]) else: diff --git a/tests/materials_test.py b/tests/materials_test.py index e595b14b..a99166a7 100644 --- a/tests/materials_test.py +++ b/tests/materials_test.py @@ -7,12 +7,12 @@ import pytest from utils import xcom_is_up -import becquerel.tools.materials as materials -import becquerel.tools.materials_compendium as materials_compendium from becquerel.tools import ( MaterialsError, MaterialsWarning, fetch_materials, + materials, + materials_compendium, remove_materials_csv, ) from becquerel.tools.materials_nist import convert_composition diff --git a/tests/wallet_cache_test.py b/tests/wallet_cache_test.py index 1b4ece39..501e8c79 100644 --- a/tests/wallet_cache_test.py +++ b/tests/wallet_cache_test.py @@ -5,7 +5,7 @@ import uncertainties from utils import nndc_is_up -import becquerel.tools.wallet_cache as wallet_cache +from becquerel.tools import wallet_cache @pytest.mark.parametrize(