Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

训练时长 #21

Open
NickCtrl opened this issue Dec 8, 2022 · 3 comments
Open

训练时长 #21

NickCtrl opened this issue Dec 8, 2022 · 3 comments

Comments

@NickCtrl
Copy link

NickCtrl commented Dec 8, 2022

你好,目前我自己的数据集已经投入训练了,我用您的 best_model.pth 作为预训练模型加载,训练集5000张,验证集1000张,RTX A5000单卡bs=32训练一个epoch 大概是15分钟,这个训练耗时正常吗?感觉我这个小数据集耗时有点严重,如果要优化,该怎么入手优化?两个分支都要训练,看了另外一个issue,也提到了loss训练时长的问题,求帮忙解答一下,谢谢!

@IrohXu
Copy link
Owner

IrohXu commented Dec 8, 2022

监控一下gpu-util,输出曲线来分析计算瓶颈。LaneNet能优化的就两个地方,Discriminative Loss和Dataloader。

@NickCtrl
Copy link
Author

NickCtrl commented Jan 4, 2023

@IrohXu 近期在学习lanenet实现细节,您的程序我仔细的研读了很多遍,loss.py 中:
def _discriminative_loss(self, embedding, seg_gt):
...
for lane_idx in labels:
seg_mask_i = (seg_gt_b == lane_idx)

            if not seg_mask_i.any():
                continue
            
            embedding_i = embedding_b * seg_mask_i
            mean_i = torch.sum(embedding_i, dim=1) / torch.sum(seg_mask_i)
            centroid_mean.append(mean_i)
            # ---------- var_loss -------------
            var_loss = var_loss + torch.sum(F.relu(
                torch.norm(embedding_i[:,seg_mask_i] - mean_i.reshape(embed_dim, 1), dim=0) - self.delta_var) ** 2) / torch.sum(seg_mask_i) / num_lanes

var_loss 的计算,并没有用到gt中各条车道线的 labels 值,只是用到了gt中各条车道线的位置信息,并计算了各条车道线的均值mean_i,那这个均值 mean_i 的初值岂不是非常随意?这里始终不能很好的理解,求指点,谢谢!

@IrohXu
Copy link
Owner

IrohXu commented Jan 5, 2023

阅读这篇论文即可理解:https://arxiv.org/abs/1708.02551

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants