From 69cd3a172b22db40f8d3b9985714e2366cc8ac1c Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Tue, 7 Apr 2020 21:27:28 -0700 Subject: [PATCH] mixup data augmentation Summary: This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412) Differential Revision: D20911088 fbshipit-source-id: 36a958ef4f711d122064fae736fed7a7e91b81e8 --- classy_vision/dataset/transforms/mixup.py | 32 ++++++++++++++++++++++ classy_vision/generic/util.py | 9 +++--- classy_vision/tasks/classification_task.py | 28 +++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 classy_vision/dataset/transforms/mixup.py diff --git a/classy_vision/dataset/transforms/mixup.py b/classy_vision/dataset/transforms/mixup.py new file mode 100644 index 0000000000..2c08d8ffea --- /dev/null +++ b/classy_vision/dataset/transforms/mixup.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from classy_vision.generic.util import convert_to_one_hot +from torch.distributions.beta import Beta + + +def mixup_transform(sample, num_classes, alpha): + """ + This implements the mixup data augmentation in the paper + "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412) + + Args: + sample (Dict[str, Any]): the batch data + alpha (float): the hyperparameter of Beta distribution used to sample mixup + coefficient. + """ + assert ( + sample["target"].ndim == 1 + ), "Currently mixup only supports single-label classification" + sample["target"] = convert_to_one_hot(sample["target"].view(-1, 1), num_classes) + + c = Beta(torch.tensor([alpha]), torch.tensor([alpha])).sample() + + for key in ["input", "target"]: + sample[key] = c * sample[key] + (1.0 - c) * sample[key].flip([0]) + + return sample diff --git a/classy_vision/generic/util.py b/classy_vision/generic/util.py index cf413e4420..6e48ecf086 100644 --- a/classy_vision/generic/util.py +++ b/classy_vision/generic/util.py @@ -736,11 +736,10 @@ def maybe_convert_to_one_hot(target, model_output): ): target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1]) - assert (target.shape == model_output.shape) and ( - torch.min(target.eq(0) + target.eq(1)) == 1 - ), ( - "Target must be one-hot/multi-label encoded and of the " - "same shape as model_output." + # target can be not necessarily hard 0/1 encoding. It can be soft + # (i.e. fractional) such as mixup label + assert target.shape == model_output.shape, ( + "Target must of the " "same shape as model_output." ) return target diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 8c8c089aa7..5b0c1b0812 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn from classy_vision.dataset import ClassyDataset, build_dataset +from classy_vision.dataset.transforms.mixup import mixup_transform from classy_vision.generic.distributed_util import ( all_reduce_mean, barrier, @@ -20,6 +21,7 @@ is_distributed_training_run, ) from classy_vision.generic.util import ( + convert_to_one_hot, copy_model_to_gpu, recursive_copy_to_gpu, update_classy_state, @@ -139,6 +141,7 @@ def __init__(self): BroadcastBuffersMode.DISABLED ) self.amp_args = None + self.mixup_args = None self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED @@ -306,6 +309,20 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]): logging.info(f"AMP enabled with args {amp_args}") return self + def set_mixup_args(self, mixup_args: Optional[Dict[str, Any]]): + """Disable / enable mixup data augmentation + + Args:: + mixup_args: expect to include the follow keys in the dictionary + num_classes (int): number of dataset classes + alpha (float): the hyperparameter of Beta distribution used to + sample mixup coefficient. + """ + self.mixup_args = mixup_args + if mixup_args is None: + logging.info(f"mixup disabled") + return self + @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": """Instantiates a ClassificationTask from a configuration. @@ -348,6 +365,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": .set_optimizer(optimizer) .set_meters(meters) .set_amp_args(amp_args) + .set_mixup_args(config.get("mixup")) .set_distributed_options( broadcast_buffers_mode=BroadcastBuffersMode[ config.get("broadcast_buffers", "disabled").upper() @@ -697,6 +715,11 @@ def eval_step(self, use_gpu): + "'target' keys" ) + if self.mixup_args is not None: + sample["target"] = convert_to_one_hot( + sample["target"].view(-1, 1), self.mixup_args["num_classes"] + ) + # Copy sample to GPU target = sample["target"] if use_gpu: @@ -743,6 +766,11 @@ def train_step(self, use_gpu): + "'target' keys" ) + if self.mixup_args is not None: + sample = mixup_transform( + sample, self.mixup_args["num_classes"], self.mixup_args["alpha"] + ) + # Copy sample to GPU target = sample["target"] if use_gpu: