Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RRHFTrainer.gather_logits_labels label in-place operation error #37

Open
asadfgglie opened this issue Jul 19, 2023 · 8 comments
Open

RRHFTrainer.gather_logits_labels label in-place operation error #37

asadfgglie opened this issue Jul 19, 2023 · 8 comments

Comments

@asadfgglie
Copy link

image
train.py中的RRHFTrainer.gather_logits_labels原始的寫法如下:

mask = (labels != -100).float()
new_logits = logits.clone()  # Create a copy to avoid in-place modification
labels[labels == -100] = 0  # in-place error!
output = torch.gather(new_logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
output = output * mask # B * L
return output

這會導致 train.py line 274的get_score無法正確計算各個responses的length:

def get_score(self, logit_label, labels):
        mask = (labels != -100).float() # all elements are True
        length = mask.sum(-1)
        scores = logit_label.sum(-1) / (length ** self.args.length_penalty)
        return scores
@GanjinZero
Copy link
Owner

我之前没有遇到这个bug,你改了batch size还是什么吗

@asadfgglie
Copy link
Author

我沒有改過batch size
這是你最新的edf1764 commit的bug

@GanjinZero
Copy link
Owner

你的python版本?error是什么

@asadfgglie
Copy link
Author

python==3.9.6
torch==1.13.1+cu117

@asadfgglie
Copy link
Author

執行時不會報error
這是一個runtime error
我開著debuger研究時發現的

@asadfgglie
Copy link
Author

主要的效果是會導致再計算get_score時他的分母length會統一成最大長度的responses的長度
而不是各個responses的長度

@asadfgglie
Copy link
Author

asadfgglie commented Jul 19, 2023

原始的labels[labels == -100] = 0會改到原始記憶體位址中的值,導致inputs['labels']的值跟著被更改,使得get_score計算有誤

@GanjinZero
Copy link
Owner

感谢反馈

@GanjinZero GanjinZero mentioned this issue Sep 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants