forked from yijingru/Vertebra-Landmark-Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
69 lines (56 loc) · 2.2 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torch.nn.functional as F
class RegL1Loss(nn.Module):
def __init__(self):
super(RegL1Loss, self).__init__()
def _gather_feat(self, feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _tranpose_and_gather_feat(self, feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def forward(self, output, mask, ind, target):
pred = self._tranpose_and_gather_feat(output, ind)
mask = mask.unsqueeze(2).expand_as(pred).float()
loss = F.l1_loss(pred * mask, target * mask, reduction='sum')
loss = loss / (mask.sum() + 1e-4)
return loss
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
class LossAll(torch.nn.Module):
def __init__(self):
super(LossAll, self).__init__()
self.L_hm = FocalLoss()
self.L_off = RegL1Loss()
self.L_wh = RegL1Loss()
def forward(self, pr_decs, gt_batch):
hm_loss = self.L_hm(pr_decs['hm'], gt_batch['hm'])
wh_loss = self.L_wh(pr_decs['wh'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['wh'])
off_loss = self.L_off(pr_decs['reg'], gt_batch['reg_mask'], gt_batch['ind'], gt_batch['reg'])
loss_dec = hm_loss + off_loss + wh_loss
return loss_dec