Skip to content

Commit

Permalink
Remove hook torch.nan_to_num(x) (ultralytics#8826)
Browse files Browse the repository at this point in the history
* Remove hook `torch.nan_to_num(x)`

Observed erratic training behavior (green line) with the nan_to_num hook in classifier branch. I'm going to remove it from master.

* Update train.py
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent df153a6 commit 47d9bf9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
for k, v in model.named_parameters():
v.requires_grad = True # train all layers
v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0.0
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze):
LOGGER.info(f'freezing {k}')
v.requires_grad = False
Expand Down

0 comments on commit 47d9bf9

Please sign in to comment.