diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3cfd82c3d3117..79c3f8f26b1b7 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -514,6 +514,7 @@ def __hash__(self): python_test_goals=[ # doctests "pyspark.testing.utils", + "pyspark.testing.pandasutils", ], ) diff --git a/python/docs/source/reference/pyspark.testing.rst b/python/docs/source/reference/pyspark.testing.rst index 7a6b6cc0d70ab..96b0c72a7bb4b 100644 --- a/python/docs/source/reference/pyspark.testing.rst +++ b/python/docs/source/reference/pyspark.testing.rst @@ -26,4 +26,5 @@ Testing :toctree: api/ assertDataFrameEqual + assertPandasOnSparkEqual assertSchemaEqual diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index f4b643f1d32b3..5ecba294d0c6a 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -164,6 +164,42 @@ "Remote client cannot create a SparkContext. Create SparkSession instead." ] }, + "DIFFERENT_PANDAS_DATAFRAME" : { + "message" : [ + "DataFrames are not almost equal:", + "Left: ", + "", + "Right: ", + "" + ] + }, + "DIFFERENT_PANDAS_INDEX" : { + "message" : [ + "Indices are not almost equal:", + "Left: ", + "", + "Right: ", + "" + ] + }, + "DIFFERENT_PANDAS_MULTIINDEX" : { + "message" : [ + "MultiIndices are not almost equal:", + "Left: ", + "", + "Right: ", + "" + ] + }, + "DIFFERENT_PANDAS_SERIES" : { + "message" : [ + "Series are not almost equal:", + "Left: ", + "", + "Right: ", + "" + ] + }, "DIFFERENT_ROWS" : { "message" : [ "" @@ -233,6 +269,12 @@ "NumPy array input should be of dimensions." ] }, + "INVALID_PANDAS_ON_SPARK_COMPARISON" : { + "message" : [ + "Expected two pandas-on-Spark DataFrames", + "but got actual: and expected: " + ] + }, "INVALID_PANDAS_UDF" : { "message" : [ "Invalid function: " diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index 35ebcf17a0f72..de7b0449daefc 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -16,6 +16,7 @@ # import pandas as pd +from typing import Union from pyspark.pandas.indexes.base import Index from pyspark.pandas.utils import ( @@ -25,8 +26,14 @@ validate_index_loc, validate_mode, ) -from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.pandasutils import ( + PandasOnSparkTestCase, + assertPandasOnSparkEqual, + _assert_pandas_equal, + _assert_pandas_almost_equal, +) from pyspark.testing.sqlutils import SQLTestUtils +from pyspark.errors import PySparkAssertionError some_global_variable = 0 @@ -105,6 +112,168 @@ def test_validate_index_loc(self): with self.assertRaisesRegex(IndexError, err_msg): validate_index_loc(psidx, -4) + def test_assert_df_assertPandasOnSparkEqual(self): + import pyspark.pandas as ps + + psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + psdf2 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=False) + assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=True) + + def test_assertPandasOnSparkEqual_ignoreOrder_default(self): + import pyspark.pandas as ps + + psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + psdf2 = ps.DataFrame({"a": [2, 1, 3], "b": [5, 4, 6], "c": [8, 7, 9]}) + + assertPandasOnSparkEqual(psdf1, psdf2) + + def test_assert_series_assertPandasOnSparkEqual(self): + import pyspark.pandas as ps + + s1 = ps.Series([212.32, 100.0001]) + s2 = ps.Series([212.32, 100.0001]) + + assertPandasOnSparkEqual(s1, s2, checkExact=False) + + def test_assert_index_assertPandasOnSparkEqual(self): + import pyspark.pandas as ps + + s1 = ps.Index([212.300001, 100.000]) + s2 = ps.Index([212.3, 100.0001]) + + assertPandasOnSparkEqual(s1, s2, almost=True) + + def test_assert_error_assertPandasOnSparkEqual(self): + import pyspark.pandas as ps + + list1 = [10, 20, 30] + list2 = [10, 20, 30] + + with self.assertRaises(PySparkAssertionError) as pe: + assertPandasOnSparkEqual(list1, list2) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": f"{ps.DataFrame.__name__}, " + f"{ps.Series.__name__}, " + f"{ps.Index.__name__}", + "arg_name": "actual", + "actual_type": type(list1), + }, + ) + + def test_assert_None_assertPandasOnSparkEqual(self): + psdf1 = None + psdf2 = None + + assertPandasOnSparkEqual(psdf1, psdf2) + + def test_assert_empty_assertPandasOnSparkEqual(self): + import pyspark.pandas as ps + + psdf1 = ps.DataFrame() + psdf2 = ps.DataFrame() + + assertPandasOnSparkEqual(psdf1, psdf2) + + def test_dataframe_error_assert_pandas_equal(self): + pdf1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3]) + pdf2 = pd.DataFrame({"a": [1, 3, 3], "b": [4, 5, 6]}, index=[0, 1, 3]) + + with self.assertRaises(PySparkAssertionError) as pe: + _assert_pandas_equal(pdf1, pdf2, True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": pdf1.to_string(), + "left_dtype": str(pdf1.dtypes), + "right": pdf2.to_string(), + "right_dtype": str(pdf2.dtypes), + }, + ) + + def test_series_error_assert_pandas_equal(self): + series1 = pd.Series([1, 2, 3]) + series2 = pd.Series([4, 5, 6]) + + with self.assertRaises(PySparkAssertionError) as pe: + _assert_pandas_equal(series1, series2, True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_SERIES", + message_parameters={ + "left": series1, + "left_dtype": series1.dtype, + "right": series2, + "right_dtype": series2.dtype, + }, + ) + + def test_index_error_assert_pandas_equal(self): + index1 = pd.Index([1, 2, 3]) + index2 = pd.Index([4, 5, 6]) + + with self.assertRaises(PySparkAssertionError) as pe: + _assert_pandas_equal(index1, index2, True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_INDEX", + message_parameters={ + "left": index1, + "left_dtype": index1.dtype, + "right": index2, + "right_dtype": index2.dtype, + }, + ) + + def test_multiindex_error_assert_pandas_almost_equal(self): + pdf1 = pd.DataFrame({"a": [1, 2], "b": [4, 10]}, index=[0, 1]) + pdf2 = pd.DataFrame({"a": [1, 5, 3], "b": [1, 5, 6]}, index=[0, 1, 3]) + multiindex1 = pd.MultiIndex.from_frame(pdf1) + multiindex2 = pd.MultiIndex.from_frame(pdf2) + + with self.assertRaises(PySparkAssertionError) as pe: + _assert_pandas_almost_equal(multiindex1, multiindex2) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_MULTIINDEX", + message_parameters={ + "left": multiindex1, + "left_dtype": multiindex1.dtype, + "right": multiindex2, + "right_dtype": multiindex2.dtype, + }, + ) + + def test_dataframe_error_assert_pandas_on_spark_almost_equal(self): + import pyspark.pandas as ps + + psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + psdf2 = ps.DataFrame({"a": [1, 2], "b": [4, 5], "c": [7, 8]}) + + with self.assertRaises(PySparkAssertionError) as pe: + assertPandasOnSparkEqual(psdf1, psdf2, almost=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": psdf1.to_string(), + "left_dtype": str(psdf1.dtypes), + "right": psdf2.to_string(), + "right_dtype": str(psdf2.dtypes), + }, + ) + class TestClassForLazyProp: def __init__(self): diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index a1cefe7c840d6..500c314e449ba 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -623,22 +623,47 @@ def test_assert_equal_nulldf(self): assertDataFrameEqual(df1, df2, checkRowOrder=False) assertDataFrameEqual(df1, df2, checkRowOrder=True) - def test_assert_error_pandas_df(self): - import pandas as pd + def test_assert_equal_exact_pandas_df(self): + import pyspark.pandas as ps - df1 = pd.DataFrame(data=[10, 20, 30], columns=["Numbers"]) - df2 = pd.DataFrame(data=[10, 20, 30], columns=["Numbers"]) + df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) + df2 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) + + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + def test_assert_equal_exact_pandas_df(self): + import pyspark.pandas as ps + + df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) + df2 = ps.DataFrame(data=[30, 20, 10], columns=["Numbers"]) + + assertDataFrameEqual(df1, df2) + + def test_assert_equal_approx_pandas_df(self): + import pyspark.pandas as ps + + df1 = ps.DataFrame(data=[10.0001, 20.32, 30.1], columns=["Numbers"]) + df2 = ps.DataFrame(data=[10.0, 20.32, 30.1], columns=["Numbers"]) + + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + def test_assert_error_pandas_pyspark_df(self): + import pyspark.pandas as ps + + df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"]) + df2 = self.spark.createDataFrame([(10,), (11,), (13,)], ["Numbers"]) with self.assertRaises(PySparkAssertionError) as pe: - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) self.check_error( exception=pe.exception, - error_class="INVALID_TYPE_DF_EQUALITY_ARG", + error_class="INVALID_PANDAS_ON_SPARK_COMPARISON", message_parameters={ - "expected_type": DataFrame, - "arg_name": "df", - "actual_type": pd.DataFrame, + "actual_type": type(df1), + "expected_type": type(df2), }, ) @@ -647,15 +672,16 @@ def test_assert_error_pandas_df(self): self.check_error( exception=pe.exception, - error_class="INVALID_TYPE_DF_EQUALITY_ARG", + error_class="INVALID_PANDAS_ON_SPARK_COMPARISON", message_parameters={ - "expected_type": DataFrame, - "arg_name": "df", - "actual_type": pd.DataFrame, + "actual_type": type(df1), + "expected_type": type(df2), }, ) def test_assert_error_non_pyspark_df(self): + import pyspark.pandas as ps + dict1 = {"a": 1, "b": 2} dict2 = {"a": 1, "b": 2} @@ -666,8 +692,8 @@ def test_assert_error_non_pyspark_df(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": DataFrame, - "arg_name": "df", + "expected_type": f"{DataFrame.__name__}, {ps.DataFrame.__name__}", + "arg_name": "actual", "actual_type": type(dict1), }, ) @@ -679,8 +705,8 @@ def test_assert_error_non_pyspark_df(self): exception=pe.exception, error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": DataFrame, - "arg_name": "df", + "expected_type": f"{DataFrame.__name__}, {ps.DataFrame.__name__}", + "arg_name": "actual", "actual_type": type(dict1), }, ) diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py index 88853e925f801..57c206629a80f 100644 --- a/python/pyspark/testing/__init__.py +++ b/python/pyspark/testing/__init__.py @@ -16,4 +16,6 @@ # from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual -__all__ = ["assertDataFrameEqual", "assertSchemaEqual"] +from pyspark.testing.pandasutils import assertPandasOnSparkEqual + +__all__ = ["assertDataFrameEqual", "assertSchemaEqual", "assertPandasOnSparkEqual"] diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index 202603ca5c0a7..4ffe8858396e2 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -19,15 +19,19 @@ import shutil import tempfile import warnings +import pandas as pd from contextlib import contextmanager from distutils.version import LooseVersion +import decimal +from typing import Union -from pyspark import pandas as ps +import pyspark.pandas as ps from pyspark.pandas.frame import DataFrame from pyspark.pandas.indexes import Index from pyspark.pandas.series import Series from pyspark.pandas.utils import SPARK_CONF_ARROW_ENABLED from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.errors import PySparkAssertionError tabulate_requirement_message = None try: @@ -54,153 +58,381 @@ have_plotly = plotly_requirement_message is None -class PandasOnSparkTestUtils: - def convert_str_to_lambda(self, func): - """ - This function coverts `func` str to lambda call - """ - return lambda x: getattr(x, func)() +__all__ = ["assertPandasOnSparkEqual"] - def assertPandasEqual(self, left, right, check_exact=True): - import pandas as pd - from pandas.core.dtypes.common import is_numeric_dtype - from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal - - if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): - try: - if LooseVersion(pd.__version__) >= LooseVersion("1.1"): - kwargs = dict(check_freq=False) - else: - kwargs = dict() - - if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): - # Due to https://github.com/pandas-dev/pandas/issues/35446 - check_exact = ( - check_exact - and all([is_numeric_dtype(dtype) for dtype in left.dtypes]) - and all([is_numeric_dtype(dtype) for dtype in right.dtypes]) - ) - assert_frame_equal( - left, - right, - check_index_type=("equiv" if len(left.index) > 0 else False), - check_column_type=("equiv" if len(left.columns) > 0 else False), - check_exact=check_exact, - **kwargs, +def _assert_pandas_equal( + left: Union[pd.DataFrame, pd.Series, pd.Index], + right: Union[pd.DataFrame, pd.Series, pd.Index], + checkExact: bool, +): + from pandas.core.dtypes.common import is_numeric_dtype + from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal + + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): + # Due to https://github.com/pandas-dev/pandas/issues/35446 + checkExact = ( + checkExact + and all([is_numeric_dtype(dtype) for dtype in left.dtypes]) + and all([is_numeric_dtype(dtype) for dtype in right.dtypes]) ) - except AssertionError as e: - msg = ( - str(e) - + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) - + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + + assert_frame_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_column_type=("equiv" if len(left.columns) > 0 else False), + check_exact=checkExact, + **kwargs, + ) + except AssertionError: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, + ) + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): + # Due to https://github.com/pandas-dev/pandas/issues/35446 + checkExact = ( + checkExact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype) ) - raise AssertionError(msg) from e - elif isinstance(left, pd.Series) and isinstance(right, pd.Series): - try: - if LooseVersion(pd.__version__) >= LooseVersion("1.1"): - kwargs = dict(check_freq=False) - else: - kwargs = dict() - if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): - # Due to https://github.com/pandas-dev/pandas/issues/35446 - check_exact = ( - check_exact - and is_numeric_dtype(left.dtype) - and is_numeric_dtype(right.dtype) - ) - assert_series_equal( - left, - right, - check_index_type=("equiv" if len(left.index) > 0 else False), - check_exact=check_exact, - **kwargs, + assert_series_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_exact=checkExact, + **kwargs, + ) + except AssertionError: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_SERIES", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + try: + if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): + # Due to https://github.com/pandas-dev/pandas/issues/35446 + checkExact = ( + checkExact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype) ) - except AssertionError as e: - msg = ( - str(e) - + "\n\nLeft:\n%s\n%s" % (left, left.dtype) - + "\n\nRight:\n%s\n%s" % (right, right.dtype) + assert_index_equal(left, right, check_exact=checkExact) + except AssertionError: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_INDEX", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + +def _assert_pandas_almost_equal( + left: Union[pd.DataFrame, pd.Series, pd.Index], right: Union[pd.DataFrame, pd.Series, pd.Index] +): + """ + This function checks if given pandas objects approximately same, + which means the conditions below: + - Both objects are nullable + - Compare floats rounding to the number of decimal places, 7 after + dropping missing values (NaN, NaT, None) + """ + # following pandas convention, rtol=1e-5 and atol=1e-8 + rtol = 1e-5 + atol = 1e-8 + + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + if left.shape != right.shape: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, + ) + for lcol, rcol in zip(left.columns, right.columns): + if lcol != rcol: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, ) - raise AssertionError(msg) from e - elif isinstance(left, pd.Index) and isinstance(right, pd.Index): - try: - if LooseVersion(pd.__version__) < LooseVersion("1.1.1"): - # Due to https://github.com/pandas-dev/pandas/issues/35446 - check_exact = ( - check_exact - and is_numeric_dtype(left.dtype) - and is_numeric_dtype(right.dtype) + for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()): + if lnull != rnull: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, ) - assert_index_equal(left, right, check_exact=check_exact) - except AssertionError as e: - msg = ( - str(e) - + "\n\nLeft:\n%s\n%s" % (left, left.dtype) - + "\n\nRight:\n%s\n%s" % (right, right.dtype) + for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()): + if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( + isinstance(rval, float) or isinstance(rval, decimal.Decimal) + ): + if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, + ) + if left.columns.names != right.columns.names: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_DATAFRAME", + message_parameters={ + "left": left.to_string(), + "left_dtype": str(left.dtypes), + "right": right.to_string(), + "right_dtype": str(right.dtypes), + }, + ) + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + if left.name != right.name or len(left) != len(right): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_SERIES", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + for lnull, rnull in zip(left.isnull(), right.isnull()): + if lnull != rnull: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_SERIES", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, ) - raise AssertionError(msg) from e + for lval, rval in zip(left.dropna(), right.dropna()): + if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( + isinstance(rval, float) or isinstance(rval, decimal.Decimal) + ): + if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_SERIES", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): + if len(left) != len(right): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_MULTIINDEX", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + for lval, rval in zip(left, right): + if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( + isinstance(rval, float) or isinstance(rval, decimal.Decimal) + ): + if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_MULTIINDEX", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + if len(left) != len(right): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_INDEX", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + for lnull, rnull in zip(left.isnull(), right.isnull()): + if lnull != rnull: + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_INDEX", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + for lval, rval in zip(left.dropna(), right.dropna()): + if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) and ( + isinstance(rval, float) or isinstance(rval, decimal.Decimal) + ): + if abs(float(lval) - float(rval)) > (atol + rtol * abs(float(rval))): + raise PySparkAssertionError( + error_class="DIFFERENT_PANDAS_INDEX", + message_parameters={ + "left": left, + "left_dtype": left.dtype, + "right": right, + "right_dtype": right.dtype, + }, + ) + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + +def assertPandasOnSparkEqual( + actual: Union[DataFrame, Series, Index], + expected: Union[DataFrame, pd.DataFrame, Series, Index], + checkExact: bool = True, + almost: bool = False, + checkRowOrder: bool = False, +): + r""" + A util function to assert equality between actual (pandas-on-Spark DataFrame) and expected + (pandas-on-Spark or pandas DataFrame). + + .. versionadded:: 3.5.0 + + Parameters + ---------- + actual: pyspark.pandas.frame.DataFrame + The DataFrame that is being compared or tested. + expected: pyspark.pandas.frame.DataFrame or pd.DataFrame + The expected DataFrame, for comparison with the actual result. + checkExact: bool, optional + A flag indicating whether to compare exact equality. + If set to 'True' (default), the data is compared exactly. + If set to 'False', the data is compared less precisely, following pandas assert_frame_equal + approximate comparison (see documentation for more details). + almost: bool, optional + A flag indicating whether to use unittest `assertAlmostEqual` or `assertEqual`. + If set to 'True', the comparison is delegated to `unittest`'s `assertAlmostEqual` + (see documentation for more details). + If set to 'False' (default), the data is compared exactly with `unittest`'s + `assertEqual`. + checkRowOrder : bool, optional + A flag indicating whether the order of rows should be considered in the comparison. + If set to `False` (default), the row order is not taken into account. + If set to `True`, the order of rows is important and will be checked during comparison. + (See Notes) + + Notes + ----- + For `checkRowOrder`, note that pandas-on-Spark DataFrame ordering is non-deterministic, unless + explicitly sorted. + + Examples + -------- + >>> import pyspark.pandas as ps + >>> psdf1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) + >>> psdf2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) + >>> assertPandasOnSparkEqual(psdf1, psdf2) # pass, ps.DataFrames are equal + >>> s1 = ps.Series([212.32, 100.0001]) + >>> s2 = ps.Series([212.32, 100.0]) + >>> assertPandasOnSparkEqual(s1, s2, checkExact=False) # pass, ps.Series are approx equal + >>> s1 = ps.Index([212.300001, 100.000]) + >>> s2 = ps.Index([212.3, 100.0001]) + >>> assertPandasOnSparkEqual(s1, s2, almost=True) # pass, ps.Index obj are almost equal + """ + if actual is None and expected is None: + return True + elif actual is None or expected is None: + return False + + if not isinstance(actual, (DataFrame, Series, Index)): + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": f"{DataFrame.__name__}, {Series.__name__}, {Index.__name__}", + "arg_name": "actual", + "actual_type": type(actual), + }, + ) + elif not isinstance(expected, (DataFrame, pd.DataFrame, Series, Index)): + raise PySparkAssertionError( + error_class="INVALID_TYPE_DF_EQUALITY_ARG", + message_parameters={ + "expected_type": f"{DataFrame.__name__}, " + f"{pd.DataFrame.__name__}, " + f"{Series.__name__}, " + f"{Index.__name__}", + "arg_name": "expected", + "actual_type": type(expected), + }, + ) + else: + actual = actual.to_pandas() + if not isinstance(expected, pd.DataFrame): + expected = expected.to_pandas() + + if not checkRowOrder: + if isinstance(actual, pd.DataFrame) and len(actual.columns) > 0: + actual = actual.sort_values(by=actual.columns[0], ignore_index=True) + if isinstance(expected, pd.DataFrame) and len(expected.columns) > 0: + expected = expected.sort_values(by=expected.columns[0], ignore_index=True) + + if almost: + _assert_pandas_almost_equal(actual, expected) else: - raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + _assert_pandas_equal(actual, expected, checkExact=checkExact) - def assertPandasAlmostEqual(self, left, right): + +class PandasOnSparkTestUtils: + def convert_str_to_lambda(self, func): """ - This function checks if given pandas objects approximately same, - which means the conditions below: - - Both objects are nullable - - Compare floats rounding to the number of decimal places, 7 after - dropping missing values (NaN, NaT, None) + This function converts `func` str to lambda call """ - import pandas as pd + return lambda x: getattr(x, func)() - if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): - msg = ( - "DataFrames are not almost equal: " - + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) - + "\n\nRight:\n%s\n%s" % (right, right.dtypes) - ) - self.assertEqual(left.shape, right.shape, msg=msg) - for lcol, rcol in zip(left.columns, right.columns): - self.assertEqual(lcol, rcol, msg=msg) - for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()): - self.assertEqual(lnull, rnull, msg=msg) - for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()): - self.assertAlmostEqual(lval, rval, msg=msg) - self.assertEqual(left.columns.names, right.columns.names, msg=msg) - elif isinstance(left, pd.Series) and isinstance(right, pd.Series): - msg = ( - "Series are not almost equal: " - + "\n\nLeft:\n%s\n%s" % (left, left.dtype) - + "\n\nRight:\n%s\n%s" % (right, right.dtype) - ) - self.assertEqual(left.name, right.name, msg=msg) - self.assertEqual(len(left), len(right), msg=msg) - for lnull, rnull in zip(left.isnull(), right.isnull()): - self.assertEqual(lnull, rnull, msg=msg) - for lval, rval in zip(left.dropna(), right.dropna()): - self.assertAlmostEqual(lval, rval, msg=msg) - elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): - msg = ( - "MultiIndices are not almost equal: " - + "\n\nLeft:\n%s\n%s" % (left, left.dtype) - + "\n\nRight:\n%s\n%s" % (right, right.dtype) - ) - self.assertEqual(len(left), len(right), msg=msg) - for lval, rval in zip(left, right): - self.assertAlmostEqual(lval, rval, msg=msg) - elif isinstance(left, pd.Index) and isinstance(right, pd.Index): - msg = ( - "Indices are not almost equal: " - + "\n\nLeft:\n%s\n%s" % (left, left.dtype) - + "\n\nRight:\n%s\n%s" % (right, right.dtype) - ) - self.assertEqual(len(left), len(right), msg=msg) - for lnull, rnull in zip(left.isnull(), right.isnull()): - self.assertEqual(lnull, rnull, msg=msg) - for lval, rval in zip(left.dropna(), right.dropna()): - self.assertAlmostEqual(lval, rval, msg=msg) - else: - raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + def assertPandasEqual(self, left, right, check_exact=True): + _assert_pandas_equal(left, right, check_exact) + + def assertPandasAlmostEqual(self, left, right): + _assert_pandas_almost_equal(left, right) def assert_eq(self, left, right, check_exact=True, almost=False): """ @@ -220,9 +452,9 @@ def assert_eq(self, left, right, check_exact=True, almost=False): robj = self._to_pandas(right) if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)): if almost: - self.assertPandasAlmostEqual(lobj, robj) + _assert_pandas_almost_equal(lobj, robj) else: - self.assertPandasEqual(lobj, robj, check_exact=check_exact) + _assert_pandas_equal(lobj, robj, checkExact=check_exact) elif is_list_like(lobj) and is_list_like(robj): self.assertTrue(len(left) == len(right)) for litem, ritem in zip(left, right): diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index b8977b6fffd79..3ba92017fc4c5 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -38,6 +38,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql import Row from pyspark.sql.types import StructType, AtomicType, StructField +import pyspark.pandas as ps have_scipy = False have_numpy = False @@ -314,8 +315,8 @@ def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any): def assertDataFrameEqual( - actual: DataFrame, - expected: Union[DataFrame, List[Row]], + actual: Union[DataFrame, ps.DataFrame], + expected: Union[DataFrame, ps.DataFrame, List[Row]], checkRowOrder: bool = False, rtol: float = 1e-5, atol: float = 1e-8, @@ -324,13 +325,17 @@ def assertDataFrameEqual( A util function to assert equality between `actual` (DataFrame) and `expected` (DataFrame or list of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`. + Supports Spark, Spark Connect, and pandas-on-Spark DataFrames. + For more information about pandas-on-Spark DataFrame equality, see the docs for + `assertPandasOnSparkEqual`. + .. versionadded:: 3.5.0 Parameters ---------- - actual : DataFrame + actual : DataFrame (Spark, Spark Connect, or pandas-on-Spark) The DataFrame that is being compared or tested. - expected : DataFrame or list of Rows + expected : DataFrame (Spark, Spark Connect, or pandas-on-Spark) or list of Rows The expected result of the operation, for comparison with the actual result. checkRowOrder : bool, optional A flag indicating whether the order of rows should be considered in the comparison. @@ -346,10 +351,10 @@ def assertDataFrameEqual( Notes ----- - When assertDataFrameEqual fails, the error message uses the Python `difflib` library to display - a diff log of each row that differs in `actual` and `expected`. + When `assertDataFrameEqual` fails, the error message uses the Python `difflib` library to + display a diff log of each row that differs in `actual` and `expected`. - For checkRowOrder, note that PySpark DataFrame ordering is non-deterministic, unless + For `checkRowOrder`, note that PySpark DataFrame ordering is non-deterministic, unless explicitly sorted. Note that schema equality is checked only when `expected` is a DataFrame (not a list of Rows). @@ -369,7 +374,11 @@ def assertDataFrameEqual( >>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"]) >>> list_of_rows = [Row(1, 1000), Row(2, 3000)] - >>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected are equal + >>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal + >>> import pyspark.pandas as ps + >>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) + >>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) + >>> assertDataFrameEqual(df1, df2) # pass, pandas-on-Spark DataFrames are equal >>> df1 = spark.createDataFrame( ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame( @@ -395,47 +404,76 @@ def assertDataFrameEqual( elif actual is None or expected is None: return False + import pyspark.pandas as ps + from pyspark.testing.pandasutils import assertPandasOnSparkEqual + try: # If Spark Connect dependencies are available, allow Spark Connect DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame - if not isinstance(actual, DataFrame) and not isinstance(actual, ConnectDataFrame): + if isinstance(actual, ps.DataFrame) or isinstance(expected, ps.DataFrame): + # handle pandas DataFrames + if not (isinstance(actual, ps.DataFrame) and isinstance(expected, ps.DataFrame)): + raise PySparkAssertionError( + error_class="INVALID_PANDAS_ON_SPARK_COMPARISON", + message_parameters={ + "actual_type": type(actual), + "expected_type": type(expected), + }, + ) + # assert approximate equality for float data + return assertPandasOnSparkEqual( + actual, expected, checkExact=False, checkRowOrder=checkRowOrder + ) + elif not isinstance(actual, (DataFrame, ps.DataFrame, ConnectDataFrame)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": DataFrame, - "arg_name": "df", + "expected_type": DataFrame.__name__, + "arg_name": "actual", "actual_type": type(actual), }, ) - elif ( - not isinstance(expected, DataFrame) - and not isinstance(expected, ConnectDataFrame) - and not isinstance(expected, List) - ): + elif not isinstance(expected, (DataFrame, ps.DataFrame, ConnectDataFrame, list)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, List[Row]], + "expected_type": f"{DataFrame.__name__}, {List[Row].__name__}", "arg_name": "expected", "actual_type": type(expected), }, ) except Exception: - if not isinstance(actual, DataFrame): + if isinstance(actual, ps.DataFrame) or isinstance(expected, ps.DataFrame): + # handle pandas DataFrames + if not (isinstance(actual, ps.DataFrame) and isinstance(expected, ps.DataFrame)): + raise PySparkAssertionError( + error_class="INVALID_PANDAS_ON_SPARK_COMPARISON", + message_parameters={ + "actual_type": type(actual), + "expected_type": type(expected), + }, + ) + # assert approximate equality for float data + return assertPandasOnSparkEqual( + actual, expected, checkExact=False, checkRowOrder=checkRowOrder + ) + elif not isinstance(actual, (DataFrame, ps.DataFrame)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": DataFrame, - "arg_name": "df", + "expected_type": f"{DataFrame.__name__}, {ps.DataFrame.__name__}", + "arg_name": "actual", "actual_type": type(actual), }, ) - elif not isinstance(expected, DataFrame) and not isinstance(expected, List): + elif not isinstance(expected, (DataFrame, ps.DataFrame, list)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ - "expected_type": Union[DataFrame, List[Row]], + "expected_type": f"{DataFrame.__name__}, " + f"{ps.DataFrame.__name__}, " + f"{List[Row].__name__}", "arg_name": "expected", "actual_type": type(expected), },