diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..014e1caa1d --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,5 @@ +RELEASE_TYPE: minor + +This release adds support for `nullable pandas dtypes `__ +in :func:`~hypothesis.extra.pandas` (:issue:`3604`). +Thanks to Cheuk Ting Ho for implementing this at the PyCon sprints! diff --git a/hypothesis-python/src/hypothesis/extra/pandas/impl.py b/hypothesis-python/src/hypothesis/extra/pandas/impl.py index 85fab8106e..3b4413820e 100644 --- a/hypothesis-python/src/hypothesis/extra/pandas/impl.py +++ b/hypothesis-python/src/hypothesis/extra/pandas/impl.py @@ -44,6 +44,12 @@ def is_categorical_dtype(dt): return dt == "category" +try: + from pandas.core.arrays.integer import IntegerDtype +except ImportError: + IntegerDtype = () + + def dtype_for_elements_strategy(s): return st.shared( s.map(lambda x: pandas.Series([x]).dtype), @@ -79,6 +85,12 @@ def elements_and_dtype(elements, dtype, source=None): f"{prefix}dtype is categorical, which is currently unsupported" ) + if isinstance(dtype, type) and issubclass(dtype, IntegerDtype): + raise InvalidArgument( + f"Passed dtype={dtype!r} is a dtype class, please pass in an instance of this class." + "Otherwise it would be treated as dtype=object" + ) + if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object: note_deprecation( f"Passed dtype={dtype!r} is not a valid Pandas dtype. We'll treat it as " @@ -92,13 +104,31 @@ def elements_and_dtype(elements, dtype, source=None): f"Passed dtype={dtype!r} is a strategy, but we require a concrete dtype " "here. See https://stackoverflow.com/q/74355937 for workaround patterns." ) - dtype = try_convert(np.dtype, dtype, "dtype") + + pd_dtype_map = { + t.name: t for t in getattr(IntegerDtype, "__subclasses__", lambda: [])() + } + + dtype = pd_dtype_map.get(dtype, dtype) + + if isinstance(dtype, IntegerDtype): + is_na_dtype = True + dtype = np.dtype(dtype.name.lower()) + elif dtype is not None: + is_na_dtype = False + dtype = try_convert(np.dtype, dtype, "dtype") + else: + is_na_dtype = False if elements is None: elements = npst.from_dtype(dtype) + if is_na_dtype: + elements = st.none() | elements elif dtype is not None: def convert_element(value): + if value is None: + return None name = f"draw({prefix}elements)" try: return np.array([value], dtype=dtype)[0] @@ -282,9 +312,17 @@ def series( else: check_strategy(index, "index") - elements, dtype = elements_and_dtype(elements, dtype) + elements, np_dtype = elements_and_dtype(elements, dtype) index_strategy = index + # if it is converted to an object, use object for series type + if ( + np_dtype is not None + and np_dtype.kind == "O" + and not isinstance(dtype, IntegerDtype) + ): + dtype = np_dtype + @st.composite def result(draw): index = draw(index_strategy) @@ -293,13 +331,13 @@ def result(draw): if dtype is not None: result_data = draw( npst.arrays( - dtype=dtype, + dtype=object, elements=elements, shape=len(index), fill=fill, unique=unique, ) - ) + ).tolist() else: result_data = list( draw( @@ -310,9 +348,8 @@ def result(draw): fill=fill, unique=unique, ) - ) + ).tolist() ) - return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name)) else: return pandas.Series( @@ -549,7 +586,7 @@ def row(): column_names.add(c.name) - c.elements, c.dtype = elements_and_dtype(c.elements, c.dtype, label) + c.elements, _ = elements_and_dtype(c.elements, c.dtype, label) if c.dtype is None and rows is not None: raise InvalidArgument( @@ -589,7 +626,9 @@ def just_draw_columns(draw): if columns_without_fill: for c in columns_without_fill: data[c.name] = pandas.Series( - np.zeros(shape=len(index), dtype=c.dtype), index=index + np.zeros(shape=len(index), dtype=object), + index=index, + dtype=c.dtype, ) seen = {c.name: set() for c in columns_without_fill if c.unique} diff --git a/hypothesis-python/tests/pandas/test_argument_validation.py b/hypothesis-python/tests/pandas/test_argument_validation.py index 3987377099..fc7227efb7 100644 --- a/hypothesis-python/tests/pandas/test_argument_validation.py +++ b/hypothesis-python/tests/pandas/test_argument_validation.py @@ -11,11 +11,14 @@ from datetime import datetime import pandas as pd +import pytest from hypothesis import given, strategies as st +from hypothesis.errors import InvalidArgument from hypothesis.extra import pandas as pdst from tests.common.arguments import argument_validation_test, e +from tests.common.debug import find_any from tests.common.utils import checks_deprecated_behaviour BAD_ARGS = [ @@ -30,7 +33,6 @@ e(pdst.data_frames, pdst.columns(1, dtype=float, elements=1)), e(pdst.data_frames, pdst.columns(1, fill=1, dtype=float)), e(pdst.data_frames, pdst.columns(["A", "A"], dtype=float)), - e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)), e(pdst.data_frames, 1), e(pdst.data_frames, [1]), e(pdst.data_frames, pdst.columns(1, dtype="category")), @@ -64,7 +66,6 @@ e(pdst.indexes, dtype="not a dtype"), e(pdst.indexes, elements="not a strategy"), e(pdst.indexes, elements=st.text(), dtype=float), - e(pdst.indexes, elements=st.none(), dtype=int), e(pdst.indexes, elements=st.integers(0, 10), dtype=st.sampled_from([int, float])), e(pdst.indexes, dtype=int, max_size=0, min_size=1), e(pdst.indexes, dtype=int, unique="true"), @@ -77,7 +78,6 @@ e(pdst.series), e(pdst.series, dtype="not a dtype"), e(pdst.series, elements="not a strategy"), - e(pdst.series, elements=st.none(), dtype=int), e(pdst.series, dtype="category"), e(pdst.series, index="not a strategy"), ] @@ -99,3 +99,11 @@ def test_timestamp_as_datetime_bounds(dt): @checks_deprecated_behaviour def test_confusing_object_dtype_aliases(): pdst.series(elements=st.tuples(st.integers()), dtype=tuple).example() + + +def test_pandas_nullable_types_class(): + with pytest.raises( + InvalidArgument, match="Otherwise it would be treated as dtype=object" + ): + st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype) + find_any(st, lambda s: s.isna().any()) diff --git a/hypothesis-python/tests/pandas/test_data_frame.py b/hypothesis-python/tests/pandas/test_data_frame.py index 7949cc6fd1..221b4f50ae 100644 --- a/hypothesis-python/tests/pandas/test_data_frame.py +++ b/hypothesis-python/tests/pandas/test_data_frame.py @@ -9,6 +9,7 @@ # obtain one at https://mozilla.org/MPL/2.0/. import numpy as np +import pandas as pd import pytest from hypothesis import HealthCheck, given, reject, settings, strategies as st @@ -267,3 +268,10 @@ def works_with_object_dtype(df): assert dtype is None with pytest.raises(ValueError, match="Maybe passing dtype=object would help"): works_with_object_dtype() + + +def test_pandas_nullable_types(): + st = pdst.data_frames(pdst.columns(2, dtype=pd.core.arrays.integer.Int8Dtype())) + df = find_any(st, lambda s: s.isna().any().any()) + for s in df.columns: + assert type(df[s].dtype) == pd.core.arrays.integer.Int8Dtype diff --git a/hypothesis-python/tests/pandas/test_series.py b/hypothesis-python/tests/pandas/test_series.py index c796b8175d..a15b3b6a58 100644 --- a/hypothesis-python/tests/pandas/test_series.py +++ b/hypothesis-python/tests/pandas/test_series.py @@ -9,7 +9,7 @@ # obtain one at https://mozilla.org/MPL/2.0/. import numpy as np -import pandas +import pandas as pd from hypothesis import assume, given, strategies as st from hypothesis.extra import numpy as npst, pandas as pdst @@ -25,7 +25,7 @@ def test_can_create_a_series_of_any_dtype(data): # Use raw data to work around pandas bug in repr. See # https://github.com/pandas-dev/pandas/issues/27484 series = data.conjecture_data.draw(pdst.series(dtype=dtype)) - assert series.dtype == pandas.Series([], dtype=dtype).dtype + assert series.dtype == pd.Series([], dtype=dtype).dtype @given(pdst.series(dtype=float, index=pdst.range_indexes(min_size=2, max_size=5))) @@ -61,3 +61,15 @@ def test_unique_series_are_unique(s): @given(pdst.series(dtype="int8", name=st.just("test_name"))) def test_name_passed_on(s): assert s.name == "test_name" + + +def test_pandas_nullable_types(): + st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype()) + e = find_any(st, lambda s: s.isna().any()) + assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype + + +def test_pandas_nullable_types_in_str(): + st = pdst.series(dtype="Int8") + e = find_any(st, lambda s: s.isna().any()) + assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype