diff --git a/.gitignore b/.gitignore index 0a016e3..2c9fa01 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ workdir/ *.pth # local file +tmp.py *.txt train.sh ttt.py diff --git a/vedastr/dataloaders/samplers/balance_sampler.py b/vedastr/dataloaders/samplers/balance_sampler.py index ea7d424..1e2341f 100644 --- a/vedastr/dataloaders/samplers/balance_sampler.py +++ b/vedastr/dataloaders/samplers/balance_sampler.py @@ -1,6 +1,8 @@ import copy +import logging import random +import numpy as np from torch.utils.data import Sampler from .registry import SAMPLER @@ -21,19 +23,52 @@ class BalanceSampler(Sampler): downsample (bool): Set to True to downsample bigger sampler set. .. warning:: If both oversample and downsample is True, BanlanceSampler will do oversample first. That means downsample will do no effect. + + The last batch mey have different batch ratio, which means last batch may be + not balance. """ - def __init__(self, dataset, batch_size, shuffle, oversample=False, downsample=False): + def __init__(self, dataset, batch_size, shuffle, oversample=False, downsample=False, eps=0.1): assert hasattr(dataset, 'data_range') assert hasattr(dataset, 'batch_ratio') self.dataset = dataset self.samples_range = dataset.data_range - self.batch_ratio = dataset.batch_ratio + self.batch_ratio = np.array(dataset.batch_ratio) self.batch_size = batch_size + self.batch_sizes = self._compute_each_batch_size() + new_br = self.batch_sizes / self.batch_size + br_diffs = np.abs((new_br - self.batch_ratio)) + assert not np.sum(br_diffs > eps), "After computing the batch sizes of each dataset based on" \ + "given batch ratio, the max difference between new batch ratio " \ + "which compute based on the computed batch size and" \ + f" given batch ratio is large than the eps {eps}.\n" \ + "Please Considering increase the value of eps or batch size." \ + f"Current computed batch sizes are {self.batch_sizes}, new batch " \ + f"ratios are {new_br}, while give batch ratio" \ + f" are {self.batch_ratio}.\n" \ + f"The max difference between given batch ratio and new batch ratio" \ + f"is {np.max(np.array(br_diffs))}." + + assert 0 not in self.batch_sizes, "0 batch size is not supported, where batch " \ + "size is computed based on the batch ratio." \ + f" Computed batch size is {self.batch_sizes}." + + assert np.sum(self.batch_sizes) == self.batch_size + logging.info(f"The truly used batch ratios are {new_br}") + self.batch_ratio = new_br self.oversample = oversample self.downsample = downsample self.shuffle = shuffle - self._generate_indices_() + self._generate_indices() + + def _compute_each_batch_size(self): + batch_sizes = self.batch_ratio * self.batch_size + int_bs = batch_sizes.astype(np.int) + float_bs = (batch_sizes - int_bs) >= 0.5 + diff = self.batch_size - np.sum(int_bs) - np.sum(float_bs) + float_bs[np.where(float_bs == (diff < 0))[0][:int(abs(diff))]] = (diff >= 0) + + return (int_bs + float_bs).astype(np.int) @property def _num_samples(self): @@ -43,7 +78,7 @@ def _num_samples(self): def _num_samples(self, v): self.num_samples = v - def _generate_indices_(self): + def _generate_indices(self): self._num_samples = len(self.dataset) indices_ = [] # TODO, elegant @@ -58,31 +93,36 @@ def _generate_indices_(self): if self.shuffle: random.shuffle(temp) indices_.append(temp) + per_dataset_len = [len(index) for index in indices_] + pratios = [l / s for (l, s) in zip(per_dataset_len, self.batch_sizes)] if self.oversample: - indices_ = self._oversample(indices_) + need_len = [int(np.ceil(max(pratios) * size)) for size in self.batch_sizes] + indices_ = self._oversample(indices_, need_len) if self.downsample: - indices_ = self._downsample(indices_) + need_len = [int(np.ceil(min(pratios) * size)) for size in self.batch_sizes] + indices_ = self._downsample(indices_, need_len) return indices_ def __iter__(self): - indices_ = self._generate_indices_() + indices_ = self._generate_indices() total_nums = len(self) // self.batch_size - sizes = [int(self.batch_size * br) for br in self.batch_ratio] - final_index = [total_nums * size for size in sizes] + final_index = [total_nums * size for size in self.batch_sizes] indices = [] for idx2 in range(total_nums): - for idx3, size in enumerate(sizes): + for idx3, size in enumerate(self.batch_sizes): indices += indices_[idx3][idx2 * size:(idx2 + 1) * size] + # TODO, + # oversample or drop last. In current situation, + # the performance may drop a lot because the last batch may not balance for idx4, index in enumerate(final_index): indices += indices_[idx4][index:] return iter(indices) - def _oversample(self, indices): - max_len = max([len(index) for index in indices]) + def _oversample(self, indices, need_len): result_indices = [] for idx, index in enumerate(indices): current_nums = len(index) - need_num = max_len - current_nums + need_num = need_len[idx] - current_nums total_nums = need_num // current_nums mod_nums = need_num % current_nums init_index = copy.copy(index) @@ -93,16 +133,16 @@ def _oversample(self, indices): index += new_index index += random.sample(index, mod_nums) result_indices.append(index) - self._num_samples = max_len * len(indices) + self._num_samples = np.sum(need_len) + return result_indices - def _downsample(self, indices): - min_len = min([len(index) for index in indices]) + def _downsample(self, indices, need_len): result_indices = [] for idx, index in enumerate(indices): - index = random.sample(index, min_len) + index = random.sample(index, need_len[idx]) result_indices.append(index) - self._num_samples = min_len * len(indices) + self._num_samples = np.sum(need_len) return result_indices def __len__(self): diff --git a/vedastr/models/bodies/rectificators/spin.py b/vedastr/models/bodies/rectificators/spin.py index 09107b9..78fa3a3 100644 --- a/vedastr/models/bodies/rectificators/spin.py +++ b/vedastr/models/bodies/rectificators/spin.py @@ -1,9 +1,10 @@ +# [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117) +# Not fully implemented yet. import copy import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from .registry import RECTIFICATORS from vedastr.models.bodies.feature_extractors import build_feature_extractor @@ -47,11 +48,8 @@ def __init__(self, spin, k): super(SPIN, self).__init__() self.body = build_feature_extractor(spin['feature_extractor']) self.spn = SPN(spin['spn']) - self.ain = AIN(spin['ain']) self.betas = generate_beta(k) init_weights(self.modules()) - # self.spn.head[-1].fc.weight.data.fill_(0) - # self.spn.head[-1].fc.bias.data = torch.from_numpy(np.array([0, 0, 0, 0, 0,0,1,0,0, 0, 0, 0, 0, 0])).float() def forward(self, x): b, c, h, w = x.size() @@ -60,23 +58,14 @@ def forward(self, x): x = self.body(x) spn_out = self.spn(x) # 2k+2 - ain_out = self.ain(x) # activated by sigmoid - omega = spn_out[:, :-1] + g_out = init_img.requires_grad_(True) - alpha = F.sigmoid(spn_out[:, -1]) - - # offset - ain_out = F.interpolate(ain_out, size=(h, w), mode='bilinear') # noqa: F811 - g_out = alpha[:, None, None, None] * ain_out + (1 - alpha[:, None, None, None]) * init_img - # g_out = alpha * ain_out + (1 - alpha) * init_img - # g_out = init_img - # beta dist on g_out gamma_out = [g_out ** beta for beta in self.betas] gamma_out = torch.stack(gamma_out, axis=1).requires_grad_(True) + fusion_img = omega[:, :, None, None, None] * gamma_out fusion_img = torch.sigmoid(fusion_img.sum(dim=1)) - return fusion_img diff --git a/vedastr/runners/inference_runner.py b/vedastr/runners/inference_runner.py index d3c4557..2a3bb7a 100644 --- a/vedastr/runners/inference_runner.py +++ b/vedastr/runners/inference_runner.py @@ -53,7 +53,7 @@ def load_checkpoint(self, filename, map_location='default', strict=True): return load_checkpoint(self.model, filename, map_location, strict) - def postprocess(self, preds, cfg=None, label=None): + def postprocess(self, preds, cfg=None): if cfg is not None: sensitive = cfg.get('sensitive', True) character = cfg.get('character', '') @@ -65,7 +65,6 @@ def postprocess(self, preds, cfg=None, label=None): max_probs, indexes = probs.max(dim=2) preds_str = [] preds_prob = [] - labels = [] for i, pstr in enumerate(self.converter.decode(indexes)): str_len = len(pstr) if str_len == 0: @@ -73,19 +72,13 @@ def postprocess(self, preds, cfg=None, label=None): else: prob = max_probs[i, :str_len].cumprod(dim=0)[-1] preds_prob.append(prob) - if not sensitive: pstr = pstr.lower() - if label is not None: - tmp = label[i].lower() + if character: pstr = re.sub('[^{}]'.format(character), '', pstr) - if label is not None: - tmp = re.sub('[^{}]'.format(character), '', tmp) - labels.append(tmp) + preds_str.append(pstr) - if label is not None: - return preds_str, preds_prob, labels return preds_str, preds_prob diff --git a/vedastr/runners/test_runner.py b/vedastr/runners/test_runner.py index 107a0a4..ea7ffdb 100644 --- a/vedastr/runners/test_runner.py +++ b/vedastr/runners/test_runner.py @@ -25,7 +25,7 @@ def test_batch(self, img, label): else: pred = self.model((img,)) - pred, prob, label = self.postprocess(pred, self.postprocess_cfg, label) + pred, prob = self.postprocess(pred, self.postprocess_cfg) self.metric.measure(pred, prob, label) self.backup_metric.measure(pred, prob, label) @@ -41,4 +41,4 @@ def __call__(self): name, self.backup_metric.avg['acc']['true'], self.metric.avg['edit'] )) self.logger.info('Test, average acc %.4f, edit distance %s' % (self.metric.avg['acc']['true'], - self.metric.avg['edit'])) + self.metric.avg['edit'])) diff --git a/vedastr/runners/train_runner.py b/vedastr/runners/train_runner.py index 4f772a1..aa6d6b3 100644 --- a/vedastr/runners/train_runner.py +++ b/vedastr/runners/train_runner.py @@ -209,10 +209,11 @@ def save_model(self, meta=meta) def resume(self, checkpoint, resume_optimizer=False, - resume_lr_scheduler=False, resume_meta=False, + resume_lr_scheduler=False, resume_meta=False, strict=True, map_location='default'): checkpoint = self.load_checkpoint(checkpoint, - map_location=map_location) + map_location=map_location, + strict=strict) if resume_optimizer and 'optimizer' in checkpoint: self.logger.info('Resume optimizer')