Skip to content

Commit

Permalink
refactor new trendlines
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaskruchten committed Aug 2, 2021
1 parent 28f8e1f commit b0ed4cf
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 29 deletions.
6 changes: 4 additions & 2 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant, Range
from .trendline_functions import ols, lowess, ma, ewma
from .trendline_functions import ols, lowess, rolling, expanding, ewm

from _plotly_utils.basevalidators import ColorscaleValidator
from plotly.colors import qualitative, sequential
Expand All @@ -17,7 +17,9 @@
)

NO_COLOR = "px_no_color_constant"
trendline_functions = dict(lowess=lowess, ma=ma, ewma=ewma, ols=ols)
trendline_functions = dict(
lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols
)

# Declare all supported attributes, across all plot types
direct_attrables = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
fit_results.params[0],
)
elif not add_constant:
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label,)
hover_header += "%s = %g * %s<br>" % (y_label, fit_results.params[0], x_label)
else:
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0],)
hover_header += "%s = %g<br>" % (y_label, fit_results.params[0])
hover_header += "R<sup>2</sup>=%f<br><br>" % fit_results.rsquared
return y_out, hover_header, fit_results

Expand All @@ -91,27 +91,48 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
return y_out, hover_header, None


def ma(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Moving Average (MA) trendline function
def _pandas(mode, trendline_options, x_raw, y, non_missing):
modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
function_name = trendline_options.pop("function", "mean")
function_args = trendline_options.pop("function_args", dict())
series = pd.Series(y, index=x_raw)
agg = getattr(series, mode) # e.g. series.rolling
agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts)
y_out = y_out[non_missing]
hover_header = "<b>%s %s trendline</b><br><br>" % (modes[mode], function_name)
return y_out, hover_header, None


Requires `pandas` to be installed.
def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Rolling trendline function
The `trendline_options` dict is passed as keyword arguments into the
`pandas.Series.rolling` function.
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.rolling` function.
"""
y_out = pd.Series(y, index=x_raw).rolling(**trendline_options).mean()[non_missing]
hover_header = "<b>MA trendline</b><br><br>"
return y_out, hover_header, None
return _pandas("rolling", trendline_options, x_raw, y, non_missing)


def ewma(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Exponentially Weighted Moving Average (EWMA) trendline function
def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Expanding trendline function
Requires `pandas` to be installed.
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.expanding` function.
"""
return _pandas("expanding", trendline_options, x_raw, y, non_missing)

The `trendline_options` dict is passed as keyword arguments into the
`pandas.Series.ewma` function.

def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
"""Exponentially weighted trendline function
The value of the `function` key of the `trendline_options` dict is the function to
use (defaults to `mean`) and the value of the `function_args` key are taken to be
its arguments as a dict. The remainder of the `trendline_options` dict is passed as
keyword arguments into the `pandas.Series.ewm` function.
"""
y_out = pd.Series(y, index=x_raw).ewm(**trendline_options).mean()[non_missing]
hover_header = "<b>EWMA trendline</b><br><br>"
return y_out, hover_header, None
return _pandas("ewm", trendline_options, x_raw, y, non_missing)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
("ols", None),
("lowess", None),
("lowess", dict(frac=0.3)),
("ma", dict(window=2)),
("ewma", dict(alpha=0.5)),
("rolling", dict(window=2)),
("expanding", None),
("ewm", dict(alpha=0.5)),
],
)
def test_trendline_results_passthrough(mode, options):
Expand Down Expand Up @@ -48,8 +49,9 @@ def test_trendline_results_passthrough(mode, options):
("ols", None),
("lowess", None),
("lowess", dict(frac=0.3)),
("ma", dict(window=2)),
("ewma", dict(alpha=0.5)),
("rolling", dict(window=2)),
("expanding", None),
("ewm", dict(alpha=0.5)),
],
)
def test_trendline_enough_values(mode, options):
Expand Down Expand Up @@ -102,8 +104,9 @@ def test_trendline_enough_values(mode, options):
("ols", dict(add_constant=False, log_x=True, log_y=True)),
("lowess", None),
("lowess", dict(frac=0.3)),
("ma", dict(window=2)),
("ewma", dict(alpha=0.5)),
("rolling", dict(window=2)),
("expanding", None),
("ewm", dict(alpha=0.5)),
],
)
def test_trendline_nan_values(mode, options):
Expand Down Expand Up @@ -173,9 +176,10 @@ def test_ols_trendline_slopes():
("ols", None),
("lowess", None),
("lowess", dict(frac=0.3)),
("ma", dict(window=2)),
("ma", dict(window="10d")),
("ewma", dict(alpha=0.5)),
("rolling", dict(window=2)),
("rolling", dict(window="10d")),
("expanding", None),
("ewm", dict(alpha=0.5)),
],
)
def test_trendline_on_timeseries(mode, options):
Expand Down

0 comments on commit b0ed4cf

Please sign in to comment.