Skip to content

Commit

Permalink
Adds precision, recall, and f1 score to evaluate detections
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgural committed Aug 8, 2024
1 parent 6380d25 commit eb19786
Showing 1 changed file with 79 additions and 10 deletions.
89 changes: 79 additions & 10 deletions fiftyone/utils/eval/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,34 @@ def evaluate_detections(
object- and sample-level recording the results of the evaluation:
- True positive (TP), false positive (FP), and false negative (FN) counts
for each sample are saved in top-level fields of each sample::
as well as precision, recall, and F1-score for each sample are saved
in top-level fields of each sample::
TP: sample.<eval_key>_tp
FP: sample.<eval_key>_fp
FN: sample.<eval_key>_fn
Precision: sample.<eval_key>_precision
Recall: sample.<eval_key>_recall
F1-Score: sample.<eval_key>_f1_score
In addition, when evaluating frame-level objects, TP/FP/FN counts are
recorded for each frame::
In addition, when evaluating frame-level objects, TP/FP/FN counts and
precision/recall/F1-score are recorded for each frame::
TP: frame.<eval_key>_tp
FP: frame.<eval_key>_fp
FN: frame.<eval_key>_fn
Precision: frame.<eval_key>_precision
Recall: frame.<eval_key>_recall
F1-Score: frame.<eval_key>_f1_score
- The fields listed below are populated on each individual object; these
fields tabulate the TP/FP/FN status of the object, the ID of the
matching object (if any), and the matching IoU::
fields tabulate the TP/FP/FN status of the object, the precision/recall/F1 scores,
the ID of the matching object (if any), and the matching IoU::
TP/FP/FN: object.<eval_key>
ID: object.<eval_key>_id
IoU: object.<eval_key>_iou
TP/FP/FN: object.<eval_key>
Precision/Recall/F1-score: object.<eval_key>
ID: object.<eval_key>_id
IoU: object.<eval_key>_iou
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
Expand Down Expand Up @@ -180,6 +188,9 @@ def evaluate_detections(
tp_field = "%s_tp" % eval_key
fp_field = "%s_fp" % eval_key
fn_field = "%s_fn" % eval_key
precision_field = "%s_precision" % eval_key
recall_field = "%s_recall" % eval_key
f1_field = "%s_f1_score" % eval_key

if config.requires_additional_fields:
_samples = samples
Expand Down Expand Up @@ -209,12 +220,41 @@ def evaluate_detections(
doc[tp_field] = tp
doc[fp_field] = fp
doc[fn_field] = fn
if tp + fp != 0:
doc[precision_field] = tp / (tp + fp)
else:
doc[precision_field] = 0
if tp + fn != 0:
doc[recall_field] = tp / (tp + fn)
else:
doc[recall_field]
if doc[precision_field] + doc[recall_field] != 0:
doc[f1_field] = (
2 * doc[precision_field] * doc[recall_field]
) / (doc[precision_field] + doc[recall_field])
else:
doc[f1_field] = 0

if save:
sample[tp_field] = sample_tp
sample[fp_field] = sample_fp
sample[fn_field] = sample_fn

if sample_tp + sample_fp != 0:
sample[precision_field] = sample_tp / (sample_tp + sample_fp)
else:
sample[precision_field] = 0
if (sample_tp + sample_fn) != 0:
sample[recall_field] = sample_tp / (sample_tp + sample_fn)
else:
sample[recall_field] = 0
if (sample[precision_field] + sample[recall_field]) != 0:
sample[f1_field] = (
2 * sample[precision_field] * sample[recall_field]
) / (sample[precision_field] + sample[recall_field])
else:
sample[f1_field] = 0

results = eval_method.generate_results(
samples,
matches,
Expand Down Expand Up @@ -310,15 +350,24 @@ def register_samples(self, samples, eval_key, dynamic=True):
tp_field = "%s_tp" % eval_key
fp_field = "%s_fp" % eval_key
fn_field = "%s_fn" % eval_key
precision_field = "%s_precision" % eval_key
recall_field = "%s_recall" % eval_key
f1_field = "%s_f1_score" % eval_key

dataset.add_sample_field(tp_field, fof.IntField)
dataset.add_sample_field(fp_field, fof.IntField)
dataset.add_sample_field(fn_field, fof.IntField)
dataset.add_sample_field(precision_field, fof.FloatField)
dataset.add_sample_field(recall_field, fof.FloatField)
dataset.add_sample_field(f1_field, fof.FloatField)

if processing_frames:
dataset.add_frame_field(tp_field, fof.IntField)
dataset.add_frame_field(fp_field, fof.IntField)
dataset.add_frame_field(fn_field, fof.IntField)
dataset.add_frame_field(precision_field, fof.FloatField)
dataset.add_frame_field(recall_field, fof.FloatField)
dataset.add_frame_field(f1_field, fof.FloatField)

if not dynamic:
return
Expand Down Expand Up @@ -428,6 +477,9 @@ def get_fields(self, samples, eval_key):
"%s_tp" % eval_key,
"%s_fp" % eval_key,
"%s_fn" % eval_key,
"%s_precision" % eval_key,
"%s_recall" % eval_key,
"%s_f1_score" % eval_key,
pred_key,
"%s_id" % pred_key,
"%s_iou" % pred_key,
Expand All @@ -439,7 +491,14 @@ def get_fields(self, samples, eval_key):
if samples._is_frame_field(gt_field):
prefix = samples._FRAMES_PREFIX + eval_key
fields.extend(
["%s_tp" % prefix, "%s_fp" % prefix, "%s_fn" % prefix]
[
"%s_tp" % prefix,
"%s_fp" % prefix,
"%s_fn" % prefix,
"%s_precision" % prefix,
"%s_recall" % prefix,
"%s_f1_score" % prefix,
]
)

return fields
Expand Down Expand Up @@ -470,6 +529,9 @@ def cleanup(self, samples, eval_key):
"%s_tp" % eval_key,
"%s_fp" % eval_key,
"%s_fn" % eval_key,
"%s_precision" % eval_key,
"%s_recall" % eval_key,
"%s_f1_score" % eval_key,
]

try:
Expand Down Expand Up @@ -500,7 +562,14 @@ def cleanup(self, samples, eval_key):

if dataset._is_frame_field(self.config.pred_field):
dataset.delete_sample_fields(
["%s_tp" % eval_key, "%s_fp" % eval_key, "%s_fn" % eval_key],
[
"%s_tp" % eval_key,
"%s_fp" % eval_key,
"%s_fn" % eval_key,
"%s_precision" % eval_key,
"%s_recall" % eval_key,
"%s_f1_score" % eval_key,
],
error_level=1,
)
dataset.delete_frame_fields(fields, error_level=1)
Expand Down

0 comments on commit eb19786

Please sign in to comment.