Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TensorBoard #3669

Merged
merged 1 commit into from
Jun 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
tb_writer=None
):
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
Expand Down Expand Up @@ -74,9 +73,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict

# Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict
# Loggers
loggers = {'wandb': None, 'tb': None} # loggers dict
if rank in [-1, 0]:
# TensorBoard
if not opt.evolve:
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(opt.save_dir)

# W&B
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
Expand Down Expand Up @@ -219,8 +225,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# model._initialize_biases(cf.to(device))
if plots:
plot_labels(labels, names, save_dir, loggers)
if tb_writer:
tb_writer.add_histogram('classes', c, 0)
if loggers['tb']:
loggers['tb'].add_histogram('classes', c, 0) # TensorBoard

# Anchors
if not opt.noautoanchor:
Expand Down Expand Up @@ -341,18 +347,18 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if tb_writer and ni == 0:
if loggers['tb'] and ni == 0: # TensorBoard
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})

# end batch ------------------------------------------------------------------------------------------------

# Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
lr = [x['lr'] for x in optimizer.param_groups] # for loggers
scheduler.step()

# DDP process 0 or single-GPU
Expand Down Expand Up @@ -385,8 +391,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
if tb_writer:
tb_writer.add_scalar(tag, x, epoch) # tensorboard
if loggers['tb']:
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
if wandb_logger.wandb:
wandb_logger.log({tag: x}) # W&B

Expand Down Expand Up @@ -537,12 +543,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Train
logger.info(opt)
if not opt.evolve:
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]:
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(opt.hyp, opt, device, tb_writer)
train(opt.hyp, opt, device)

# Evolve hyperparameters (optional)
else:
Expand Down