Skip to content

Commit

Permalink
针对inference阶段没有gt_label的情况针对性修复ser_postprocessor以及ser_visualizer中存在的bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinNuNu committed May 25, 2023
1 parent d9a3a5e commit b04e126
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 56 deletions.
13 changes: 7 additions & 6 deletions projects/LayoutLMv3/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,23 @@ def transform(self, results: dict) -> dict:
for key in self.ser_keys:
if key not in results:
continue
value = to_tensor(results[key])
inputs[key] = value
inputs[key] = to_tensor(results[key])
packed_results['inputs'] = inputs

# pack `data_samples`
data_samples = []
for truncation_idx in range(truncation_number):
data_sample = SERDataSample()
gt_label = LabelData()
assert 'labels' in results, 'key `labels` not in results.'
value = to_tensor(results['labels'][truncation_idx])
gt_label.item = value
if results.get('labels', None):
gt_label.item = to_tensor(results['labels'][truncation_idx])
data_sample.gt_label = gt_label
meta = {}
for key in self.meta_keys:
meta[key] = results[key]
if key == 'truncation_word_ids':
meta[key] = results[key][truncation_idx]
else:
meta[key] = results[key]
data_sample.set_metainfo(meta)
data_samples.append(data_sample)
packed_results['data_samples'] = data_samples
Expand Down
65 changes: 41 additions & 24 deletions projects/LayoutLMv3/models/ser_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
class SERPostprocessor(nn.Module):
"""PostProcessor for SER."""

def __init__(self,
classes: Union[tuple, list],
ignore_index: int = -100) -> None:
def __init__(self, classes: Union[tuple, list]) -> None:
super().__init__()
self.other_label_name = find_other_label_name_of_biolabel(classes)
self.id2biolabel = self._generate_id2biolabel_map(classes)
self.ignore_index = ignore_index
self.softmax = nn.Softmax(dim=-1)

def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
Expand All @@ -43,42 +40,62 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
def __call__(self, outputs: torch.Tensor,
data_samples: Sequence[SERDataSample]
) -> Sequence[SERDataSample]:
# merge several truncation data_sample to one data_sample
assert all('truncation_word_ids' in d for d in data_samples), \
'The key `truncation_word_ids` should be specified' \
'in PackSERInputs.'
truncation_word_ids = []
for data_sample in data_samples:
truncation_word_ids.append(data_sample.pop('truncation_word_ids'))
merged_data_sample = copy.deepcopy(data_samples[0])
merged_data_sample.set_metainfo(
dict(truncation_word_ids=truncation_word_ids))
flattened_word_ids = [
word_id for word_ids in truncation_word_ids for word_id in word_ids
]

# convert outputs dim from (truncation_num, max_length, label_num)
# to (truncation_num * max_length, label_num)
outputs = outputs.cpu().detach()
truncation_num = outputs.size(0)
outputs = torch.reshape(outputs, (-1, outputs.size(-1)))
# merge gt label ids from data_samples
gt_label_ids = [
data_samples[truncation_idx].gt_label.item
for truncation_idx in range(truncation_num)
]
gt_label_ids = torch.cat(gt_label_ids, dim=0).cpu().detach().numpy()
# get pred label ids/scores from outputs
probs = self.softmax(outputs)
max_value, max_idx = torch.max(probs, -1)
pred_label_ids = max_idx.numpy()
pred_label_scores = max_value.numpy()
# select valid token and convert iid to biolabel
gt_biolabels = [
self.id2biolabel[g] for (g, p) in zip(gt_label_ids, pred_label_ids)
if g != self.ignore_index
]

# determine whether it is an inference process
if 'item' in data_samples[0].gt_label:
# merge gt label ids from data_samples
gt_label_ids = [
data_sample.gt_label.item for data_sample in data_samples
]
gt_label_ids = torch.cat(
gt_label_ids, dim=0).cpu().detach().numpy()
gt_biolabels = [
self.id2biolabel[g]
for (w, g) in zip(flattened_word_ids, gt_label_ids)
if w is not None
]
# update merged gt_label
merged_data_sample.gt_label.item = gt_biolabels

