Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RT-DETR weights initialization #31724

Merged
merged 7 commits into from
Jul 3, 2024
Merged

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Jul 1, 2024

What does this PR do?

Fix RT-DETR bbox and class head weight initialization.

  1. In_init_weight method bbox and class heads are not reachable for initialization. This sometimes leads to unstable training and lower results (see experiments below).
  2. Added config parameter for bias initialization, in the original code biases of the head are initialized with prior_prob=0.01 which is OK for training with 80 classes, however, while fine-tuning this value should be adjusted.

Results of the fine-tuning on main vs fix branches on CPPE-5 dataset (averaged for 6 runs each):

  • better convergence: +0.1-0.15 mAP50 on average on eval and test sets
  • lower dispersion of results
Screenshot 2024-07-01 at 09 42 12 Screenshot 2024-07-01 at 09 43 00

Who can review?

@amyeroberts

cc @SangbumChoi @NielsRogge

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this! Does the model converge as fast as the original implementation?

@qubvel
Copy link
Member Author

qubvel commented Jul 1, 2024

I didn't have a chance to run the fine-tuning with the original code, maybe @SangbumChoi has a fine-tuning script to compare. However, I would say that from my previous experiments with other detection models in transformers RT-DTER with this fix is the best in terms of compute, convergence speed, and results achieved on cppe-5 dataset.

@merveenoyan
Copy link
Contributor

@qubvel is there anything else that needs to be done?

@qubvel
Copy link
Member Author

qubvel commented Jul 1, 2024

@merve there is one more PR for generating anchors cache fix, but it's not that critical
#31671

I also implemented sdpa attention, but didn't observe any speed-up in inference speed

@SangbumChoi
Copy link
Contributor

@qubvel Isn't SDPA is default operation in MDHA?

This function has already been incorporated into torch.nn.MultiheadAttention and torch.nn.TransformerEncoderLayer.

Since there are many FLOPS in encoder (which is not related to Attention module) I guess speed-up with applying attention friendly library such as SDPA, xformers might be marginal.

@qubvel @NielsRogge Thanks for this PR. (Good to here that this is the best result by far) Unfortunately I don't have any results of finetuning raw RTDETR repo. (I have some test result in Transformers RTDETR).

@qubvel
Copy link
Member Author

qubvel commented Jul 1, 2024

@SangbumChoi I'm talking about RTDetrMultiheadAttention here, I added support for sdpa for this class, but didn't observe any speed-up, I will open a separate PR to discuss it :)

@qubvel qubvel requested a review from amyeroberts July 1, 2024 14:36
@qubvel qubvel changed the title Fix R-DETR weights initialization Fix RT-DETR weights initialization Jul 1, 2024
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@qubvel qubvel merged commit 048f599 into huggingface:main Jul 3, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants