Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <kalyan.ben10@live.com>
  • Loading branch information
rawwar committed Feb 12, 2024
1 parent f80995c commit d1912bf
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 24 deletions.
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import os
import sys


sys.path.insert(0, os.path.abspath("../../"))

# -- Project information -----------------------------------------------------
Expand Down
6 changes: 4 additions & 2 deletions opensearch_py_ml/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 +978,10 @@ def _sizeof_fmt(num: float, size_qualifier: str) -> str:
elif verbose is False: # specifically set to False, not nesc None
_non_verbose_repr()
else:
_non_verbose_repr() if exceeds_info_cols else _verbose_repr(
number_of_columns
(
_non_verbose_repr()
if exceeds_info_cols
else _verbose_repr(number_of_columns)
)

# pandas 0.25.1 uses get_dtype_counts() here. This
Expand Down
6 changes: 3 additions & 3 deletions opensearch_py_ml/field_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,9 @@ def find_aggregatable(row, df):
try:
series = df.loc[df.os_field_name == os_field_name_keyword]
if not series.empty and series.is_aggregatable.squeeze():
row_as_dict[
"aggregatable_os_field_name"
] = os_field_name_keyword
row_as_dict["aggregatable_os_field_name"] = (
os_field_name_keyword
)
else:
row_as_dict["aggregatable_os_field_name"] = None
except KeyError:
Expand Down
6 changes: 3 additions & 3 deletions opensearch_py_ml/ml_commons/model_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ def _register_model(
model_meta_json[TOTAL_CHUNKS_FIELD] = total_num_chunks

if MODEL_CONTENT_SIZE_IN_BYTES_FIELD not in model_meta_json:
model_meta_json[
MODEL_CONTENT_SIZE_IN_BYTES_FIELD
] = model_content_size_in_bytes
model_meta_json[MODEL_CONTENT_SIZE_IN_BYTES_FIELD] = (
model_content_size_in_bytes
)
if MODEL_CONTENT_HASH_VALUE not in model_meta_json:
# Generate the sha1 hash for the model zip file
hash_val_model_file = _generate_model_content_hash_value(model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def merge_events(
send = ends[ix]
sevents = candidates[ix, :]

merged: List[
Dict[str, torch.Tensor]
] = [] # merge in linear pass over time dimension
merged: List[Dict[str, torch.Tensor]] = (
[]
) # merge in linear pass over time dimension
currstart = torch.tensor([-1])
currend = torch.tensor([-1])
currevent = torch.ones(T) * -1.0
Expand Down
6 changes: 3 additions & 3 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,9 +1304,9 @@ def make_model_config_json(
model_config_content["model_content_size_in_bytes"] = os.stat(
model_zip_file_path
).st_size
model_config_content[
"model_content_hash_value"
] = _generate_model_content_hash_value(model_zip_file_path)
model_config_content["model_content_hash_value"] = (
_generate_model_content_hash_value(model_zip_file_path)
)

if verbose:
print("generating ml-commons_model_config.json file...\n")
Expand Down
8 changes: 5 additions & 3 deletions opensearch_py_ml/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,11 @@ def _map_pd_aggs_to_os_aggs(
# piggy-back on that single aggregation.
if extended_stats_calls >= 2:
os_aggs = [
("extended_stats", os_agg)
if os_agg in extended_stats_os_aggs
else os_agg
(
("extended_stats", os_agg)
if os_agg in extended_stats_os_aggs
else os_agg
)
for os_agg in os_aggs
]

Expand Down
8 changes: 5 additions & 3 deletions tests/dataframe/test_iterrows_itertuples_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def assert_tuples_almost_equal(left, right):
# Shim which uses pytest.approx() for floating point values inside tuples.
assert len(left) == len(right)
assert all(
(lt == rt) # Not floats? Use ==
if not isinstance(lt, float) and not isinstance(rt, float)
else (lt == pytest.approx(rt)) # If both are floats use pytest.approx()
(
(lt == rt) # Not floats? Use ==
if not isinstance(lt, float) and not isinstance(rt, float)
else (lt == pytest.approx(rt))
) # If both are floats use pytest.approx()
for lt, rt in zip(left, right)
)

Expand Down
6 changes: 3 additions & 3 deletions utils/model_uploader/update_models_upload_history_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def create_model_json_obj(
"Model ID": model_id,
"Model Version": model_version,
"Model Format": model_format,
"Embedding Dimension": str(embedding_dimension)
if embedding_dimension is not None
else "N/A",
"Embedding Dimension": (
str(embedding_dimension) if embedding_dimension is not None else "N/A"
),
"Pooling Mode": pooling_mode if pooling_mode is not None else "N/A",
"Workflow Run ID": workflow_id if workflow_id is not None else "-",
}
Expand Down

0 comments on commit d1912bf

Please sign in to comment.