-
Notifications
You must be signed in to change notification settings - Fork 28.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-49531][PYTHON][CONNECT] Support line plot with plotly backend
### What changes were proposed in this pull request? Support line plot with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations, such as line plots, by leveraging libraries like Plotly. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. ```python >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> sdf = spark.createDataFrame(data, columns) >>> sdf.show() +--------+-------+---------+ |category|int_val|float_val| +--------+-------+---------+ | A| 10| 1.5| | B| 30| 2.5| | C| 20| 3.5| +--------+-------+---------+ >>> f = sdf.plot(kind="line", x="category", y="int_val") >>> f.show() # see below >>> g = sdf.plot.line(x="category", y=["int_val", "float_val"]) >>> g.show() # see below ``` `f.show()`: ![newplot](https://github.com/user-attachments/assets/ebd50bbc-0dd1-437f-ae0c-0b4de8f3c722) `g.show()`: ![newplot (1)](https://github.com/user-attachments/assets/46d28840-a147-428f-8d88-d424aa76ad06) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48139 from xinrong-meng/plot_line_w_dep. Authored-by: Xinrong Meng <xinrong@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
- Loading branch information
1 parent
3d8c078
commit 22a7edc
Showing
21 changed files
with
529 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This package includes the plotting APIs for PySpark DataFrame. | ||
""" | ||
from pyspark.sql.plot.core import * # noqa: F403, F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import Any, TYPE_CHECKING, Optional, Union | ||
from types import ModuleType | ||
from pyspark.errors import PySparkRuntimeError, PySparkValueError | ||
from pyspark.sql.utils import require_minimum_plotly_version | ||
|
||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql import DataFrame | ||
import pandas as pd | ||
from plotly.graph_objs import Figure | ||
|
||
|
||
class PySparkTopNPlotBase: | ||
def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": | ||
from pyspark.sql import SparkSession | ||
|
||
session = SparkSession.getActiveSession() | ||
if session is None: | ||
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) | ||
|
||
max_rows = int( | ||
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] | ||
) | ||
pdf = sdf.limit(max_rows + 1).toPandas() | ||
|
||
self.partial = False | ||
if len(pdf) > max_rows: | ||
self.partial = True | ||
pdf = pdf.iloc[:max_rows] | ||
|
||
return pdf | ||
|
||
|
||
class PySparkSampledPlotBase: | ||
def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": | ||
from pyspark.sql import SparkSession | ||
|
||
session = SparkSession.getActiveSession() | ||
if session is None: | ||
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) | ||
|
||
sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio") | ||
max_rows = int( | ||
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] | ||
) | ||
|
||
if sample_ratio is None: | ||
fraction = 1 / (sdf.count() / max_rows) | ||
fraction = min(1.0, fraction) | ||
else: | ||
fraction = float(sample_ratio) | ||
|
||
sampled_sdf = sdf.sample(fraction=fraction) | ||
pdf = sampled_sdf.toPandas() | ||
|
||
return pdf | ||
|
||
|
||
class PySparkPlotAccessor: | ||
plot_data_map = { | ||
"line": PySparkSampledPlotBase().get_sampled, | ||
} | ||
_backends = {} # type: ignore[var-annotated] | ||
|
||
def __init__(self, data: "DataFrame"): | ||
self.data = data | ||
|
||
def __call__( | ||
self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any | ||
) -> "Figure": | ||
plot_backend = PySparkPlotAccessor._get_plot_backend(backend) | ||
|
||
return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs) | ||
|
||
@staticmethod | ||
def _get_plot_backend(backend: Optional[str] = None) -> ModuleType: | ||
backend = backend or "plotly" | ||
|
||
if backend in PySparkPlotAccessor._backends: | ||
return PySparkPlotAccessor._backends[backend] | ||
|
||
if backend == "plotly": | ||
require_minimum_plotly_version() | ||
else: | ||
raise PySparkValueError( | ||
errorClass="UNSUPPORTED_PLOT_BACKEND", | ||
messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])}, | ||
) | ||
from pyspark.sql.plot import plotly as module | ||
|
||
return module | ||
|
||
def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": | ||
""" | ||
Plot DataFrame as lines. | ||
Parameters | ||
---------- | ||
x : str | ||
Name of column to use for the horizontal axis. | ||
y : str or list of str | ||
Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. | ||
**kwargs : optional | ||
Additional keyword arguments. | ||
Returns | ||
------- | ||
:class:`plotly.graph_objs.Figure` | ||
Examples | ||
-------- | ||
>>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] | ||
>>> columns = ["category", "int_val", "float_val"] | ||
>>> df = spark.createDataFrame(data, columns) | ||
>>> df.plot.line(x="category", y="int_val") # doctest: +SKIP | ||
>>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP | ||
""" | ||
return self(kind="line", x=x, y=y, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from pyspark.sql.plot import PySparkPlotAccessor | ||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql import DataFrame | ||
from plotly.graph_objs import Figure | ||
|
||
|
||
def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": | ||
import plotly | ||
|
||
return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) |
36 changes: 36 additions & 0 deletions
36
python/pyspark/sql/tests/connect/test_parity_frame_plot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from pyspark.testing.connectutils import ReusedConnectTestCase | ||
from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin | ||
|
||
|
||
class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase): | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
import unittest | ||
from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401 | ||
|
||
try: | ||
import xmlrunner # type: ignore[import] | ||
|
||
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) | ||
except ImportError: | ||
testRunner = None | ||
unittest.main(testRunner=testRunner, verbosity=2) |
Oops, something went wrong.