Skip to content

Commit

Permalink
[SPARK-49531][PYTHON][CONNECT] Support line plot with plotly backend
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support line plot with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations, such as line plots, by leveraging libraries like Plotly. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes.

```python
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> sdf = spark.createDataFrame(data, columns)
>>> sdf.show()
+--------+-------+---------+
|category|int_val|float_val|
+--------+-------+---------+
|       A|     10|      1.5|
|       B|     30|      2.5|
|       C|     20|      3.5|
+--------+-------+---------+

>>> f = sdf.plot(kind="line", x="category", y="int_val")
>>> f.show()  # see below
>>> g = sdf.plot.line(x="category", y=["int_val", "float_val"])
>>> g.show()  # see below
```
`f.show()`:
![newplot](https://github.com/user-attachments/assets/ebd50bbc-0dd1-437f-ae0c-0b4de8f3c722)

`g.show()`:
![newplot (1)](https://github.com/user-attachments/assets/46d28840-a147-428f-8d88-d424aa76ad06)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48139 from xinrong-meng/plot_line_w_dep.

Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
xinrong-meng authored and dongjoon-hyun committed Sep 20, 2024
1 parent 3d8c078 commit 22a7edc
Show file tree
Hide file tree
Showing 21 changed files with 529 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_python_connect.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
python packaging/connect/setup.py sdist
cd dist
pip install pyspark*connect-*.tar.gz
pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting
pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8'
- name: Run tests
env:
SPARK_TESTING: 1
Expand Down
2 changes: 1 addition & 1 deletion dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pyarrow>=10.0.0
six==1.16.0
pandas>=2.0.0
scipy
plotly
plotly>=4.8
mlflow>=2.3.1
scikit-learn
matplotlib
Expand Down
4 changes: 4 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ def __hash__(self):
"pyspark.sql.tests.test_udtf",
"pyspark.sql.tests.test_utils",
"pyspark.sql.tests.test_resources",
"pyspark.sql.tests.plot.test_frame_plot",
"pyspark.sql.tests.plot.test_frame_plot_plotly",
],
)

Expand Down Expand Up @@ -1051,6 +1053,8 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
"pyspark.sql.tests.connect.test_parity_python_datasource",
"pyspark.sql.tests.connect.test_parity_python_streaming_datasource",
"pyspark.sql.tests.connect.test_parity_frame_plot",
"pyspark.sql.tests.connect.test_parity_frame_plot_plotly",
"pyspark.sql.tests.connect.test_utils",
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_artifact_localcluster",
Expand Down
1 change: 1 addition & 0 deletions python/docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ Package Supported version Note
Additional libraries that enhance functionality but are not included in the installation packages:

- **memory-profiler**: Used for PySpark UDF memory profiling, ``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``.
- **plotly**: Used for PySpark plotting, ``DataFrame.plot``.

Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set and refer to |downloading|_.

Expand Down
1 change: 1 addition & 0 deletions python/packaging/classic/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def run(self):
"pyspark.sql.connect.streaming.worker",
"pyspark.sql.functions",
"pyspark.sql.pandas",
"pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
"pyspark.sql.worker",
Expand Down
2 changes: 2 additions & 0 deletions python/packaging/connect/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"pyspark.sql.tests.connect.client",
"pyspark.sql.tests.connect.shell",
"pyspark.sql.tests.pandas",
"pyspark.sql.tests.plot",
"pyspark.sql.tests.streaming",
"pyspark.ml.tests.connect",
"pyspark.pandas.tests",
Expand Down Expand Up @@ -161,6 +162,7 @@
"pyspark.sql.connect.streaming.worker",
"pyspark.sql.functions",
"pyspark.sql.pandas",
"pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
"pyspark.sql.worker",
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,11 @@
"Function `<func_name>` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments."
]
},
"UNSUPPORTED_PLOT_BACKEND": {
"message": [
"`<backend>` is not supported, it should be one of the values from <supported_backends>"
]
},
"UNSUPPORTED_SIGNATURE": {
"message": [
"Unsupported signature: <signature>."
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin

try:
from pyspark.sql.plot import PySparkPlotAccessor
except ImportError:
PySparkPlotAccessor = None # type: ignore

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
Expand Down Expand Up @@ -1862,6 +1867,10 @@ def executionInfo(self) -> Optional["ExecutionInfo"]:
messageParameters={"member": "queryExecution"},
)

@property
def plot(self) -> PySparkPlotAccessor:
return PySparkPlotAccessor(self)


class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]

