-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
33 lines (24 loc) · 870 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
def kaiming_init(module):
classname = module.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
def weights_init(m):
classname = m.__class__.__name__
if "Conv" in classname:
try:
m.weight.data.normal_(0.0, 0.02)
except:
pass
elif "BatchNorm" in classname:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def load_checkpoint(net, checkpoint):
from collections import OrderedDict
temp = OrderedDict()
if 'state_dict' in checkpoint:
checkpoint = dict(checkpoint['state_dict'])
for k in checkpoint:
k2 = 'module.'+k if not k.startswith('module.') else k
temp[k2] = checkpoint[k]
net.load_state_dict(temp, strict=True)