Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding distillation loss functions from TinyBERT #1879

Merged
merged 18 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ None
#### distil\_from

```python
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0)
| distil_from(teacher_model: "FARMReader", data_dir: str, train_filename: str, dev_filename: Optional[str] = None, test_filename: Optional[str] = None, use_gpu: Optional[bool] = None, student_batch_size: int = 10, teacher_batch_size: Optional[int] = None, n_epochs: int = 2, learning_rate: float = 1e-5, max_seq_len: Optional[int] = None, warmup_proportion: float = 0.2, dev_split: float = 0, evaluate_every: int = 300, save_dir: Optional[str] = None, num_processes: Optional[int] = None, use_amp: str = None, checkpoint_root_dir: Path = Path("model_checkpoints"), checkpoint_every: Optional[int] = None, checkpoints_to_keep: int = 3, caching: bool = False, cache_path: Path = Path("cache/data_silo"), distillation_loss_weight: float = 0.5, distillation_loss: Union[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = "kl_div", temperature: float = 1.0, tinybert_loss: bool = False, tinybert_epochs: int = 1)
```

Fine-tune a model on a QA dataset using distillation. You need to provide a teacher model that is already finetuned on the dataset
Expand Down Expand Up @@ -218,8 +218,10 @@ If any checkpoints are stored, a subsequent run of train() will resume training
:param caching whether or not to use caching for preprocessed dataset and teacher logits
- `cache_path`: Path to cache the preprocessed dataset and teacher logits
- `distillation_loss_weight`: The weight of the distillation loss. A higher weight means the teacher outputs are more important.
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named paramters student_logits and teacher_logits)
- `distillation_loss`: Specifies how teacher and model logits should be compared. Can either be a string ("mse" for mean squared error or "kl_div" for kl divergence loss) or a callable loss function (needs to have named parameters student_logits and teacher_logits)
- `temperature`: The temperature for distillation. A higher temperature will result in less certainty of teacher outputs. A lower temperature means more certainty. A temperature of 1.0 does not change the certainty of the model.
- `tinybert_loss`: Whether to use the TinyBERT loss function for distillation. This requires the student to be a TinyBERT model and the teacher to be a finetuned version of bert-base-uncased.
- `tinybert_epochs`: Number of epochs to train the student model with the TinyBERT loss function. After this many epochs, the student model is trained with the regular distillation loss function.

**Returns**:

Expand Down
40 changes: 33 additions & 7 deletions haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,19 +743,29 @@ def __init__(self, teacher_model: "FARMReader", teacher_batch_size: int, device:
super().__init__(max_processes=max_processes, processor=processor, batch_size=batch_size, eval_batch_size=eval_batch_size,
distributed=distributed, automatic_loading=automatic_loading, caching=caching, cache_path=cache_path)

def _run_teacher(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int],
def _run_teacher(self, batch: dict) -> List[torch.Tensor]:
"""
Run the teacher model on the given batch.
"""
return self.teacher.inferencer.model(**batch)

def _pass_batches(self, batch: List[List[torch.Tensor]], corresponding_chunks: List[int],
teacher_outputs: List[List[Tuple[torch.Tensor, ...]]], tensor_names: List[str]):
with torch.no_grad():
batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...)
batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature
batch_dict = {key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list)} # create input dict
y = self.teacher.inferencer.model(**batch_dict)
y = self._run_teacher(batch=batch_dict) # run teacher model
y = [y.cpu() for y in y]
self.output_len = len(y)

# grouping by chunk
for i, data in zip(corresponding_chunks, zip(*y)): # transpose back
teacher_outputs[i].append(data)
return

def _teacher_output_names(self) -> List[str]:
return ["teacher_output_" + str(i) for i in range(self.output_len)]

def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[List[Dict]] = None):
concat_datasets, tensor_names = super()._get_dataset(filename, dicts)
Expand All @@ -772,16 +782,16 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis
batch.append(x)
corresponding_chunks.append(i)
if len(batch) == self.teacher_batch_size:
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model
self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names) # doing forward pass on teacher model
batch = []
corresponding_chunks = []
if batch:
self._run_teacher(batch, corresponding_chunks, teacher_outputs, tensor_names)
self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names)

