Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sunburst improvements #2133

Merged
merged 2 commits into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ def build_dataframe(args, attrables, array_attrables):
def _check_dataframe_all_leaves(df):
df_sorted = df.sort_values(by=list(df.columns))
null_mask = df_sorted.isnull()
df_sorted = df_sorted.astype(str)
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
for null_row_index in null_indices:
row = null_mask.iloc[null_row_index]
Expand Down Expand Up @@ -1043,8 +1044,9 @@ def process_dataframe_hierarchy(args):

if args["color"] and args["color"] in path:
series_to_copy = df[args["color"]]
args["color"] = str(args["color"]) + "additional_col_for_px"
df[args["color"]] = series_to_copy
new_col_name = args["color"] + "additional_col_for_color"
path = [new_col_name if x == args["color"] else x for x in path]
df[new_col_name] = series_to_copy
if args["hover_data"]:
for col_name in args["hover_data"]:
if col_name == args["color"]:
Expand Down Expand Up @@ -1147,6 +1149,11 @@ def aggfunc_continuous(x):
args["ids"] = "id"
args["names"] = "labels"
args["parents"] = "parent"
if args["color"]:
if not args["hover_data"]:
args["hover_data"] = [args["color"]]
else:
args["hover_data"].append(args["color"])
return args


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,22 @@ def test_sunburst_treemap_with_path_color():
# Hover info
df["hover"] = [el.lower() for el in vendors]
fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"])
custom = fig.data[0].customdata.ravel()
assert np.all(custom[:8] == df["hover"])
assert np.all(custom[8:] == "(?)")
custom = fig.data[0].customdata
assert np.all(custom[:8, 0] == df["hover"])
assert np.all(custom[8:, 0] == "(?)")
assert np.all(custom[:8, 1] == df["calls"])

# Discrete color
fig = px.sunburst(df, path=path, color="vendors")
assert len(np.unique(fig.data[0].marker.colors)) == 9

# Numerical column in path
df["regions"] = df["regions"].map({"North": 1, "South": 2})
path = ["total", "regions", "sectors", "vendors"]
fig = px.sunburst(df, path=path, values="values", color="calls")
colors = fig.data[0].marker.colors
assert np.all(np.array(colors[:8]) == np.array(calls))


def test_sunburst_treemap_with_path_non_rectangular():
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
Expand Down