Skip to content

Commit

Permalink
add support to pandas nullable dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Cheukting committed Apr 25, 2023
1 parent 154577c commit 54333d1
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 13 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: minor

This release adds support for `nullable pandas dtypes <https://pandas.pydata.org/docs/user_guide/integer_na.html>`__
in :func:`~hypothesis.extra.pandas` (:issue:`3604`).
Thanks to Cheuk Ting Ho for implementing this at the PyCon sprints!
55 changes: 47 additions & 8 deletions hypothesis-python/src/hypothesis/extra/pandas/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def is_categorical_dtype(dt):
return dt == "category"


try:
from pandas.core.arrays.integer import IntegerDtype
except ImportError:
IntegerDtype = ()


def dtype_for_elements_strategy(s):
return st.shared(
s.map(lambda x: pandas.Series([x]).dtype),
Expand Down Expand Up @@ -79,6 +85,12 @@ def elements_and_dtype(elements, dtype, source=None):
f"{prefix}dtype is categorical, which is currently unsupported"
)

if isinstance(dtype, type) and issubclass(dtype, IntegerDtype):
raise InvalidArgument(
f"Passed dtype={dtype!r} is a dtype class, please pass in an instance of this class."
"Otherwise it would be treated as dtype=object"
)

if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object:
note_deprecation(
f"Passed dtype={dtype!r} is not a valid Pandas dtype. We'll treat it as "
Expand All @@ -92,13 +104,31 @@ def elements_and_dtype(elements, dtype, source=None):
f"Passed dtype={dtype!r} is a strategy, but we require a concrete dtype "
"here. See https://stackoverflow.com/q/74355937 for workaround patterns."
)
dtype = try_convert(np.dtype, dtype, "dtype")

pd_dtype_map = {
t.name: t for t in getattr(IntegerDtype, "__subclasses__", lambda: [])()
}

dtype = pd_dtype_map.get(dtype, dtype)

if isinstance(dtype, IntegerDtype):
is_na_dtype = True
dtype = np.dtype(dtype.name.lower())
elif dtype is not None:
is_na_dtype = False
dtype = try_convert(np.dtype, dtype, "dtype")
else:
is_na_dtype = False

if elements is None:
elements = npst.from_dtype(dtype)
if is_na_dtype:
elements = st.none() | elements
elif dtype is not None:

def convert_element(value):
if value is None:
return None
name = f"draw({prefix}elements)"
try:
return np.array([value], dtype=dtype)[0]
Expand Down Expand Up @@ -282,9 +312,17 @@ def series(
else:
check_strategy(index, "index")

elements, dtype = elements_and_dtype(elements, dtype)
elements, np_dtype = elements_and_dtype(elements, dtype)
index_strategy = index

# if it is converted to an object, use object for series type
if (
np_dtype is not None
and np_dtype.kind == "O"
and not isinstance(dtype, IntegerDtype)
):
dtype = np_dtype

@st.composite
def result(draw):
index = draw(index_strategy)
Expand All @@ -293,13 +331,13 @@ def result(draw):
if dtype is not None:
result_data = draw(
npst.arrays(
dtype=dtype,
dtype=object,
elements=elements,
shape=len(index),
fill=fill,
unique=unique,
)
)
).tolist()
else:
result_data = list(
draw(
Expand All @@ -310,9 +348,8 @@ def result(draw):
fill=fill,
unique=unique,
)
)
).tolist()
)

