Skip to content

Commit

Permalink
[shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)
Browse files Browse the repository at this point in the history
* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>
  • Loading branch information
flybird11111 and ver217 authored Jun 12, 2024
1 parent d9dddf5 commit e9f2fd9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 18 deletions.
5 changes: 4 additions & 1 deletion colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,10 @@ def gpt2_for_sequence_classification_forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
Expand Down
20 changes: 15 additions & 5 deletions colossalai/shardformer/modeling/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _get_attention_mask(
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
use_flash_attention_2: bool = False,
) -> Optional[Union[torch.Tensor, dict]]:
batch_size, seq_len = hidden_states.shape[:2]
past_key_values_length = 0
Expand All @@ -47,7 +48,7 @@ def _get_attention_mask(
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
elif use_flash_attention_2 and attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
Expand Down Expand Up @@ -162,7 +163,9 @@ def gptj_model_forward(

output_shape = input_shape + (hidden_states.size(-1),)

attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down Expand Up @@ -419,7 +422,10 @@ def gptj_for_sequence_classification_forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
Expand Down Expand Up @@ -712,7 +718,9 @@ def forward(

hidden_states = self.drop(hidden_states)

attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

Expand Down Expand Up @@ -886,7 +894,9 @@ def forward(
hidden_states = self.drop(hidden_states)

output_shape = input_shape + (hidden_states.size(-1),)
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down
32 changes: 27 additions & 5 deletions colossalai/shardformer/modeling/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SequenceClassifierOutput,
)
from transformers.models.whisper.modeling_whisper import (
_HIDDEN_STATES_START_POSITION,
WhisperDecoder,
WhisperEncoder,
WhisperForAudioClassification,
Expand Down Expand Up @@ -166,6 +167,7 @@ def forward(
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -199,9 +201,13 @@ def forward(

# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
)

hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
Expand Down Expand Up @@ -599,6 +605,7 @@ def whisper_decoder_forward(
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -716,9 +723,13 @@ def whisper_decoder_forward(

# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
)

hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
Expand Down Expand Up @@ -841,6 +852,7 @@ def whisper_model_forward(
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -944,6 +956,7 @@ def whisper_model_forward(
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -986,6 +999,7 @@ def whisper_for_conditional_generation_forward(
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -1048,6 +1062,7 @@ def whisper_for_conditional_generation_forward(
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -1118,6 +1133,12 @@ def whisper_for_audio_classification_forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

if self.config.use_weighted_layer_sum:
output_hidden_states = True
elif output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# audio_classification only holds encoder
Expand All @@ -1138,7 +1159,8 @@ def whisper_for_audio_classification_forward(
return encoder_outputs

if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
Expand Down
8 changes: 2 additions & 6 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel

ATTN_IMPLEMENTATION = {
"eager": GPTJAttention,
}
from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel

policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement]

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ray
sentencepiece
google
protobuf
transformers>=4.36.2,<4.40.0
transformers==4.39.3
peft>=0.7.1
bitsandbytes>=0.39.0
rpyc==6.0.0
Expand Down

0 comments on commit e9f2fd9

Please sign in to comment.