diff --git a/src/aac_datasets/datasets/base.py b/src/aac_datasets/datasets/base.py index 907b787..f601f49 100644 --- a/src/aac_datasets/datasets/base.py +++ b/src/aac_datasets/datasets/base.py @@ -55,7 +55,7 @@ def _is_index(index: Any) -> TypeGuard[IndexType]: def _is_column(column: Any) -> TypeGuard[ColumnType]: - return isinstance(column, str) or is_iterable_str(column) or column is None + return is_iterable_str(column, accept_str=True) or column is None class AACDataset(Generic[ItemType], Dataset[ItemType]): diff --git a/src/aac_datasets/utils/type_checks.py b/src/aac_datasets/utils/type_checks.py index 3bf8727..83ac86a 100644 --- a/src/aac_datasets/utils/type_checks.py +++ b/src/aac_datasets/utils/type_checks.py @@ -10,9 +10,11 @@ def is_iterable_int(x: Any) -> TypeGuard[Iterable[int]]: return isinstance(x, Iterable) and all(isinstance(xi, int) for xi in x) -def is_iterable_str(x: Any, accept_str: bool = False) -> TypeGuard[Iterable[str]]: - return (accept_str or not isinstance(x, str)) and ( - isinstance(x, Iterable) and all(isinstance(xi, str) for xi in x) +def is_iterable_str(x: Any, *, accept_str: bool) -> TypeGuard[Iterable[str]]: + return (accept_str and isinstance(x, str)) or ( + not isinstance(x, str) + and isinstance(x, Iterable) + and all(isinstance(xi, str) for xi in x) )