From 501205dcb4a5056fa2c52ceedc9f38eb8b55904a Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 19 Jul 2024 00:42:46 +0000 Subject: [PATCH] ring attn + tp, pp tests passed; fix typos such as causal --- .pre-commit-config.yaml | 1 + .../booster/plugin/hybrid_parallel_plugin.py | 2 +- .../pipeline/schedule/interleaved_pp.py | 1 - colossalai/shardformer/layer/attn.py | 27 ++++++------ colossalai/shardformer/layer/utils.py | 5 ++- colossalai/shardformer/modeling/llama.py | 42 ++++++++++--------- colossalai/shardformer/policies/command.py | 4 +- colossalai/shardformer/policies/deepseek.py | 2 +- colossalai/shardformer/policies/llama.py | 2 +- colossalai/shardformer/policies/mistral.py | 2 +- colossalai/shardformer/policies/mixtral.py | 2 +- colossalai/shardformer/policies/qwen2.py | 2 +- examples/language/llama/benchmark.py | 3 +- .../language/openmoe/model/openmoe_policy.py | 2 +- examples/language/opt/README.md | 2 +- examples/tutorial/opt/opt/README.md | 2 +- tests/kit/model_zoo/__init__.py | 4 +- tests/kit/model_zoo/transformers/command.py | 12 +++--- tests/kit/model_zoo/transformers/llama.py | 12 +++--- tests/kit/model_zoo/transformers/mistral.py | 2 +- tests/kit/model_zoo/transformers/qwen2.py | 12 +++--- .../test_plugin/test_3d_plugin.py | 2 +- .../test_plugin/test_low_level_zero_plugin.py | 2 +- .../test_gemini_checkpoint_io.py | 2 +- .../test_gemini_torch_compability.py | 2 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 2 +- .../test_low_level_zero_checkpoint_io.py | 2 +- .../test_plugins_huggingface_compatibility.py | 2 +- tests/test_lora/test_lora.py | 2 +- tests/test_shardformer/test_model/_utils.py | 7 +++- .../test_model/test_shard_command.py | 4 +- .../test_model/test_shard_llama.py | 20 +++++++-- 32 files changed, 105 insertions(+), 85 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9088d0e1bb71..e2a038e628d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,7 @@ repos: hooks: - id: isort name: sort all imports (python) + args: ["--profile", "black"] # avoid comflict with black - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.4.2 diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 057dad389a0b..7b58e60f2018 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -187,7 +187,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): """ if self.shard_config.enable_sequence_parallelism: - if self.shard_config.sequence_parallelism_mode == "all_to_all": + if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: return if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a21b45c44a2c..412f3896fb80 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -286,7 +286,6 @@ def forward_step( # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous stage containing hidden_states etc. # Only attention_mask from micro_batch is used - with self.stage_manager.switch_model_chunk_id(model_chunk_id): if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 8aa1a37afbb2..cb429bcb8329 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -415,13 +415,12 @@ def _rescale_out_lse(out, block_out, lse, block_lse): # new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) new_block_lse = torch.exp(block_lse - new_lse) - assert _not_nan(new_lse), new_lse - # dist.barrier() - assert _not_nan(new_block_lse), new_block_lse out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out) lse.copy_(new_lse) - + assert _not_nan(new_lse), new_lse + assert _not_nan(new_block_lse), new_block_lse + assert _not_nan(out), out # block_out = block_out.float() # out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out)) # lse.copy_(lse - F.logsigmoid(lse - block_lse)) @@ -600,7 +599,8 @@ def forward( b, h, sq, d = q.shape # (B, H, Sq, D) -> (B, Sq, H, D) q, k, v = [x.transpose(1, 2) for x in (q, k, v)] - + assert _not_nan(q), q + assert _not_nan(k), k sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) sp_global_ranks = dist.get_process_group_ranks(sp_group) @@ -626,7 +626,9 @@ def forward( with torch.cuda.stream(sp_streams[i % 2]): for req in p2p_reqs[(i + 1) % 2]: req.wait() - assert _not_nan(kv_buffers[i % 2]), kv_buffers[i % 2] + assert _not_nan( + kv_buffers[i % 2] + ), f"rank {dist.get_rank()} iter {i} kv buffer is nan: {kv_buffers[i % 2]}" if i < sp_size - 1: p2p_reqs[i % 2] = ring_attn_p2p_comm( @@ -674,7 +676,7 @@ def forward( kv_block = kv_buffers[i % 2] # (2, B * Sq // 2, H, D) kv_block = kv_block.view(2, b * sq, h, d)[:, : b * sq // 2].clone() - assert _not_nan(kv_block), f"rank {sp_rank} step {i} kv_block {kv_block}" + assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}" # actual_lse = (q_block.flatten(start_dim=1) @ kv_block[0].movedim(0, -1).flatten(end_dim=-2)).exp().sum(dim=-1).log() ( _, @@ -702,7 +704,7 @@ def forward( # Drop the first half of q q_block = q.view(b * sq, h, d)[b * sq // 2 :] kv_block = kv_buffers[i % 2].view(2, b * sq, h, d).clone() - assert _not_nan(kv_block), f"rank {sp_rank} step {i} kv_block {kv_block}" + assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}" # actual_lse = (q_block.flatten(start_dim=1) @ kv_block[0].movedim(0, -1).flatten(end_dim=-2)).exp().sum(dim=-1).log() ( @@ -919,9 +921,9 @@ def backward(ctx, dout): # Accumulate grads if i == 0: # TODO: use float() if precision goes wrong - dq = dq_block - dk_recv = dkv_buffers[(i + 1) % 2][0] = dk_block.clone() - dv_recv = dkv_buffers[(i + 1) % 2][1] = dv_block.clone() + dq = dq_block.float() + dk_recv = dkv_buffers[(i + 1) % 2][0] = dk_block.float() + dv_recv = dkv_buffers[(i + 1) % 2][1] = dv_block.float() else: # Accumulate local dq if i <= sp_rank: @@ -933,7 +935,8 @@ def backward(ctx, dout): # Wait for mobile kv grad accumulators for req in dkv_reqs: req.wait() - + assert _not_nan(dkv_buffers[(i + 1) % 2]), f"rank {dist.get_rank()} step {i} dkv_buffers is nan" + assert _not_nan(dq_block), f"rank {dist.get_rank()} step {i} dq_block is nan" if i <= sp_rank: # q blocks "surrounded" by kv blocks dk_recv = dkv_buffers[(i + 1) % 2][0] diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index a51f7d5d9ee8..7d78f4f19a88 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -291,7 +291,7 @@ def create_randomizer_with_offset( return Randomizer(seed=base_seed) -def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen: bool = False): +def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False): """ Split the input along the sequence dimension for Ring Attention. As naively spliting sequence in the causual setting will result in the first ranks having much less workload than the last ranks, @@ -301,18 +301,19 @@ def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen Args: batch (List[torch.Tensor]): The input tensors to split. sp_group (ProcessGroup): The process group for sequence parallelism. + seq_dim (int): The sequence dimension to split. varlen (bool): If the input is padded (aka "packing" mode), such that sequences in a batch have different lengths, and we need to unpad and split each sequence evenly by sp_size. """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) - seq_dim = 1 if sp_size > 1: for idx, tensor in enumerate(batch): assert ( tensor.numel() // (sp_size * 2) > 1 ), f"Bro, the seq length for tensor {idx} in batch is too short to split!" + tensor = tensor.view( *tensor.shape[:seq_dim], 2 * sp_size, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7f88c0f94b8b..254c0996e115 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -135,7 +135,7 @@ def llama_model_forward( if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( + attn_mask = ColoAttention.prepare_attn_kwargs( mask_shape, hidden_states.dtype, hidden_states.device, @@ -143,22 +143,24 @@ def llama_model_forward( is_causal=True, ) else: - attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) + attn_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position) # Support SP + PP if stage_manager.is_first_stage(): + # Ring Attention zigzag batch processing if sp_mode == "ring_attn": - # NOTE: This will throw an error in KV Cache inference without replicating q in all ranks. - # Also, I don't see get_llama_flash_attention_forward supporting - # query_states and key_states with different seq_len. - batch = { - "input": inputs_embeds, - "attention_mask": attention_mask["attention_mask"], - "position": position_ids, - } - batch = zigzag_split_batch(batch, sp_group) - inputs_embeds, attention_mask["attention_mask"], position_ids = batch.values() - elif sp_mode in ["ring", "split_gather"]: + assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention." + if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL: + attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info( + attn_mask["attention_mask"].squeeze(1).any(dim=-1) + ) # [B, 1, Sq, Skv] -> [B, Sq] + else: + attn_mask["cu_seqlens"] = attn_mask["max_seqlen"] = attn_mask["indices"] = None + batch = [hidden_states, position_ids] + # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) + hidden_states, position_ids = zigzag_split_batch(batch, sp_group) + + elif is_share_sp_tp(sp_mode): hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group) elif sp_mode == "all_to_all": hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size) @@ -193,12 +195,11 @@ def llama_model_forward( for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) - if idx - start_idx < num_ckpt_layers: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + attn_mask, position_ids, past_key_values, output_attentions, @@ -208,14 +209,13 @@ def llama_model_forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=attn_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) - hidden_states = layer_outputs[0] if use_cache: @@ -314,7 +314,7 @@ def llama_for_causal_lm_forward( if stage_manager.is_first_stage(): if shard_config.sequence_parallelism_mode == "ring_attn": - labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group) + labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0] # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( @@ -500,7 +500,7 @@ def forward( bsz, q_len, _ = hidden_states.size() # sp: modify sp_len when sequence parallel mode is ring - if sp_mode in ["split_gather", "ring"]: + if is_share_sp_tp(sp_mode): q_len *= sp_size if self.config.pretraining_tp > 1: @@ -555,7 +555,9 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + assert not self.q_proj.weight.isnan().any(), self.q_proj.weight + assert not query_states.isnan().any(), query_states if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, @@ -701,7 +703,7 @@ def forward( # inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group) inputs_embeds, position_ids = zigzag_split_batch(batch, sp_group) - elif sp_mode in ["ring", "split_gather"]: + elif is_share_sp_tp(sp_mode): inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) elif sp_mode == "all_to_all": inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 95c3707f4024..1efd3d0179af 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -292,11 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy): def module_policy(self): from transformers import CohereForCausalLM - self.is_casual = True + self.is_causal = True policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { CohereForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 8ebda357b380..afd3d3b18ce9 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -163,7 +163,7 @@ def module_policy(self): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { "DeepseekForCausalLM": ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 19f2accc381b..f72a72df0b1b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -305,7 +305,7 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index c5a0277a5783..6ea27e210455 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -271,7 +271,7 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { MistralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index ad93e94694c8..d65e90cd4d66 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -159,7 +159,7 @@ def module_policy(self): policy = super().module_policy() # TODO: assign pg mesh from plugin to all modules if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { MixtralForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 362c14060fd9..235dc7d56a2d 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -313,7 +313,7 @@ def module_policy(self): setattr(self.shard_config, "causal_lm", True) if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm new_item = { Qwen2ForCausalLM: ModulePolicyDescription( sub_module_replacement=[ diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 97fb6adaaa14..796c4678748c 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -193,7 +193,7 @@ def empty_init(): num_model_chunks=args.n_chunks, zero_stage=args.zero, sp_size=args.sp, - sp_mode=args.sp_mode, + sequence_parallelism_mode=args.sp_mode, enable_sequence_parallelism=args.sp > 1, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, @@ -324,7 +324,6 @@ def empty_init(): performance_evaluator.on_step_end(**batch) prof.step() - booster.save_model(model, "model.pt") performance_evaluator.on_fit_end() coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index f46062128563..051f7fc8004e 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -171,7 +171,7 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm + # add a new item for causal lm # TODO: recursively assign ep group foe all modules new_item = { OpenMoeForCausalLM: ModulePolicyDescription( diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md index af1e794374ed..694c5cf91acc 100644 --- a/examples/language/opt/README.md +++ b/examples/language/opt/README.md @@ -17,7 +17,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost. ## Our Modifications diff --git a/examples/tutorial/opt/opt/README.md b/examples/tutorial/opt/opt/README.md index a01209cbda0e..3776e0c64552 100644 --- a/examples/tutorial/opt/opt/README.md +++ b/examples/tutorial/opt/opt/README.md @@ -19,7 +19,7 @@ limitations under the License. ## OPT Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost. +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost. We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling). diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 66c794a7d891..9c1a11e7bc29 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -22,9 +22,9 @@ "transformers_bloom_for_causal_lm", "transformers_falcon_for_causal_lm", "transformers_chatglm_for_conditional_generation", - "transformers_llama_for_casual_lm", + "transformers_llama_for_causal_lm", "transformers_vit_for_masked_image_modeling", - "transformers_mistral_for_casual_lm", + "transformers_mistral_for_causal_lm", ] IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1" diff --git a/tests/kit/model_zoo/transformers/command.py b/tests/kit/model_zoo/transformers/command.py index a8b8842c5907..3f4ea45838d7 100644 --- a/tests/kit/model_zoo/transformers/command.py +++ b/tests/kit/model_zoo/transformers/command.py @@ -32,8 +32,8 @@ def data_gen(): return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -44,7 +44,7 @@ def data_gen_for_casual_lm(): # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = CohereConfig( @@ -70,10 +70,10 @@ def data_gen_for_casual_lm(): model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_command_for_casual_lm", + name="transformers_command_for_causal_lm", model_fn=lambda: transformers.CohereForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index a184c916e42a..9b3db7ca96eb 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -43,8 +43,8 @@ def data_gen(): return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -55,7 +55,7 @@ def data_gen_for_casual_lm(): # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( @@ -74,11 +74,11 @@ def data_gen_for_casual_lm(): # transformers.LlamaModel, # transformers.LlamaForSequenceClassification, model_zoo.register( - name="transformers_llama_for_casual_lm", + name="transformers_llama_for_causal_lm", model_fn=lambda: transformers.LlamaForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/kit/model_zoo/transformers/mistral.py b/tests/kit/model_zoo/transformers/mistral.py index ae5a9700240a..43fc662cc840 100644 --- a/tests/kit/model_zoo/transformers/mistral.py +++ b/tests/kit/model_zoo/transformers/mistral.py @@ -64,7 +64,7 @@ def data_gen_for_sequence_classification(): model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_mistral_for_casual_lm", + name="transformers_mistral_for_causal_lm", model_fn=lambda: transformers.MistralForCausalLM(config), data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, diff --git a/tests/kit/model_zoo/transformers/qwen2.py b/tests/kit/model_zoo/transformers/qwen2.py index 1c26af698497..83bc9f941be7 100644 --- a/tests/kit/model_zoo/transformers/qwen2.py +++ b/tests/kit/model_zoo/transformers/qwen2.py @@ -33,8 +33,8 @@ def data_gen(): attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) - # label is needed for casual lm - def data_gen_for_casual_lm(): + # label is needed for causal lm + def data_gen_for_causal_lm(): data = data_gen() labels = data["input_ids"].clone() data["labels"] = labels @@ -45,7 +45,7 @@ def data_gen_for_casual_lm(): # function to get the loss loss_fn = lambda output: output["last_hidden_state"].mean() - loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_causal_lm = lambda output: output["loss"] loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = Qwen2Config( @@ -72,11 +72,11 @@ def data_gen_for_casual_lm(): model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( - name="transformers_qwen2_for_casual_lm", + name="transformers_qwen2_for_causal_lm", model_fn=lambda: transformers.Qwen2ForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, + loss_fn=loss_fn_for_causal_lm, model_attribute=ModelAttribute(has_control_flow=True), ) model_zoo.register( diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index e57cadfd8673..3e85329553e0 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True): # TODO(ver217): add more models for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( - "transformers_llama_for_casual_lm" + "transformers_llama_for_causal_lm" ).items(): err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 8c59f430c2d9..c2a08a541bc7 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): sub_model_zoo = model_zoo.get_sub_registry(model_name) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index fd13ce0bfadc..b133be948c1e 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @clear_cache_before_run() @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 4897907ffc8a..ce4d10322ba5 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -20,7 +20,7 @@ @clear_cache_before_run() @parameterize("shard", [False, True]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) def exam_torch_load_from_gemini(shard: bool, model_name: str): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 4f8f260417a3..86d7924fb828 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -39,7 +39,7 @@ @parameterize("shard", [True, False]) -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) @clear_cache_before_run() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index ab48944d4eaa..a8e05a25ad28 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO( if name != "transformers_llama": continue task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index df8636141e2a..6f8eb2ad26cd 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -18,7 +18,7 @@ @clear_cache_before_run() -@parameterize("model_name", ["transformers_llama_for_casual_lm"]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("plugin_type", ["ddp", "zero", "gemini"]) def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( diff --git a/tests/test_lora/test_lora.py b/tests/test_lora/test_lora.py index b8daf775db0e..71dd7863ab4c 100644 --- a/tests/test_lora/test_lora.py +++ b/tests/test_lora/test_lora.py @@ -88,7 +88,7 @@ def run_lora_test(): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None - if name == "transformers_llama_for_casual_lm": + if name == "transformers_llama_for_causal_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a20e2b51de5b..8e66d8b00fb5 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -376,8 +376,11 @@ def get_grad_tensors_for_check( shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel - if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[: org_grad.shape[0], :] + try: + if shard_grad.shape[0] > org_grad.shape[0]: + shard_grad = shard_grad[: org_grad.shape[0], :] + except: + pass if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 3281b50e1d5d..efe5cee2a2b6 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ], ) def run_command_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -321,7 +321,7 @@ def run_command_test(test_config): ], ) def run_command_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm") + sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 2609f420c715..eccad3979e5b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -153,12 +153,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - # Zigzag Ring Attention + # Zigzag Ring Attention + PP + { + "tp_size": 1, + "pp_size": 2, + "sp_size": 2, + "num_microbatches": 2, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "ring_attn", + "use_lazy_init": True, + "zero_stage": 1, + "precision": "bf16", + "initial_scale": 1, + }, + # Ring Attention + TP { "tp_size": 2, "pp_size": 1, "sp_size": 2, - "num_microbatches": 2, + "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "ring_attn", "use_lazy_init": True, @@ -170,7 +183,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 1, "sp_size": 2, - "num_microbatches": 2, + "num_microbatches": 1, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", "enable_all_optimization": True, @@ -262,7 +275,6 @@ def run_llama_test(test_config): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue - try: check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) except Exception as e: