diff --git a/projects/LayoutLMv3/datasets/transforms/formatting.py b/projects/LayoutLMv3/datasets/transforms/formatting.py index 36b0f3ee0..7dd4911f6 100644 --- a/projects/LayoutLMv3/datasets/transforms/formatting.py +++ b/projects/LayoutLMv3/datasets/transforms/formatting.py @@ -98,8 +98,7 @@ def transform(self, results: dict) -> dict: for key in self.ser_keys: if key not in results: continue - value = to_tensor(results[key]) - inputs[key] = value + inputs[key] = to_tensor(results[key]) packed_results['inputs'] = inputs # pack `data_samples` @@ -107,13 +106,15 @@ def transform(self, results: dict) -> dict: for truncation_idx in range(truncation_number): data_sample = SERDataSample() gt_label = LabelData() - assert 'labels' in results, 'key `labels` not in results.' - value = to_tensor(results['labels'][truncation_idx]) - gt_label.item = value + if results.get('labels', None): + gt_label.item = to_tensor(results['labels'][truncation_idx]) data_sample.gt_label = gt_label meta = {} for key in self.meta_keys: - meta[key] = results[key] + if key == 'truncation_word_ids': + meta[key] = results[key][truncation_idx] + else: + meta[key] = results[key] data_sample.set_metainfo(meta) data_samples.append(data_sample) packed_results['data_samples'] = data_samples diff --git a/projects/LayoutLMv3/models/ser_postprocessor.py b/projects/LayoutLMv3/models/ser_postprocessor.py index 370173301..a70c2ae82 100644 --- a/projects/LayoutLMv3/models/ser_postprocessor.py +++ b/projects/LayoutLMv3/models/ser_postprocessor.py @@ -16,13 +16,10 @@ class SERPostprocessor(nn.Module): """PostProcessor for SER.""" - def __init__(self, - classes: Union[tuple, list], - ignore_index: int = -100) -> None: + def __init__(self, classes: Union[tuple, list]) -> None: super().__init__() self.other_label_name = find_other_label_name_of_biolabel(classes) self.id2biolabel = self._generate_id2biolabel_map(classes) - self.ignore_index = ignore_index self.softmax = nn.Softmax(dim=-1) def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict: @@ -43,42 +40,62 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict: def __call__(self, outputs: torch.Tensor, data_samples: Sequence[SERDataSample] ) -> Sequence[SERDataSample]: + # merge several truncation data_sample to one data_sample + assert all('truncation_word_ids' in d for d in data_samples), \ + 'The key `truncation_word_ids` should be specified' \ + 'in PackSERInputs.' + truncation_word_ids = [] + for data_sample in data_samples: + truncation_word_ids.append(data_sample.pop('truncation_word_ids')) + merged_data_sample = copy.deepcopy(data_samples[0]) + merged_data_sample.set_metainfo( + dict(truncation_word_ids=truncation_word_ids)) + flattened_word_ids = [ + word_id for word_ids in truncation_word_ids for word_id in word_ids + ] + # convert outputs dim from (truncation_num, max_length, label_num) # to (truncation_num * max_length, label_num) outputs = outputs.cpu().detach() - truncation_num = outputs.size(0) outputs = torch.reshape(outputs, (-1, outputs.size(-1))) - # merge gt label ids from data_samples - gt_label_ids = [ - data_samples[truncation_idx].gt_label.item - for truncation_idx in range(truncation_num) - ] - gt_label_ids = torch.cat(gt_label_ids, dim=0).cpu().detach().numpy() # get pred label ids/scores from outputs probs = self.softmax(outputs) max_value, max_idx = torch.max(probs, -1) pred_label_ids = max_idx.numpy() pred_label_scores = max_value.numpy() - # select valid token and convert iid to biolabel - gt_biolabels = [ - self.id2biolabel[g] for (g, p) in zip(gt_label_ids, pred_label_ids) - if g != self.ignore_index - ] + + # determine whether it is an inference process + if 'item' in data_samples[0].gt_label: + # merge gt label ids from data_samples + gt_label_ids = [ + data_sample.gt_label.item for data_sample in data_samples + ] + gt_label_ids = torch.cat( + gt_label_ids, dim=0).cpu().detach().numpy() + gt_biolabels = [ + self.id2biolabel[g] + for (w, g) in zip(flattened_word_ids, gt_label_ids) + if w is not None + ] + # update merged gt_label + merged_data_sample.gt_label.item = gt_biolabels + + # inference process do not have item in gt_label, + # so select valid token with flattened_word_ids + # rather than with gt_label_ids like official code. pred_biolabels = [ - self.id2biolabel[p] for (g, p) in zip(gt_label_ids, pred_label_ids) - if g != self.ignore_index + self.id2biolabel[p] + for (w, p) in zip(flattened_word_ids, pred_label_ids) + if w is not None ] pred_biolabel_scores = [ - s for (g, s) in zip(gt_label_ids, pred_label_scores) - if g != self.ignore_index + s for (w, s) in zip(flattened_word_ids, pred_label_scores) + if w is not None ] # record pred_label pred_label = LabelData() pred_label.item = pred_biolabels pred_label.score = pred_biolabel_scores - # merge several truncation data_sample to one data_sample - merged_data_sample = copy.deepcopy(data_samples[0]) merged_data_sample.pred_label = pred_label - # update merged gt_label - merged_data_sample.gt_label.item = gt_biolabels + return [merged_data_sample] diff --git a/projects/LayoutLMv3/visualization/ser_visualizer.py b/projects/LayoutLMv3/visualization/ser_visualizer.py index e2e2834df..f0cdc3707 100644 --- a/projects/LayoutLMv3/visualization/ser_visualizer.py +++ b/projects/LayoutLMv3/visualization/ser_visualizer.py @@ -91,19 +91,13 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray, line_width=self.line_width, alpha=self.alpha) - # draw gt/pred labels - if gt_labels is not None and pred_labels is not None: + areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + scales = _get_adaptive_scales(areas) + positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2 + + if gt_labels is not None: gt_tokens_biolabel = gt_labels.item gt_words_label = [] - pred_tokens_biolabel = pred_labels.item - pred_words_label = [] - - if 'score' in pred_labels: - pred_tokens_biolabel_score = pred_labels.score - pred_words_label_score = [] - else: - pred_tokens_biolabel_score = None - pred_words_label_score = None pre_word_id = None for idx, cur_word_id in enumerate(word_ids): @@ -112,29 +106,32 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray, gt_words_label_name = gt_tokens_biolabel[idx][2:] \ if gt_tokens_biolabel[idx] != 'O' else 'other' gt_words_label.append(gt_words_label_name) + pre_word_id = cur_word_id + assert len(gt_words_label) == len(bboxes) + if pred_labels is not None: + pred_tokens_biolabel = pred_labels.item + pred_words_label = [] + pred_tokens_biolabel_score = pred_labels.score + pred_words_label_score = [] + + pre_word_id = None + for idx, cur_word_id in enumerate(word_ids): + if cur_word_id is not None: + if cur_word_id != pre_word_id: pred_words_label_name = pred_tokens_biolabel[idx][2:] \ if pred_tokens_biolabel[idx] != 'O' else 'other' pred_words_label.append(pred_words_label_name) - if pred_tokens_biolabel_score is not None: - pred_words_label_score.append( - pred_tokens_biolabel_score[idx]) + pred_words_label_score.append( + pred_tokens_biolabel_score[idx]) pre_word_id = cur_word_id - assert len(gt_words_label) == len(bboxes) assert len(pred_words_label) == len(bboxes) - areas = (bboxes[:, 3] - bboxes[:, 1]) * ( - bboxes[:, 2] - bboxes[:, 0]) - scales = _get_adaptive_scales(areas) - positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2 - + # draw gt or pred labels + if gt_labels is not None and pred_labels is not None: for i, (pos, gt, pred) in enumerate( zip(positions, gt_words_label, pred_words_label)): - if pred_words_label_score is not None: - score = round(float(pred_words_label_score[i]) * 100, 1) - label_text = f'{gt} | {pred}({score})' - else: - label_text = f'{gt} | {pred}' - + score = round(float(pred_words_label_score[i]) * 100, 1) + label_text = f'{gt} | {pred}({score})' self.draw_texts( label_text, pos, @@ -142,6 +139,27 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray, font_sizes=int(13 * scales[i]), vertical_alignments='center', horizontal_alignments='center') + elif pred_labels is not None: + for i, (pos, pred) in enumerate(zip(positions, pred_words_label)): + score = round(float(pred_words_label_score[i]) * 100, 1) + label_text = f'Pred: {pred}({score})' + self.draw_texts( + label_text, + pos, + colors=self.label_color, + font_sizes=int(13 * scales[i]), + vertical_alignments='center', + horizontal_alignments='center') + elif gt_labels is not None: + for i, (pos, gt) in enumerate(zip(positions, gt_words_label)): + label_text = f'GT: {gt}' + self.draw_texts( + label_text, + pos, + colors=self.label_color, + font_sizes=int(13 * scales[i]), + vertical_alignments='center', + horizontal_alignments='center') return self.get_image()