# appending teacher outputs to original dataset
for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs):
dataset.tensors += tuple(torch.stack(tensors) for tensors in zip(*teacher_output))
tensor_names.extend(["teacher_output_" + str(i) for i, _ in enumerate(zip(*teacher_output))])
tensor_names += self._teacher_output_names()
concat_datasets = ConcatDataset(concat_datasets.datasets) # making sure metrics are updated
return concat_datasets, tensor_names

Expand All @@ -796,7 +806,23 @@ def _get_checksum(self):
"max_seq_len": self.processor.max_seq_len,
"dev_split": self.processor.dev_split,
"tasks": self.processor.tasks,
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"]
"teacher_name_or_path": self.teacher.pipeline_config["params"]["model_name_or_path"],
"data_silo_type": self.__class__.__name__,
}
checksum = get_dict_checksum(payload_dict)
return checksum
return checksum

class TinyBERTDistillationDataSilo(DistillationDataSilo):
def _run_teacher(self, batch: dict) -> List[torch.Tensor]:
"""
Run the teacher model on the given batch.
"""
model = self.teacher.inferencer.model
logits, hidden_states, attentions = model.forward(**batch, output_attentions=True, output_hidden_states=True)
self.hidden_states_len = len(hidden_states)
self.attentions_len = len(attentions)
return hidden_states + attentions

def _teacher_output_names(self) -> List[str]:
return [f"teacher_hidden_state_{i}" for i in range(self.hidden_states_len)] + \
[f"teacher_attention_{i}" for i in range(self.attentions_len)]
24 changes: 20 additions & 4 deletions haystack/modeling/model/adaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy
import torch
from torch import nn
from torch import nn, set_warn_always
from transformers import AutoConfig
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model

Expand Down Expand Up @@ -356,7 +356,7 @@ def prepare_labels(self, **kwargs):
all_labels.append(labels)
return all_labels

def forward(self, **kwargs):
def forward(self, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the doc strings for the new parameters here as well, e.g.:

 :param output_hidden_states: Whether to output hidden states
 :param output_attentions: Whether to output attentions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the doc strings.

"""
Push data through the whole model and returns logits. The data will
propagate through the language model and each of the attached prediction heads.
Expand All @@ -366,8 +366,17 @@ def forward(self, **kwargs):
:return: All logits as torch.tensor or multiple tensors.
"""
# Run forward pass of language model
sequence_output, pooled_output = self.forward_lm(**kwargs)

output_tuple = self.language_model.forward(**kwargs, output_hidden_states=output_hidden_states, output_attentions=output_attentions)
if output_hidden_states:
if output_attentions:
sequence_output, pooled_output, hidden_states, attentions = output_tuple
else:
sequence_output, pooled_output, hidden_states = output_tuple
else:
if output_attentions:
sequence_output, pooled_output, attentions = output_tuple
else:
sequence_output, pooled_output = output_tuple
# Run forward pass of (multiple) prediction heads using the output from above
all_logits = []
if len(self.prediction_heads) > 0:
Expand All @@ -392,6 +401,13 @@ def forward(self, **kwargs):
# just return LM output (e.g. useful for extracting embeddings at inference time)
all_logits.append((sequence_output, pooled_output))

if output_hidden_states:
if output_attentions:
return all_logits, hidden_states, attentions
else:
return all_logits, hidden_states
elif output_attentions:
return all_logits, attentions
return all_logits

def forward_lm(self, **kwargs):
Expand Down
18 changes: 12 additions & 6 deletions haystack/modeling/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add docstrings for these new parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have added the docstrings.

**kwargs,
):
"""
Expand All @@ -501,13 +503,17 @@ def forward(
input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask,
output_hidden_states=self.model.encoder.config.output_hidden_states or output_hidden_states,
output_attentions=self.model.encoder.config.output_attentions or output_attentions,
return_dict=False
)
if self.model.encoder.config.output_hidden_states == True:
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
return output_tuple
# if self.model.encoder.config.output_hidden_states == True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the commented code. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now deleted the commented code. It is unnecessary as output tuple is now handled by HuggingFace transformers.

# sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
# return sequence_output, pooled_output, all_hidden_states
# else:
# sequence_output, pooled_output = output_tuple[0], output_tuple[1]
# return sequence_output, pooled_output

def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from haystack.modeling.training.base import Trainer, DistillationTrainer
from haystack.modeling.training.base import Trainer, DistillationTrainer, TinyBERTDistillationTrainer
Loading