Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nullable pandas dtypes #3623

Merged
merged 7 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: minor

This release adds support for `nullable pandas dtypes <https://pandas.pydata.org/docs/user_guide/integer_na.html>`__
in :func:`~hypothesis.extra.pandas` (:issue:`3604`).
Thanks to Cheuk Ting Ho for implementing this at the PyCon sprints!
58 changes: 45 additions & 13 deletions hypothesis-python/src/hypothesis/extra/pandas/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def is_categorical_dtype(dt):
return dt == "category"


try:
from pandas.core.arrays.integer import IntegerDtype
except ImportError:
IntegerDtype = ()


def dtype_for_elements_strategy(s):
return st.shared(
s.map(lambda x: pandas.Series([x]).dtype),
Expand Down Expand Up @@ -79,6 +85,12 @@ def elements_and_dtype(elements, dtype, source=None):
f"{prefix}dtype is categorical, which is currently unsupported"
)

if isinstance(dtype, type) and issubclass(dtype, IntegerDtype):
raise InvalidArgument(
f"Passed dtype={dtype!r} is a dtype class, please pass in an instance of this class."
"Otherwise it would be treated as dtype=object"
)

if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object:
note_deprecation(
f"Passed dtype={dtype!r} is not a valid Pandas dtype. We'll treat it as "
Expand All @@ -92,25 +104,36 @@ def elements_and_dtype(elements, dtype, source=None):
f"Passed dtype={dtype!r} is a strategy, but we require a concrete dtype "
"here. See https://stackoverflow.com/q/74355937 for workaround patterns."
)
dtype = try_convert(np.dtype, dtype, "dtype")

_get_subclasses = getattr(IntegerDtype, "__subclasses__", list)
dtype = {t.name: t() for t in _get_subclasses()}.get(dtype, dtype)

if isinstance(dtype, IntegerDtype):
is_na_dtype = True
dtype = np.dtype(dtype.name.lower())
elif dtype is not None:
is_na_dtype = False
dtype = try_convert(np.dtype, dtype, "dtype")
else:
is_na_dtype = False

if elements is None:
elements = npst.from_dtype(dtype)
if is_na_dtype:
elements = st.none() | elements
elif dtype is not None:

def convert_element(value):
if is_na_dtype and value is None:
return None
name = f"draw({prefix}elements)"
try:
return np.array([value], dtype=dtype)[0]
except TypeError:
except (TypeError, ValueError):
raise InvalidArgument(
"Cannot convert %s=%r of type %s to dtype %s"
% (name, value, type(value).__name__, dtype.str)
) from None
except ValueError:
raise InvalidArgument(
f"Cannot convert {name}={value!r} to type {dtype.str}"
) from None

elements = elements.map(convert_element)
assert elements is not None
Expand Down Expand Up @@ -282,9 +305,17 @@ def series(
else:
check_strategy(index, "index")

elements, dtype = elements_and_dtype(elements, dtype)
elements, np_dtype = elements_and_dtype(elements, dtype)
index_strategy = index

# if it is converted to an object, use object for series type
if (
np_dtype is not None
and np_dtype.kind == "O"
and not isinstance(dtype, IntegerDtype)
):
dtype = np_dtype

@st.composite
def result(draw):
index = draw(index_strategy)
Expand All @@ -293,13 +324,13 @@ def result(draw):
if dtype is not None:
result_data = draw(
npst.arrays(
dtype=dtype,
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
)
).tolist()
else:
result_data = list(
draw(
Expand All @@ -310,9 +341,8 @@ def result(draw):
fill=fill,
unique=unique,
)
)
).tolist()
)

return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name))
else:
return pandas.Series(
Expand Down Expand Up @@ -549,7 +579,7 @@ def row():

column_names.add(c.name)

c.elements, c.dtype = elements_and_dtype(c.elements, c.dtype, label)
c.elements, _ = elements_and_dtype(c.elements, c.dtype, label)

if c.dtype is None and rows is not None:
raise InvalidArgument(
Expand Down Expand Up @@ -589,7 +619,9 @@ def just_draw_columns(draw):
if columns_without_fill:
for c in columns_without_fill:
data[c.name] = pandas.Series(
np.zeros(shape=len(index), dtype=c.dtype), index=index
np.zeros(shape=len(index), dtype=object),
index=index,
dtype=c.dtype,
)
seen = {c.name: set() for c in columns_without_fill if c.unique}

Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/tests/common/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def e(a, *args, **kwargs):


def e_to_str(elt):
f, args, kwargs = elt
f, args, kwargs = getattr(elt, "values", elt)
bits = list(map(repr, args))
bits.extend(sorted(f"{k}={v!r}" for k, v in kwargs.items()))
return "{}({})".format(f.__name__, ", ".join(bits))
Expand Down
33 changes: 30 additions & 3 deletions hypothesis-python/tests/pandas/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
from datetime import datetime

import pandas as pd
import pytest

from hypothesis import given, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra import pandas as pdst
from hypothesis.extra.pandas.impl import IntegerDtype

from tests.common.arguments import argument_validation_test, e
from tests.common.debug import find_any
from tests.common.utils import checks_deprecated_behaviour

BAD_ARGS = [
Expand All @@ -30,7 +34,11 @@
e(pdst.data_frames, pdst.columns(1, dtype=float, elements=1)),
e(pdst.data_frames, pdst.columns(1, fill=1, dtype=float)),
e(pdst.data_frames, pdst.columns(["A", "A"], dtype=float)),
e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)),
pytest.param(
*e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)),
marks=pytest.mark.skipif(IntegerDtype, reason="works with integer NA"),
),
e(pdst.data_frames, pdst.columns(1, elements=st.text(), dtype=int)),
e(pdst.data_frames, 1),
e(pdst.data_frames, [1]),
e(pdst.data_frames, pdst.columns(1, dtype="category")),
Expand Down Expand Up @@ -64,7 +72,11 @@
e(pdst.indexes, dtype="not a dtype"),
e(pdst.indexes, elements="not a strategy"),
e(pdst.indexes, elements=st.text(), dtype=float),
e(pdst.indexes, elements=st.none(), dtype=int),
pytest.param(
*e(pdst.indexes, elements=st.none(), dtype=int),
marks=pytest.mark.skipif(IntegerDtype, reason="works with integer NA"),
),
e(pdst.indexes, elements=st.text(), dtype=int),
e(pdst.indexes, elements=st.integers(0, 10), dtype=st.sampled_from([int, float])),
e(pdst.indexes, dtype=int, max_size=0, min_size=1),
e(pdst.indexes, dtype=int, unique="true"),
Expand All @@ -77,7 +89,11 @@
e(pdst.series),
e(pdst.series, dtype="not a dtype"),
e(pdst.series, elements="not a strategy"),
e(pdst.series, elements=st.none(), dtype=int),
pytest.param(
*e(pdst.series, elements=st.none(), dtype=int),
marks=pytest.mark.skipif(IntegerDtype, reason="works with integer NA"),
),
e(pdst.series, elements=st.text(), dtype=int),
e(pdst.series, dtype="category"),
e(pdst.series, index="not a strategy"),
]
Expand All @@ -99,3 +115,14 @@ def test_timestamp_as_datetime_bounds(dt):
@checks_deprecated_behaviour
def test_confusing_object_dtype_aliases():
pdst.series(elements=st.tuples(st.integers()), dtype=tuple).example()


@pytest.mark.skipif(
not IntegerDtype, reason="Nullable types not available in this version of Pandas"
)
def test_pandas_nullable_types_class():
with pytest.raises(
InvalidArgument, match="Otherwise it would be treated as dtype=object"
):
st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype)
find_any(st, lambda s: s.isna().any())
12 changes: 12 additions & 0 deletions hypothesis-python/tests/pandas/test_data_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import numpy as np
import pandas as pd
import pytest

from hypothesis import HealthCheck, given, reject, settings, strategies as st
from hypothesis.extra import numpy as npst, pandas as pdst
from hypothesis.extra.pandas.impl import IntegerDtype

from tests.common.debug import find_any
from tests.pandas.helpers import supported_by_pandas
Expand Down Expand Up @@ -267,3 +269,13 @@ def works_with_object_dtype(df):
assert dtype is None
with pytest.raises(ValueError, match="Maybe passing dtype=object would help"):
works_with_object_dtype()


@pytest.mark.skipif(
not IntegerDtype, reason="Nullable types not available in this version of Pandas"
)
def test_pandas_nullable_types():
st = pdst.data_frames(pdst.columns(2, dtype=pd.core.arrays.integer.Int8Dtype()))
df = find_any(st, lambda s: s.isna().any().any())
for s in df.columns:
assert type(df[s].dtype) == pd.core.arrays.integer.Int8Dtype
28 changes: 25 additions & 3 deletions hypothesis-python/tests/pandas/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import numpy as np
import pandas
import pandas as pd
import pytest

from hypothesis import assume, given, strategies as st
from hypothesis.extra import numpy as npst, pandas as pdst
from hypothesis.extra.pandas.impl import IntegerDtype

from tests.common.debug import find_any
from tests.common.debug import assert_all_examples, assert_no_examples, find_any
from tests.pandas.helpers import supported_by_pandas


Expand All @@ -25,7 +27,7 @@ def test_can_create_a_series_of_any_dtype(data):
# Use raw data to work around pandas bug in repr. See
# https://github.com/pandas-dev/pandas/issues/27484
series = data.conjecture_data.draw(pdst.series(dtype=dtype))
assert series.dtype == pandas.Series([], dtype=dtype).dtype
assert series.dtype == pd.Series([], dtype=dtype).dtype


@given(pdst.series(dtype=float, index=pdst.range_indexes(min_size=2, max_size=5)))
Expand Down Expand Up @@ -61,3 +63,23 @@ def test_unique_series_are_unique(s):
@given(pdst.series(dtype="int8", name=st.just("test_name")))
def test_name_passed_on(s):
assert s.name == "test_name"


@pytest.mark.skipif(
not IntegerDtype, reason="Nullable types not available in this version of Pandas"
)
@pytest.mark.parametrize(
"dtype", ["Int8", pd.core.arrays.integer.Int8Dtype() if IntegerDtype else None]
)
def test_pandas_nullable_types(dtype):
assert_no_examples(
pdst.series(dtype=dtype, elements=st.just(0)),
lambda s: s.isna().any(),
)
assert_all_examples(
pdst.series(dtype=dtype, elements=st.none()),
lambda s: s.isna().all(),
)
find_any(pdst.series(dtype=dtype), lambda s: not s.isna().any())
e = find_any(pdst.series(dtype=dtype), lambda s: s.isna().any())
assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype