Skip to content

Commit

Permalink
Merge pull request #418 from lbl-anp/enable-ruff-pylint-take2
Browse files Browse the repository at this point in the history
Enable pylint and other linting rules in ruff
  • Loading branch information
markbandstra committed Jun 5, 2024
2 parents fc39bb4 + ef295c4 commit 71be85a
Show file tree
Hide file tree
Showing 19 changed files with 127 additions and 145 deletions.
16 changes: 7 additions & 9 deletions becquerel/core/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,8 @@ def _validate_expression(
raise CalibrationError(f"Independent variable {ind_var} must be 'x' or 'y'")
ind_var_appears = False
for node in ast.walk(ast.parse(expression)):
if type(node) is ast.Name:
if node.id == ind_var:
ind_var_appears = True
if type(node) is ast.Name and node.id == ind_var:
ind_var_appears = True
if not ind_var_appears:
raise CalibrationError(
f'Independent variable "{ind_var}" must appear in the expression:\n'
Expand All @@ -267,12 +266,11 @@ def _validate_expression(
"Parameter indices in expression are not contiguous:\n"
f"{expression}\n{param_indices}"
)
if params is not None:
if len(param_indices) != len(params):
raise CalibrationError(
"Not enough parameter indices in expression:\n"
f"{expression}\n{param_indices}"
)
if params is not None and len(param_indices) != len(params):
raise CalibrationError(
"Not enough parameter indices in expression:\n"
f"{expression}\n{param_indices}"
)

# make sure the expression can be evaluated
if params is not None:
Expand Down
13 changes: 5 additions & 8 deletions becquerel/core/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def _translate_model(self, m):
raise FittingError(f"Unknown model type: {m}")

def _make_model(self, model):
if isinstance(model, str) or isinstance(model, Model):
if isinstance(model, (str, Model)):
model = [model]
# Convert the model(s) to a list of Model classes / Model instancess
self._model_cls_cnt = {}
Expand Down Expand Up @@ -750,10 +750,9 @@ def _guess_param_defaults(self, **kwargs):

def guess_param_defaults(self, update=False, **kwargs):
defaults = self._guess_param_defaults(**kwargs)
if update:
if defaults is not None:
for dp in defaults:
self.set_param(*dp)
if update and defaults is not None:
for dp in defaults:
self.set_param(*dp)
return defaults

def fit(self, backend="lmfit", guess=None, limits=None):
Expand Down Expand Up @@ -871,9 +870,7 @@ def model_loss(*args):
min_vals[lim[0]] = lim[2]
elif lim[1] == "max":
max_vals[lim[0]] = lim[2]
limits_i = {
p: (min_vals.get(p, None), max_vals.get(p, None)) for p in free_vars
}
limits_i = {p: (min_vals.get(p), max_vals.get(p)) for p in free_vars}
except NotImplementedError:
# If the model/component does not have a guess() method
limits_i = {}
Expand Down
21 changes: 10 additions & 11 deletions becquerel/core/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,17 @@ def xmode(self, mode):
self._xmode = "energy"
else:
self._xmode = "channel"
elif mode.lower() in ("kev", "energy"):
if not self.spec.is_calibrated:
raise PlottingError(
"Spectrum is not calibrated, however "
"x axis was requested as energy"
)
self._xmode = "energy"
elif mode.lower() in ("channel", "channels", "chn", "chns"):
self._xmode = "channel"
else:
if mode.lower() in ("kev", "energy"):
if not self.spec.is_calibrated:
raise PlottingError(
"Spectrum is not calibrated, however "
"x axis was requested as energy"
)
self._xmode = "energy"
elif mode.lower() in ("channel", "channels", "chn", "chns"):
self._xmode = "channel"
else:
raise PlottingError(f"Unknown x data mode: {mode}")
raise PlottingError(f"Unknown x data mode: {mode}")

# Then, set the _xedges and _xlabel based on the _xmode
xedges, xlabel = self.spec.parse_xmode(self._xmode)
Expand Down
87 changes: 44 additions & 43 deletions becquerel/core/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,11 @@ def __init__(

if realtime is not None:
self.realtime = float(realtime)
if self.livetime is not None:
if self.livetime > self.realtime:
raise ValueError(
f"Livetime ({self.livetime}) cannot exceed realtime "
f"({self.realtime})"
)
if self.livetime is not None and self.livetime > self.realtime:
raise ValueError(
f"Livetime ({self.livetime}) cannot exceed realtime "
f"({self.realtime})"
)

self.start_time = handle_datetime(start_time, "start_time", allow_none=True)
self.stop_time = handle_datetime(stop_time, "stop_time", allow_none=True)
Expand Down Expand Up @@ -903,20 +902,26 @@ def _add_sub_error_checking(self, other):
"calibrated spectrum. If both have the same calibration, "
'please use the "calibrate_like" method'
)
if self.is_calibrated and other.is_calibrated:
if not np.all(self.bin_edges_kev == other.bin_edges_kev):
raise NotImplementedError(
"Addition/subtraction for arbitrary calibrated spectra "
"not implemented"
)
# TODO: if both spectra are calibrated but with different
# calibrations, should one be rebinned to match?
if not self.is_calibrated and not other.is_calibrated:
if not np.all(self.bin_edges_raw == other.bin_edges_raw):
raise NotImplementedError(
"Addition/subtraction for arbitrary uncalibrated "
"spectra not implemented"
)
if (
self.is_calibrated
and other.is_calibrated
and not np.all(self.bin_edges_kev == other.bin_edges_kev)
):
raise NotImplementedError(
"Addition/subtraction for arbitrary calibrated spectra "
"not implemented"
)
# TODO: if both spectra are calibrated but with different
# calibrations, should one be rebinned to match?
if (
not self.is_calibrated
and not other.is_calibrated
and not np.all(self.bin_edges_raw == other.bin_edges_raw)
):
raise NotImplementedError(
"Addition/subtraction for arbitrary uncalibrated "
"spectra not implemented"
)

def __mul__(self, other):
"""Return a new Spectrum object with counts (or CPS) scaled up.
Expand All @@ -937,7 +942,7 @@ def __mul__(self, other):
# This line adds the right multiplication
__rmul__ = __mul__

def __div__(self, other):
def __truediv__(self, other):
"""Return a new Spectrum object with counts (or CPS) scaled down.
Args:
Expand All @@ -953,9 +958,6 @@ def __div__(self, other):

return self._mul_div(other, div=True)

# This line adds true division
__truediv__ = __div__

def _mul_div(self, scaling_factor: float, div=False):
"""Multiply or divide a spectrum by a scalar. Handle errors.
Expand All @@ -980,13 +982,12 @@ def _mul_div(self, scaling_factor: float, div=False):
or np.isnan(scaling_factor)
):
raise ValueError("Scaling factor must be nonzero and finite")
else:
if (
scaling_factor.nominal_value == 0
or np.isinf(scaling_factor.nominal_value)
or np.isnan(scaling_factor.nominal_value)
):
raise ValueError("Scaling factor must be nonzero and finite")
elif (
scaling_factor.nominal_value == 0
or np.isinf(scaling_factor.nominal_value)
or np.isnan(scaling_factor.nominal_value)
):
raise ValueError("Scaling factor must be nonzero and finite")
if div:
multiplier = 1 / scaling_factor
else:
Expand Down Expand Up @@ -1146,10 +1147,7 @@ def has_uniform_bins(self, use_kev=None, rtol=None) -> bool:
# first non-uniform bin.
iterator = iter(bin_widths)
x0 = next(iterator, None)
for x in iterator:
if abs(x / x0 - 1.0) > rtol:
return False
return True
return all(abs(x / x0 - 1.0) <= rtol for x in iterator)

def find_bin_index(self, x: float, use_kev=None) -> int:
"""Find the Spectrum bin index or indices containing x-axis value(s) x.
Expand Down Expand Up @@ -1348,13 +1346,16 @@ def rebin(
"Cannot rebin spectrum without energy calibration"
) # TODO: why not?
in_spec = self.counts_vals
if method.lower() == "listmode":
if (self._counts is None) and (self.livetime is not None):
warnings.warn(
"Rebinning by listmode method without explicit counts "
"provided in Spectrum object",
SpectrumWarning,
)
if (
method.lower() == "listmode"
and (self._counts is None)
and (self.livetime is not None)
):
warnings.warn(
"Rebinning by listmode method without explicit counts "
"provided in Spectrum object",
SpectrumWarning,
)
out_spec = rebin(
in_spec,
self.bin_edges_kev,
Expand Down Expand Up @@ -1476,7 +1477,7 @@ def plot(self, *fmt, **kwargs):
color = ax.get_lines()[-1].get_color()
if emode == "band":
plotter.errorband(color=color, alpha=alpha * 0.5, label="_nolegend_")
elif emode == "bars" or emode == "bar":
elif emode in ("bars", "bar"):
plotter.errorbar(color=color, label="_nolegend_")
elif emode != "none":
raise SpectrumError(f"Unknown error mode '{emode}', use 'bars' or 'band'")
Expand Down
4 changes: 2 additions & 2 deletions becquerel/io/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def write_h5(name: str | h5py.File | h5py.Group, dsets: dict, attrs: dict) -> No
"""
with open_h5(name, "w") as file:
# write the datasets
for key in dsets.keys():
for key in dsets:
try:
file.create_dataset(
key,
Expand Down Expand Up @@ -137,7 +137,7 @@ def read_h5(name: str | h5py.File | h5py.Group) -> tuple[dict, dict, list]:
skipped = []
with open_h5(name, "r") as file:
# read the datasets
for key in file.keys():
for key in file:
# skip any non-datasets
if not isinstance(file[key], h5py.Dataset):
skipped.append(str(key))
Expand Down
2 changes: 1 addition & 1 deletion becquerel/parsers/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def read(filename, verbose=False, cal_kwargs=None):
data["counts"] = counts

# clean up null characters in any strings
for key in data.keys():
for key in data:
if isinstance(data[key], str):
data[key] = data[key].replace("\x00", " ")
data[key] = data[key].replace("\x01", " ")
Expand Down
2 changes: 1 addition & 1 deletion becquerel/parsers/spc.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def read(filename, verbose=False, cal_kwargs=None):
raise BecquerelParserError("Calibration parameters not found") from exc

# clean up null characters in any strings
for key in data.keys():
for key in data:
if isinstance(data[key], str):
data[key] = data[key].replace("\x00", " ")
data[key] = data[key].replace("\x01", " ")
Expand Down
5 changes: 2 additions & 3 deletions becquerel/parsers/spe.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ def read(filename, verbose=False, cal_kwargs=None):
while i < len(lines) and not lines[i].startswith("$"):
values.append(lines[i])
i += 1
if i < len(lines):
if lines[i].startswith("$"):
i -= 1
if i < len(lines) and lines[i].startswith("$"):
i -= 1
if len(values) == 1:
values = values[0]
data[key] = values
Expand Down
58 changes: 26 additions & 32 deletions becquerel/tools/isotope.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,35 +235,31 @@ def _init_m(self, arg):
if arg == "" or arg is None or arg == 0:
self.m = ""
self.M = 0
else:
if isinstance(arg, int):
if arg == 1:
self.m = "m"
self.M = 1
elif arg >= 2:
self.M = arg
self.m = f"m{self.M}"
else:
raise IsotopeError(f"Metastable level must be >= 0: {arg}")
elif isinstance(arg, str):
self.m = arg.lower()
if self.m[0] != "m":
elif isinstance(arg, int):
if arg == 1:
self.m = "m"
self.M = 1
elif arg >= 2:
self.M = arg
self.m = f"m{self.M}"
else:
raise IsotopeError(f"Metastable level must be >= 0: {arg}")
elif isinstance(arg, str):
self.m = arg.lower()
if self.m[0] != "m":
raise IsotopeError(f'Metastable level must start with "m": {self.m}')
if len(self.m) > 1:
if not self.m[1:].isdigit():
raise IsotopeError(
f'Metastable level must start with "m": {self.m}'
f"Metastable level must be numeric: {self.m[0]} {self.m[1:]}"
)
if len(self.m) > 1:
if not self.m[1:].isdigit():
raise IsotopeError(
"Metastable level must be numeric: "
f"{self.m[0]} {self.m[1:]}"
)
self.M = int(self.m[1:])
else:
self.M = 1
self.M = int(self.m[1:])
else:
raise IsotopeError(
f"Metastable level must be integer or string: {arg} {type(arg)}"
)
self.M = 1
else:
raise IsotopeError(
f"Metastable level must be integer or string: {arg} {type(arg)}"
)

def __str__(self):
"""Define behavior of str() on Isotope."""
Expand Down Expand Up @@ -370,9 +366,8 @@ def abundance(self):

df = self._wallet_card()
data = df["Abundance (%)"].tolist()
if not isinstance(data[0], uncertainties.core.Variable):
if np.isnan(data[0]):
return None
if not isinstance(data[0], uncertainties.core.Variable) and np.isnan(data[0]):
return None
return data[0]

@property
Expand Down Expand Up @@ -411,9 +406,8 @@ def mass_excess(self):

df = self._wallet_card()
data = df["Mass Excess (MeV)"].tolist()
if not isinstance(data[0], uncertainties.core.Variable):
if np.isnan(data[0]):
return None
if not isinstance(data[0], uncertainties.core.Variable) and np.isnan(data[0]):
return None
return data[0]

@property
Expand Down
7 changes: 1 addition & 6 deletions becquerel/tools/isotope_qty.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,8 @@ def __mul__(self, other):

return self._mul_div(other, div=False)

def __div__(self, other):
"""Divide the quantity"""

return self._mul_div(other, div=True)

def __truediv__(self, other):
"""Divide the quantity (python 3)"""
"""Divide the quantity"""

return self._mul_div(other, div=True)

Expand Down
11 changes: 5 additions & 6 deletions becquerel/tools/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ def _load_and_compile_materials():
rho2 = data_elem["Density"][data_elem["Element"] == name].to_numpy()[0]
elif name in data_mat["Material"].to_numpy():
rho2 = data_mat["Density"][data_mat["Material"] == name].to_numpy()[0]
if rho2:
if not np.isclose(rho1, rho2, atol=2e-2):
raise MaterialsError(
f"Material {name} densities do not match between different "
f"data sources: {rho1:.6f} {rho2:.6f}"
)
if rho2 and not np.isclose(rho1, rho2, atol=2e-2):
raise MaterialsError(
f"Material {name} densities do not match between different "
f"data sources: {rho1:.6f} {rho2:.6f}"
)

for j in range(len(data_comp)):
name = data_comp["Material"].to_numpy()[j]
Expand Down
Loading

0 comments on commit 71be85a

Please sign in to comment.