Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discriminative Loss can not calculate dist loss... #22

Open
IrohXu opened this issue Jun 18, 2021 · 3 comments
Open

Discriminative Loss can not calculate dist loss... #22

IrohXu opened this issue Jun 18, 2021 · 3 comments

Comments

@IrohXu
Copy link

IrohXu commented Jun 18, 2021

During training, I found the instance part of your program is not accurate, the instance loss is too fast to converge to 0. Thus, I checked the loss.py file of your code, and found that parameter num_lanes inside class DiscriminativeLoss is always = 1 during training. This might be caused by your one-hot representation for instance target.

Due to num_lanes = 1, this part will miss:

if num_lanes > 1:
  centroid_mean1 = centroid_mean.reshape(-1, 1, num_lanes)
  centroid_mean2 = centroid_mean.reshape(1, -1, num_lanes)
  dist = torch.norm(centroid_mean1 - centroid_mean2, dim=2)  # shape (num_lanes, num_lanes)
  dist = dist + torch.eye(embed_dim, dtype=dist.dtype, device=dist.device) * self.delta_dist  # diagonal elements are 0, now mask above delta_d
  dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_dist) ** 2) / (
          num_lanes * (num_lanes - 1)) / 2

#################################################

To solve this problem, I think we need to fix the DiscriminativeLoss or build a new dataloader.
I change a little in DiscriminativeLoss and this part may works (I am not sure):

class DiscriminativeLoss(_Loss):

    def __init__(self, delta_var=0.5, delta_dist=1.5, norm=2, alpha=1.0, beta=1.0, gamma=0.001,
                 usegpu=False, size_average=True):
        super(DiscriminativeLoss, self).__init__(reduction='mean')
        self.delta_var = delta_var
        self.delta_dist = delta_dist
        self.norm = norm
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.usegpu = usegpu
        assert self.norm in [1, 2]

    def forward(self, input, target):

        return self._discriminative_loss(input, target)

    def _discriminative_loss(self, embedding, seg_gt):
        batch_size, embed_dim, H, W = embedding.shape
        embedding = embedding.reshape(batch_size, embed_dim, H*W)
        seg_gt = seg_gt.reshape(batch_size, embed_dim, H*W)

        var_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
        dist_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
        reg_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)

        for b in range(batch_size):
            embedding_b = embedding[b]  # (embed_dim, H*W)
            seg_gt_b = torch.zeros((H * W)).to(DEVICE)

            for j in range(0, embed_dim):
                seg_gt_b += seg_gt[b][j] * (j+1)

            labels, indexs = torch.unique(seg_gt_b, return_inverse=True)
            num_lanes = len(labels)
            if num_lanes == 0:
                # please refer to issue here: https://github.com/harryhan618/LaneNet/issues/12
                _nonsense = embedding.sum()
                _zero = torch.zeros_like(_nonsense)
                var_loss = var_loss + _nonsense * _zero
                dist_loss = dist_loss + _nonsense * _zero
                reg_loss = reg_loss + _nonsense * _zero
                continue

            centroid_mean = []
            for lane_idx in labels:
                seg_mask_i = (seg_gt_b == lane_idx)

                if not seg_mask_i.any():
                    continue
                
                embedding_i = embedding_b * seg_mask_i
                mean_i = torch.sum(embedding_i, dim=1) / torch.sum(seg_mask_i)

                centroid_mean.append(mean_i)

                # ---------- var_loss -------------
                var_loss = var_loss + torch.sum(F.relu(
                    torch.norm(embedding_i[:,seg_mask_i] - mean_i.reshape(embed_dim, 1), dim=0) - self.delta_var) ** 2) / torch.sum(seg_mask_i) / num_lanes
            centroid_mean = torch.stack(centroid_mean)  # (n_lane, embed_dim)

            if num_lanes > 1:
                centroid_mean1 = centroid_mean.reshape(-1, 1, embed_dim)
                centroid_mean2 = centroid_mean.reshape(1, -1, embed_dim)

                dist = torch.norm(centroid_mean1 - centroid_mean2, dim=2)  # shape (num_lanes, num_lanes)
                dist = dist + torch.eye(num_lanes, dtype=dist.dtype,
                                        device=dist.device) * self.delta_dist  # diagonal elements are 0, now mask above delta_d

                # divided by two for double calculated loss above, for implementation convenience
                dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_dist) ** 2) / (
                        num_lanes * (num_lanes - 1)) / 2

            # reg_loss is not used in original paper
            # reg_loss = reg_loss + torch.mean(torch.norm(centroid_mean, dim=1))

        var_loss = var_loss / batch_size
        dist_loss = dist_loss / batch_size
        reg_loss = reg_loss / batch_size

        return var_loss, dist_loss, reg_loss

Who can help me test it? Thanks a lot.

@Rakuzan-Developer
Copy link

RuntimeError: shape '[4, 131072]' is invalid for input of size 1572864
when ı try to my own dataset, which is same as with your dataset structure, ı am getting this error:
lanenet-lane-detection-pytorch/model/lanenet/loss.py", line 64, in _discriminative_loss
seg_gt = seg_gt.reshape(batch_size, H*W)
on loss.py. How can ı fix ?

@klintan
Copy link
Owner

klintan commented Feb 24, 2022

Thanks for this ( unfortunately I missed it last year :O ) , @IrohXu would be awesome if you create a PR for this and I'll review it.

@qq852518421
Copy link

@IrohXu what's means for

for j in range(0, embed_dim):
                seg_gt_b += seg_gt[b][j] * (j+1)

in your code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants