Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly detect number of results colors #789

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def plot_results(results,
savedir: optional path to save the figure. If None (default), the figure will be interactively plotted.
height: height of the image in pixels. Required if the geometry type is 'polygon'.
width: width of the image in pixels. Required if the geometry type is 'polygon'.
results_color (list): color of the results annotations as a tuple of RGB color, e.g. orange annotations is [245, 135, 66]
results_color (list or sv.ColorPalette): color of the results annotations as a tuple of RGB color (if a single color), e.g. orange annotations is [245, 135, 66], or an supervision.ColorPalette if multiple labels and specifying colors for each label
ground_truth_color (list): color of the ground truth annotations as a tuple of RGB color, e.g. blue annotations is [0, 165, 255]
thickness: thickness of the rectangle border line in px
basename: optional basename for the saved figure. If None (default), the basename will be extracted from the image path.
Expand All @@ -385,17 +385,30 @@ def plot_results(results,
None
"""
# Convert colors, check for multi-class labels
results_color_sv = sv.Color(results_color[0], results_color[1], results_color[2])
if isinstance(results_color, list) and len(results_color) == 3:
results_color_sv = sv.Color(results_color[0], results_color[1], results_color[2])
elif isinstance(results_color, sv.draw.color.ColorPalette):
results_color_sv = results_color
elif isinstance(results_color, list):
raise ValueError(
"results_color must be either a 3 item list containing RGB values or an sv.ColorPalette instance"
)
else:
raise TypeError(
"results_color must be either a list of RGB values or an sv.ColorPalette instance"
)

ground_truth_color_sv = sv.Color(ground_truth_color[0], ground_truth_color[1],
ground_truth_color[2])

num_labels = len(results.label.unique())
if num_labels > 1 and len(results_color) != num_labels:
if num_labels > 1 and len(results_color_sv) != num_labels:
warnings.warn(
"Multiple labels detected in the results. Each label will be plotted with a different color using a color ramp, results color argument is ignored."
"""Multiple labels detected in the results and results_color argument provides a single color.
Each label will be plotted with a different color using a built-in color ramp.
If you want to customize colors with multiple labels pass a supervision.ColorPalette object to results_color with the appropriate number of labels"""
)
if num_labels > 1:
results_color_sv = sv.ColorPalette.from_matplotlib('viridis', num_labels)
results_color_sv = sv.ColorPalette.from_matplotlib('viridis', num_labels)

# Read images
root_dir = results.root_dir
Expand Down