Skip to content

Commit

Permalink
A plot_results function that works across predict* functions and mult…
Browse files Browse the repository at this point in the history
…iple annotation types. (#768)

Change viz function to implement plot_results, 

Correctly detect number of results colors (#789)
Add geometry column to the test assert 
Also adds some type/value checks to the input.
Remove plotting from predict functions, closes #761 and update docs
Style changes and update image paths
Increase viz sizes

---------

Co-authored-by: Ethan White <ethan@weecology.org>
  • Loading branch information
bw4sz and ethanwhite committed Sep 25, 2024
1 parent 46bd2fe commit e14bc6d
Show file tree
Hide file tree
Showing 15 changed files with 452 additions and 266 deletions.
67 changes: 51 additions & 16 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,20 @@ def predict_image(self,
thickness: int = 1):
"""Predict a single image with a deepforest model.
Deprecation warning: The 'return_plot', and related 'color' and 'thickness' arguments are deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead.
Args:
image: a float32 numpy array of a RGB with channels last format
path: optional path to read image from disk instead of passing image arg
return_plot: Return image with plotted detections
color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
thickness: thickness of the rectangle border line in px
(deprecated) return_plot: return a plot of the image with predictions overlaid
(deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
(deprectaed) thickness: thickness of the rectangle border line in px
Returns:
result: A pandas dataframe of predictions (Default)
img: The input with predictions overlaid (Optional)
"""

# Ensure we are in eval mode
self.model.eval()

Expand Down Expand Up @@ -379,6 +383,11 @@ def predict_image(self,
color=color)

if return_plot:
# Add deprecated warning
warnings.warn(
"return_plot is deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead."
)

return result
else:
#If there were no predictions, return None
Expand All @@ -388,6 +397,14 @@ def predict_image(self,
result["label"] = result.label.apply(
lambda x: self.numeric_to_label_dict[x])

result = utilities.read_file(result)
if path is None:
warnings.warn(
"An image was passed directly to predict_image, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
)
else:
result.root_dir = os.path.dirname(path)

return result

def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1):
Expand All @@ -397,15 +414,19 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
relative filename, not absolute path, which is in the root_dir
directory. One bounding box per line.
Deprecation warning: The return_plot argument is deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead.
Args:
csv_file: path to csv file
root_dir: directory of images. If none, uses "image_dir" in config
savedir: Optional. Directory to save image plots.
color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
thickness: thickness of the rectangle border line in px
(deprecated) savedir: directory to save images with bounding boxes
(deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
(deprecated) thickness: thickness of the rectangle border line in px
Returns:
df: pandas dataframe with bounding boxes, label and scores for each image in the csv file
"""

df = utilities.read_file(csv_file)
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
Expand All @@ -419,10 +440,12 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1
dataloader=dataloader,
root_dir=root_dir,
nms_thresh=self.config["nms_thresh"],
savedir=savedir,
color=color,
savedir=savedir,
thickness=thickness)

results.root_dir = root_dir

return results

def predict_tile(self,
Expand Down Expand Up @@ -453,15 +476,17 @@ def predict_tile(self,
iou_threshold: Minimum iou overlap among predictions between
windows to be suppressed.
Lower values suppress more boxes at edges.
return_plot: Should the image be returned with the predictions drawn?
mosaic: Return a single prediction dataframe (True) or a tuple of image crops and predictions (False)
sigma: variance of Gaussian function used in Gaussian Soft NMS
thresh: the score thresh used to filter bboxes after soft-nms performed
color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
thickness: thickness of the rectangle border line in px
cropModel: a deepforest.model.CropModel object to predict on crops
crop_transform: a torchvision.transforms object to apply to crops
crop_augment: a boolean to apply augmentations to crops
(deprecated) return_plot: return a plot of the image with predictions overlaid
(deprecated) color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
(deprecated) thickness: thickness of the rectangle border line in px
Deprecation: The return_plot argument is deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead.
Returns:
boxes (array): if return_plot, an image.
Expand Down Expand Up @@ -511,6 +536,10 @@ def predict_tile(self,
if raster_path:
results["image_path"] = os.path.basename(raster_path)
if return_plot:
# Add deprecated warning
warnings.warn(
"return_plot is deprecated and will be removed in 2.0. Use visualize.plot_results on the result instead."
)
# Draw predictions on BGR
if raster_path:
tile = rio.open(raster_path).read()
Expand Down Expand Up @@ -543,6 +572,15 @@ def predict_tile(self,
transform=crop_transform,
augment=crop_augment)

results = utilities.read_file(results)

if raster_path is None:
warnings.warn(
"An image was passed directly to predict_tile, the root_dir will be None in the output dataframe, to use visualize.plot_results, please assign results.root_dir = <directory name>"
)
else:
results.root_dir = os.path.dirname(raster_path)

return results

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -740,22 +778,19 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None):
csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label", each box in a row
root_dir: location of files in the dataframe 'name' column.
iou_threshold: float [0,1] intersection-over-union union between annotation and prediction to be scored true positive
savedir: optional path dir to save evaluation images
savedir: location to save images with bounding boxes
Returns:
results: dict of ("results", "precision", "recall") for a given threshold
"""
ground_df = utilities.read_file(csv_file)
ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x])
predictions = self.predict_file(csv_file=csv_file,
root_dir=root_dir,
savedir=savedir)
predictions = self.predict_file(csv_file=csv_file, root_dir=root_dir)

results = evaluate_iou.__evaluate_wrapper__(
predictions=predictions,
ground_df=ground_df,
root_dir=root_dir,
iou_threshold=iou_threshold,
numeric_to_label_dict=self.numeric_to_label_dict,
savedir=savedir)
numeric_to_label_dict=self.numeric_to_label_dict)

return results
3 changes: 3 additions & 0 deletions deepforest/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,9 @@ def read_file(input, root_dir=None):
# convert to geodataframe
df = gpd.GeoDataFrame(df, geometry='geometry')

# If root_dir is specified, add as attribute
df.root_dir = root_dir

return df


Expand Down
Loading

0 comments on commit e14bc6d

Please sign in to comment.