Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding distillation loss functions from TinyBERT #1879

Merged
merged 18 commits into from
Dec 23, 2021
Merged

Conversation

MichelBartels
Copy link
Contributor

@MichelBartels MichelBartels commented Dec 13, 2021

Proposed changes:
This adds the distillation loss functions from TinyBERT as explained in #1873.

Status (please check what you already did):

  • First draft (up for discussions & feedback)
  • Final code
  • Added tests
  • Updated documentation

This adds two parameters to the distil_from method. Enabling the parameter tinybert_loss adds an additional distillation stage before the original one. tinybert_epochs specifies the number of epochs in this stage. The stage is realised using a new TinyBERTDistillationTrainer that computes the teacher hidden states and attention on the fly.
Caching of the teacher is not used as this would take up too much memory (100s to 1000s of gigabytes). This means that the standard DistillationTrainer can be used.

@MichelBartels MichelBartels marked this pull request as ready for review December 22, 2021 15:15
Copy link
Member

@julian-risch julian-risch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very good already. Just some smaller changes requested. Most interesting for you is the missing return keyword in haystack/nodes/reader/farm.py I guess. Happy to jump on a quick call in the afternoon if you want to discuss something.

def test_tinybert_distillation():
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_4L_312D")
teacher = FARMReader(model_name_or_path="bert-base-uncased")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use as smaller teacher model here to speed up the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be theoretically possible, but the teacher model would need to have the exact same dimensions except for the number of layers which would need to be a multiple of the number of student layers. This means it is quite hard to find a matching model. If it is a big performance issue we could perhaps create our own "mock model" with the right parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's keep it as it is for now. 👍

:return: None
"""
if tinybert_loss:
self._training_procedure(data_dir=data_dir, train_filename=train_filename,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a return missing here in front of self._training_procedure(...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, task specific distillation for TinyBERT has two stages and the second stage is the same as what we have already implemented. So calling _training_procedure with tinybert=True only executes the first stage. I have added a short comment explaining that.

@@ -1,9 +1,6 @@
from typing import Optional, Union, Tuple, List, Callable

from typing import TYPE_CHECKING
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why we got rid of these lines so that I understand a bit better? _LRScheduleris A response to this comment would be fine. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint for DistillationTrainer turned out to be wrong. Because of that I don't need to import TYPE_CHECKING anymore as it's just necessary for preventing the circular import of FARMReader. There was never really a reason to also use that for _LRScheduler so `_LRScheduler can just be imported normally.

@@ -630,7 +627,7 @@ class DistillationTrainer(Trainer):
"""
def __init__(
self,
model: "FARMReader",
model: "AdaptiveModel",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the code ready to use other models than FARMReader in its current form?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can basically train any AdaptiveModel with a QA prediction_head. I changed this line because I realised that _training_procedure only passes the AdaptiveModel. This behavior is exactly the same for the normal Trainer class.

@@ -484,6 +484,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstrings for these new parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have added the docstrings.

sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
return output_tuple
# if self.model.encoder.config.output_hidden_states == True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the commented code. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now deleted the commented code. It is unnecessary as output tuple is now handled by HuggingFace transformers.

@@ -356,7 +356,7 @@ def prepare_labels(self, **kwargs):
all_labels.append(labels)
return all_labels

def forward(self, **kwargs):
def forward(self, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the doc strings for the new parameters here as well, e.g.:

 :param output_hidden_states: Whether to output hidden states
 :param output_attentions: Whether to output attentions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the doc strings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants