Skip to content

Commit

Permalink
[shardformer] upgrade transformers to 4.39.3 (#5815)
Browse files Browse the repository at this point in the history
* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral (#5808)

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama (#5809)

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers (#5810)

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini (#5814)

---------

Co-authored-by: ver217 <lhx0217@gmail.com>
  • Loading branch information
flybird11111 and ver217 authored Jun 14, 2024
1 parent 3bcbba9 commit 2ddf624
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 238 deletions.
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
state_dict_cpu = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, state_dict)

if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
Expand Down
74 changes: 30 additions & 44 deletions colossalai/inference/modeling/models/glide_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@

import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
Expand Down Expand Up @@ -137,6 +133,7 @@ def glide_llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -147,57 +144,43 @@ def glide_llama_model_forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
position_ids = position_ids.unsqueeze(0)

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -212,6 +195,7 @@ def glide_llama_model_forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]
Expand All @@ -230,7 +214,9 @@ def glide_llama_model_forward(

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
5 changes: 4 additions & 1 deletion colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,10 @@ def gpt2_for_sequence_classification_forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
Expand Down
20 changes: 15 additions & 5 deletions colossalai/shardformer/modeling/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _get_attention_mask(
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
use_flash_attention_2: bool = False,
) -> Optional[Union[torch.Tensor, dict]]:
batch_size, seq_len = hidden_states.shape[:2]
past_key_values_length = 0
Expand All @@ -47,7 +48,7 @@ def _get_attention_mask(
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
elif use_flash_attention_2 and attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
Expand Down Expand Up @@ -162,7 +163,9 @@ def gptj_model_forward(

output_shape = input_shape + (hidden_states.size(-1),)

attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down Expand Up @@ -419,7 +422,10 @@ def gptj_for_sequence_classification_forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
Expand Down Expand Up @@ -712,7 +718,9 @@ def forward(

hidden_states = self.drop(hidden_states)

attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

Expand Down Expand Up @@ -886,7 +894,9 @@ def forward(
hidden_states = self.drop(hidden_states)

output_shape = input_shape + (hidden_states.size(-1),)
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down
Loading

0 comments on commit 2ddf624

Please sign in to comment.