# inference process do not have item in gt_label,
# so select valid token with flattened_word_ids
# rather than with gt_label_ids like official code.
pred_biolabels = [
self.id2biolabel[p] for (g, p) in zip(gt_label_ids, pred_label_ids)
if g != self.ignore_index
self.id2biolabel[p]
for (w, p) in zip(flattened_word_ids, pred_label_ids)
if w is not None
]
pred_biolabel_scores = [
s for (g, s) in zip(gt_label_ids, pred_label_scores)
if g != self.ignore_index
s for (w, s) in zip(flattened_word_ids, pred_label_scores)
if w is not None
]
# record pred_label
pred_label = LabelData()
pred_label.item = pred_biolabels
pred_label.score = pred_biolabel_scores
# merge several truncation data_sample to one data_sample
merged_data_sample = copy.deepcopy(data_samples[0])
merged_data_sample.pred_label = pred_label
# update merged gt_label
merged_data_sample.gt_label.item = gt_biolabels

return [merged_data_sample]
70 changes: 44 additions & 26 deletions projects/LayoutLMv3/visualization/ser_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,13 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
line_width=self.line_width,
alpha=self.alpha)

# draw gt/pred labels
if gt_labels is not None and pred_labels is not None:
areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
scales = _get_adaptive_scales(areas)
positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2

if gt_labels is not None:
gt_tokens_biolabel = gt_labels.item
gt_words_label = []
pred_tokens_biolabel = pred_labels.item
pred_words_label = []

if 'score' in pred_labels:
pred_tokens_biolabel_score = pred_labels.score
pred_words_label_score = []
else:
pred_tokens_biolabel_score = None
pred_words_label_score = None

pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
Expand All @@ -112,36 +106,60 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
gt_words_label_name = gt_tokens_biolabel[idx][2:] \
if gt_tokens_biolabel[idx] != 'O' else 'other'
gt_words_label.append(gt_words_label_name)
pre_word_id = cur_word_id
assert len(gt_words_label) == len(bboxes)
if pred_labels is not None:
pred_tokens_biolabel = pred_labels.item
pred_words_label = []
pred_tokens_biolabel_score = pred_labels.score
pred_words_label_score = []

pre_word_id = None
for idx, cur_word_id in enumerate(word_ids):
if cur_word_id is not None:
if cur_word_id != pre_word_id:
pred_words_label_name = pred_tokens_biolabel[idx][2:] \
if pred_tokens_biolabel[idx] != 'O' else 'other'
pred_words_label.append(pred_words_label_name)
if pred_tokens_biolabel_score is not None:
pred_words_label_score.append(
pred_tokens_biolabel_score[idx])
pred_words_label_score.append(
pred_tokens_biolabel_score[idx])
pre_word_id = cur_word_id
assert len(gt_words_label) == len(bboxes)
assert len(pred_words_label) == len(bboxes)

areas = (bboxes[:, 3] - bboxes[:, 1]) * (
bboxes[:, 2] - bboxes[:, 0])
scales = _get_adaptive_scales(areas)
positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2

# draw gt or pred labels
if gt_labels is not None and pred_labels is not None:
for i, (pos, gt, pred) in enumerate(
zip(positions, gt_words_label, pred_words_label)):
if pred_words_label_score is not None:
score = round(float(pred_words_label_score[i]) * 100, 1)
label_text = f'{gt} | {pred}({score})'
else:
label_text = f'{gt} | {pred}'

score = round(float(pred_words_label_score[i]) * 100, 1)
label_text = f'{gt} | {pred}({score})'
self.draw_texts(
label_text,
pos,
colors=self.label_color if gt == pred else 'r',
font_sizes=int(13 * scales[i]),
vertical_alignments='center',
horizontal_alignments='center')
elif pred_labels is not None:
for i, (pos, pred) in enumerate(zip(positions, pred_words_label)):
score = round(float(pred_words_label_score[i]) * 100, 1)
label_text = f'Pred: {pred}({score})'
self.draw_texts(
label_text,
pos,
colors=self.label_color,
font_sizes=int(13 * scales[i]),
vertical_alignments='center',
horizontal_alignments='center')
elif gt_labels is not None:
for i, (pos, gt) in enumerate(zip(positions, gt_words_label)):
label_text = f'GT: {gt}'
self.draw_texts(
label_text,
pos,
colors=self.label_color,
font_sizes=int(13 * scales[i]),
vertical_alignments='center',
horizontal_alignments='center')

return self.get_image()

Expand Down

0 comments on commit b04e126

Please sign in to comment.