-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
instance_seg_metric.py
94 lines (77 loc) · 3.47 KB
/
instance_seg_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmdet3d.evaluation import instance_seg_eval
from mmdet3d.registry import METRICS
@METRICS.register_module()
class InstanceSegMetric(BaseMetric):
"""3D instance segmentation evaluation metric.
Args:
collect_device (str, optional): Device name used for collecting
results from different ranks during distributed training.
Must be 'cpu' or 'gpu'. Defaults to 'cpu'.
prefix (str): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None
"""
def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None,
**kwargs):
super(InstanceSegMetric, self).__init__(
prefix=prefix, collect_device=collect_device)
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions.
The processed results should be stored in ``self.results``,
which will be used to compute the metrics when all batches
have been processed.
Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""
for data_sample in data_samples:
pred_3d = data_sample['pred_pts_seg']
eval_ann_info = data_sample['eval_ann_info']
cpu_pred_3d = dict()
for k, v in pred_3d.items():
if hasattr(v, 'to'):
cpu_pred_3d[k] = v.to('cpu')
else:
cpu_pred_3d[k] = v
self.results.append((eval_ann_info, cpu_pred_3d))
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES']
self.valid_class_ids = self.dataset_meta['seg_valid_class_ids']
gt_semantic_masks = []
gt_instance_masks = []
pred_instance_masks = []
pred_instance_labels = []
pred_instance_scores = []
for eval_ann, sinlge_pred_results in results:
gt_semantic_masks.append(eval_ann['pts_semantic_mask'])
gt_instance_masks.append(eval_ann['pts_instance_mask'])
pred_instance_masks.append(
sinlge_pred_results['pts_instance_mask'])
pred_instance_labels.append(sinlge_pred_results['instance_labels'])
pred_instance_scores.append(sinlge_pred_results['instance_scores'])
ret_dict = instance_seg_eval(
gt_semantic_masks,
gt_instance_masks,
pred_instance_masks,
pred_instance_labels,
pred_instance_scores,
valid_class_ids=self.valid_class_ids,
class_labels=self.classes,
logger=logger)
return ret_dict