try:
from pyspark.sql.plot import PySparkPlotAccessor
except ImportError:
PySparkPlotAccessor = None # type: ignore

if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
Expand Down Expand Up @@ -2239,6 +2243,10 @@ def rdd(self) -> "RDD[Row]":
def executionInfo(self) -> Optional["ExecutionInfo"]:
return self._execution_info

@property
def plot(self) -> PySparkPlotAccessor:
return PySparkPlotAccessor(self)


class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pyspark.sql.types import StructType, Row
from pyspark.sql.utils import dispatch_df_method


if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
Expand All @@ -65,6 +66,7 @@
ArrowMapIterFunction,
DataFrameLike as PandasDataFrameLike,
)
from pyspark.sql.plot import PySparkPlotAccessor
from pyspark.sql.metrics import ExecutionInfo


Expand Down Expand Up @@ -6394,6 +6396,32 @@ def executionInfo(self) -> Optional["ExecutionInfo"]:
"""
...

@property
def plot(self) -> "PySparkPlotAccessor":
"""
Returns a :class:`PySparkPlotAccessor` for plotting functions.
.. versionadded:: 4.0.0
Returns
-------
:class:`PySparkPlotAccessor`
Notes
-----
This API is experimental.
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> type(df.plot)
<class 'pyspark.sql.plot.core.PySparkPlotAccessor'>
>>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
"""
...


class DataFrameNaFunctions:
"""Functionality for working with missing data in :class:`DataFrame`.
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This package includes the plotting APIs for PySpark DataFrame.
"""
from pyspark.sql.plot.core import * # noqa: F403, F401
135 changes: 135 additions & 0 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, TYPE_CHECKING, Optional, Union
from types import ModuleType
from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.sql.utils import require_minimum_plotly_version


if TYPE_CHECKING:
from pyspark.sql import DataFrame
import pandas as pd
from plotly.graph_objs import Figure


class PySparkTopNPlotBase:
def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession

session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())

max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)
pdf = sdf.limit(max_rows + 1).toPandas()

self.partial = False
if len(pdf) > max_rows:
self.partial = True
pdf = pdf.iloc[:max_rows]

return pdf


class PySparkSampledPlotBase:
def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession

session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())

sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio")
max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)

if sample_ratio is None:
fraction = 1 / (sdf.count() / max_rows)
fraction = min(1.0, fraction)
else:
fraction = float(sample_ratio)

sampled_sdf = sdf.sample(fraction=fraction)
pdf = sampled_sdf.toPandas()

return pdf


class PySparkPlotAccessor:
plot_data_map = {
"line": PySparkSampledPlotBase().get_sampled,
}
_backends = {} # type: ignore[var-annotated]

def __init__(self, data: "DataFrame"):
self.data = data

def __call__(
self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
) -> "Figure":
plot_backend = PySparkPlotAccessor._get_plot_backend(backend)

return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)

@staticmethod
def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
backend = backend or "plotly"

if backend in PySparkPlotAccessor._backends:
return PySparkPlotAccessor._backends[backend]

if backend == "plotly":
require_minimum_plotly_version()
else:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_BACKEND",
messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])},
)
from pyspark.sql.plot import plotly as module

return module

def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
"""
Plot DataFrame as lines.
Parameters
----------
x : str
Name of column to use for the horizontal axis.
y : str or list of str
Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
>>> columns = ["category", "int_val", "float_val"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.line(x="category", y="int_val") # doctest: +SKIP
>>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
"""
return self(kind="line", x=x, y=y, **kwargs)
30 changes: 30 additions & 0 deletions python/pyspark/sql/plot/plotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import TYPE_CHECKING, Any

from pyspark.sql.plot import PySparkPlotAccessor

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from plotly.graph_objs import Figure


def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
import plotly

return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_frame_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin


class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase):
pass


if __name__ == "__main__":
import unittest
from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit 22a7edc

Please sign in to comment.