Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix finetuning notebook augmentation #2071

Merged
merged 9 commits into from
Jan 26, 2022
14 changes: 11 additions & 3 deletions docs/_src/tutorials/tutorials/2.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,16 @@ To get the most out of model distillation, we recommend increasing the size of y
```python
# Downloading script
!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py
# Just replace the path with your dataset and adjust the output
!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2

# Downloading smaller glove vector file (only for demonstration purposes)
!wget https://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip

# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)
!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json

# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)
!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt
```

In this case, we use a multiplication factor of 2 to keep this example lightweight. Usually you would use a factor like 20 depending on the size of your training data. Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.
Expand All @@ -124,7 +132,7 @@ teacher = FARMReader(model_name_or_path="my_model", use_gpu=True)
# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.
student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D", use_gpu=True)

student.distil_intermediate_layers_from(teacher, data_dir="data/squad20", train_filename="augmented_dataset.json", use_gpu=True)
student.distil_intermediate_layers_from(teacher, data_dir=".", train_filename="augmented_dataset.json", use_gpu=True)
student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True)

student.save(directory="my_distilled_model")
Expand Down
138 changes: 89 additions & 49 deletions haystack/modeling/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -583,19 +585,23 @@ def forward(
It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

output_tuple = self.model(
input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
if self.model.encoder.config.output_hidden_states == True:
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
return output_tuple

def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True
Expand Down Expand Up @@ -654,6 +660,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -665,19 +673,23 @@ def forward(
It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

output_tuple = self.model(
input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
if self.model.encoder.config.output_hidden_states == True:
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
return output_tuple

def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True
Expand Down Expand Up @@ -736,6 +748,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -747,19 +761,23 @@ def forward(
It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

output_tuple = self.model(
input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
if self.model.encoder.config.output_hidden_states == True:
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
return output_tuple

def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True
Expand Down Expand Up @@ -832,6 +850,8 @@ def forward( # type: ignore
self,
input_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -840,20 +860,24 @@ def forward( # type: ignore
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

output_tuple = self.model(
input_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
# We need to manually aggregate that to get a pooled output (one vec per seq)
pooled_output = self.pooler(output_tuple[0])
if self.model.config.output_hidden_states == True:
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
else:
sequence_output = output_tuple[0]
return sequence_output, pooled_output
return (output_tuple[0], pooled_output) + output_tuple[1:]

def enable_hidden_states_output(self):
self.model.config.output_hidden_states = True
Expand Down Expand Up @@ -921,6 +945,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -932,26 +958,28 @@ def forward(
It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

# Note: XLNet has a couple of special input tensors for pretraining / text generation (perm_mask, target_mapping ...)
# We will need to implement them, if we wanna support LM adaptation
output_tuple = self.model(
input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
# XLNet also only returns the sequence_output (one vec per token)
# We need to manually aggregate that to get a pooled output (one vec per seq)
# TODO verify that this is really doing correct pooling
pooled_output = self.pooler(output_tuple[0])

if self.model.output_hidden_states == True:
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output = output_tuple[0]
return sequence_output, pooled_output
return (output_tuple[0], pooled_output) + output_tuple[1:]

def enable_hidden_states_output(self):
self.model.output_hidden_states = True
Expand Down Expand Up @@ -1030,6 +1058,8 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -1038,6 +1068,8 @@ def forward(
:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
output_tuple = self.model(
Expand All @@ -1046,18 +1078,20 @@ def forward(
attention_mask=padding_mask,
)

if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

output_tuple = self.model(
input_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
# We need to manually aggregate that to get a pooled output (one vec per seq)
pooled_output = self.pooler(output_tuple[0])

if self.model.config.output_hidden_states == True:
sequence_output, all_hidden_states = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
else:
sequence_output = output_tuple[0]
return sequence_output, pooled_output

def enable_hidden_states_output(self):
self.model.config.output_hidden_states = True
return (output_tuple[0], pooled_output) + output_tuple[1:]

def disable_hidden_states_output(self):
self.model.config.output_hidden_states = False
Expand Down Expand Up @@ -1439,30 +1473,36 @@ def forward(
input_ids: torch.Tensor,
segment_ids: torch.Tensor,
padding_mask: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
):
"""
Perform the forward pass of the BERT model.
Perform the forward pass of the BigBird model.

:param input_ids: The ids of each token in the input sequence. Is a tensor of shape [batch_size, max_seq_len]
:param segment_ids: The id of the segment. For example, in next sentence prediction, the tokens in the
first sentence are marked with 0 and those in the second are marked with 1.
It is a tensor of shape [batch_size, max_seq_len]
:param padding_mask: A mask that assigns a 1 to valid input tokens and 0 to padding tokens
of shape [batch_size, max_seq_len]
:param output_hidden_states: Whether to output hidden states in addition to the embeddings
:param output_attentions: Whether to output attentions in addition to the embeddings
:return: Embeddings for each token in the input sequence.
"""
if output_hidden_states is None:
output_hidden_states = self.model.encoder.config.output_hidden_states
if output_attentions is None:
output_attentions = self.model.encoder.config.output_attentions

output_tuple = self.model(
input_ids,
token_type_ids=segment_ids,
attention_mask=padding_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions
)
if self.model.encoder.config.output_hidden_states == True:
sequence_output, pooled_output, all_hidden_states = output_tuple[0], output_tuple[1], output_tuple[2]
return sequence_output, pooled_output, all_hidden_states
else:
sequence_output, pooled_output = output_tuple[0], output_tuple[1]
return sequence_output, pooled_output
return output_tuple

def enable_hidden_states_output(self):
self.model.encoder.config.output_hidden_states = True
Expand Down
14 changes: 11 additions & 3 deletions tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,16 @@
"source": [
"# Downloading script\n",
"!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/haystack/utils/augment_squad.py\n",
"# Just replace the path with your dataset and adjust the output\n",
"!python augment_squad.py --squad_path data/squad20/dev-v2.0.json --output_path augmented_dataset.json --multiplication_factor 2"
"\n",
"# Downloading smaller glove vector file (only for demonstration purposes)\n",
"!wget https://nlp.stanford.edu/data/glove.6B.zip\n",
"!unzip glove.6B.zip\n",
"\n",
"# Downloading very small dataset to make tutorial faster (please use a bigger dataset for real use cases)\n",
"!wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json\n",
"\n",
"# Just replace the path with your dataset and adjust the output (also please remove glove path to use bigger glove vector file)\n",
"!python augment_squad.py --squad_path small.json --output_path augmented_dataset.json --multiplication_factor 2 --glove_path glove.6B.300d.txt"
]
},
{
Expand Down Expand Up @@ -217,7 +225,7 @@
"# The number of the layers in the teacher model also needs to be a multiple of the number of the layers in the student.\n",
"student = FARMReader(model_name_or_path=\"huawei-noah/TinyBERT_General_6L_768D\", use_gpu=True)\n",
"\n",
"student.distil_intermediate_layers_from(teacher, data_dir=\"data/squad20\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n",
"student.distil_intermediate_layers_from(teacher, data_dir=\".\", train_filename=\"augmented_dataset.json\", use_gpu=True)\n",
"student.distil_prediction_layer_from(teacher, data_dir=\"data/squad20\", train_filename=\"dev-v2.0.json\", use_gpu=True)\n",
"\n",
"student.save(directory=\"my_distilled_model\")"
Expand Down
13 changes: 11 additions & 2 deletions tutorials/Tutorial2_Finetune_a_model_on_your_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pathlib import Path

import os

def tutorial2_finetune_a_model_on_your_data():
# ## Create Training Data
Expand Down Expand Up @@ -65,8 +66,16 @@ def distil():
# ### Augmenting your training data
# To get the most out of model distillation, we recommend increasing the size of your training data by using data augmentation.
# You can do this by running the [`augment_squad.py` script](https://github.com/deepset-ai/haystack/blob/master/haystack/utils/augment_squad.py):
# # Just replace dataset.json with the name of your dataset and adjust the output path
augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2)

# Downloading smaller glove vector file (only for demonstration purposes)
os.system("wget https://nlp.stanford.edu/data/glove.6B.zip")
os.system("unzip glove.6B.zip")

# Downloading very small dataset to make tutorial faster (please use a bigger dataset in real use cases)
os.system("wget https://raw.githubusercontent.com/deepset-ai/haystack/master/test/samples/squad/small.json")

# Just replace dataset.json with the name of your dataset and adjust the output path
augment_squad.main(squad_path=Path("dataset.json"), output_path=Path("augmented_dataset.json"), multiplication_factor=2, glove_path=Path("glove.6B.300d.txt"))
# In this case, we use a multiplication factor of 2 to keep this example lightweight.
# Usually you would use a factor like 20 depending on the size of your training data.
# Augmenting this small dataset with a multiplication factor of 2, should take about 5 to 10 minutes to run on one V100 GPU.
Expand Down