Skip to content

Commit

Permalink
Correctly detect number of results colors (#789)
Browse files Browse the repository at this point in the history
We were checking results_color, which will typically be a length 3 list.
This would cause incorrect behavior in the presence of multiple labels.
This switches to always creating results_color_sv upfront and checking
it.

Also adds some type/value checks to the input.
  • Loading branch information
ethanwhite committed Sep 25, 2024
1 parent aac610d commit c4464d2
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def plot_results(results,
savedir: optional path to save the figure. If None (default), the figure will be interactively plotted.
height: height of the image in pixels. Required if the geometry type is 'polygon'.
width: width of the image in pixels. Required if the geometry type is 'polygon'.
results_color (list): color of the results annotations as a tuple of RGB color, e.g. orange annotations is [245, 135, 66]
results_color (list or sv.ColorPalette): color of the results annotations as a tuple of RGB color (if a single color), e.g. orange annotations is [245, 135, 66], or an supervision.ColorPalette if multiple labels and specifying colors for each label
ground_truth_color (list): color of the ground truth annotations as a tuple of RGB color, e.g. blue annotations is [0, 165, 255]
thickness: thickness of the rectangle border line in px
basename: optional basename for the saved figure. If None (default), the basename will be extracted from the image path.
Expand All @@ -385,17 +385,30 @@ def plot_results(results,
None
"""
# Convert colors, check for multi-class labels
results_color_sv = sv.Color(results_color[0], results_color[1], results_color[2])
if isinstance(results_color, list) and len(results_color) == 3:
results_color_sv = sv.Color(results_color[0], results_color[1], results_color[2])
elif isinstance(results_color, sv.draw.color.ColorPalette):
results_color_sv = results_color
elif isinstance(results_color, list):
raise ValueError(
"results_color must be either a 3 item list containing RGB values or an sv.ColorPalette instance"
)
else:
raise TypeError(
"results_color must be either a list of RGB values or an sv.ColorPalette instance"
)

ground_truth_color_sv = sv.Color(ground_truth_color[0], ground_truth_color[1],
ground_truth_color[2])

num_labels = len(results.label.unique())
if num_labels > 1 and len(results_color) != num_labels:
if num_labels > 1 and len(results_color_sv) != num_labels:
warnings.warn(
"Multiple labels detected in the results. Each label will be plotted with a different color using a color ramp, results color argument is ignored."
"""Multiple labels detected in the results and results_color argument provides a single color.
Each label will be plotted with a different color using a built-in color ramp.
If you want to customize colors with multiple labels pass a supervision.ColorPalette object to results_color with the appropriate number of labels"""
)
if num_labels > 1:
results_color_sv = sv.ColorPalette.from_matplotlib('viridis', num_labels)
results_color_sv = sv.ColorPalette.from_matplotlib('viridis', num_labels)

# Read images
root_dir = results.root_dir
Expand Down

0 comments on commit c4464d2

Please sign in to comment.