Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch AMP check #7917

Merged
merged 11 commits into from
May 22, 2022
5 changes: 3 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,10 @@ class AutoShape(nn.Module):
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference

def __init__(self, model):
def __init__(self, model, verbose=True):
super().__init__()
LOGGER.info('Adding AutoShape... ')
if verbose:
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
Expand Down
13 changes: 7 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@
from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
one_cycle, print_args, print_mutation, strip_optimizer)
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
Expand Down Expand Up @@ -126,6 +126,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
amp_enabled = check_amp(model) # check AMP

# Freeze
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
Expand Down Expand Up @@ -293,7 +294,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
scaler = amp.GradScaler(enabled=amp_enabled)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class
callbacks.run('on_train_start')
Expand Down Expand Up @@ -348,7 +349,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
with amp.autocast(enabled=cuda):
with amp.autocast(enabled=amp_enabled):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down
28 changes: 27 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness

# Settings
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
RANK = int(os.getenv('RANK', -1))

# Settings
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
Expand Down Expand Up @@ -505,6 +507,30 @@ def check_dataset(data, autodownload=True):
return data # dictionary


def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape

help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
prefix = colorstr('AMP:')
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
return False
m = AutoShape(model, verbose=False) # model
im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB)
y = m(im).xyxy[0]
with torch.cuda.amp.autocast(enabled=True):
glenn-jocher marked this conversation as resolved.
Show resolved Hide resolved
y_amp = m(im).xyxy[0]
if (y.shape == y_amp.shape) and torch.allclose(y, y_amp, atol=1.0): # close to 1 pixel bounding box
if RANK in {-1, 0}:
LOGGER.info(emojis(f'{prefix} checks passed βœ…'))
return True
else:
if RANK in {-1, 0}:
LOGGER.warning(emojis(f'{prefix} checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False


def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
Expand Down