return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name))
else:
return pandas.Series(
Expand Down Expand Up @@ -549,7 +586,7 @@ def row():

column_names.add(c.name)

c.elements, c.dtype = elements_and_dtype(c.elements, c.dtype, label)
c.elements, _ = elements_and_dtype(c.elements, c.dtype, label)

if c.dtype is None and rows is not None:
raise InvalidArgument(
Expand Down Expand Up @@ -589,7 +626,9 @@ def just_draw_columns(draw):
if columns_without_fill:
for c in columns_without_fill:
data[c.name] = pandas.Series(
np.zeros(shape=len(index), dtype=c.dtype), index=index
np.zeros(shape=len(index), dtype=object),
index=index,
dtype=c.dtype,
)
seen = {c.name: set() for c in columns_without_fill if c.unique}

Expand Down
14 changes: 11 additions & 3 deletions hypothesis-python/tests/pandas/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from datetime import datetime

import pandas as pd
import pytest

from hypothesis import given, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra import pandas as pdst

from tests.common.arguments import argument_validation_test, e
from tests.common.debug import find_any
from tests.common.utils import checks_deprecated_behaviour

BAD_ARGS = [
Expand All @@ -30,7 +33,6 @@
e(pdst.data_frames, pdst.columns(1, dtype=float, elements=1)),
e(pdst.data_frames, pdst.columns(1, fill=1, dtype=float)),
e(pdst.data_frames, pdst.columns(["A", "A"], dtype=float)),
e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)),
e(pdst.data_frames, 1),
e(pdst.data_frames, [1]),
e(pdst.data_frames, pdst.columns(1, dtype="category")),
Expand Down Expand Up @@ -64,7 +66,6 @@
e(pdst.indexes, dtype="not a dtype"),
e(pdst.indexes, elements="not a strategy"),
e(pdst.indexes, elements=st.text(), dtype=float),
e(pdst.indexes, elements=st.none(), dtype=int),
e(pdst.indexes, elements=st.integers(0, 10), dtype=st.sampled_from([int, float])),
e(pdst.indexes, dtype=int, max_size=0, min_size=1),
e(pdst.indexes, dtype=int, unique="true"),
Expand All @@ -77,7 +78,6 @@
e(pdst.series),
e(pdst.series, dtype="not a dtype"),
e(pdst.series, elements="not a strategy"),
e(pdst.series, elements=st.none(), dtype=int),
e(pdst.series, dtype="category"),
e(pdst.series, index="not a strategy"),
]
Expand All @@ -99,3 +99,11 @@ def test_timestamp_as_datetime_bounds(dt):
@checks_deprecated_behaviour
def test_confusing_object_dtype_aliases():
pdst.series(elements=st.tuples(st.integers()), dtype=tuple).example()


def test_pandas_nullable_types_class():
with pytest.raises(
InvalidArgument, match="Otherwise it would be treated as dtype=object"
):
st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype)
find_any(st, lambda s: s.isna().any())
8 changes: 8 additions & 0 deletions hypothesis-python/tests/pandas/test_data_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import numpy as np
import pandas as pd
import pytest

from hypothesis import HealthCheck, given, reject, settings, strategies as st
Expand Down Expand Up @@ -267,3 +268,10 @@ def works_with_object_dtype(df):
assert dtype is None
with pytest.raises(ValueError, match="Maybe passing dtype=object would help"):
works_with_object_dtype()


def test_pandas_nullable_types():
st = pdst.data_frames(pdst.columns(2, dtype=pd.core.arrays.integer.Int8Dtype()))
df = find_any(st, lambda s: s.isna().any().any())
for s in df.columns:
assert type(df[s].dtype) == pd.core.arrays.integer.Int8Dtype
16 changes: 14 additions & 2 deletions hypothesis-python/tests/pandas/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import numpy as np
import pandas
import pandas as pd

from hypothesis import assume, given, strategies as st
from hypothesis.extra import numpy as npst, pandas as pdst
Expand All @@ -25,7 +25,7 @@ def test_can_create_a_series_of_any_dtype(data):
# Use raw data to work around pandas bug in repr. See
# https://github.com/pandas-dev/pandas/issues/27484
series = data.conjecture_data.draw(pdst.series(dtype=dtype))
assert series.dtype == pandas.Series([], dtype=dtype).dtype
assert series.dtype == pd.Series([], dtype=dtype).dtype


@given(pdst.series(dtype=float, index=pdst.range_indexes(min_size=2, max_size=5)))
Expand Down Expand Up @@ -61,3 +61,15 @@ def test_unique_series_are_unique(s):
@given(pdst.series(dtype="int8", name=st.just("test_name")))
def test_name_passed_on(s):
assert s.name == "test_name"


def test_pandas_nullable_types():
st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype())
e = find_any(st, lambda s: s.isna().any())
assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype


def test_pandas_nullable_types_in_str():
st = pdst.series(dtype="Int8")
e = find_any(st, lambda s: s.isna().any())
assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype

0 comments on commit 54333d1

Please sign in to comment.