Skip to content

Commit

Permalink
[test] Fix/fix testcase (#5770)
Browse files Browse the repository at this point in the history
* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;
  • Loading branch information
duanjunwen authored Jun 3, 2024
1 parent 3f2be80 commit 1b76564
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion colossalai/_analyzer/fx/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,4 @@ def emit_node(node: Node, body):
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
return PythonCode(fn_code, globals_, {})
2 changes: 1 addition & 1 deletion colossalai/fx/codegen/activation_checkpoint_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def emit_node(node: Node, body):
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
return PythonCode(fn_code, globals_, {})

else:

Expand Down
6 changes: 3 additions & 3 deletions tests/test_auto_parallel/test_offload/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel

from tests.components_to_test.registry import non_distributed_component_funcs
# from tests.components_to_test.registry import non_distributed_component_funcs


class GPTLMModel(nn.Module):
Expand Down Expand Up @@ -55,7 +55,7 @@ def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]


@non_distributed_component_funcs.register(name="bert_")
# @non_distributed_component_funcs.register(name="bert_")
def get_bert_components():
vocab_size = 1024
seq_len = 64
Expand All @@ -74,7 +74,7 @@ def bert_data_gen(device="meta"):
return bert_model_builder, bert_data_gen


@non_distributed_component_funcs.register(name="gpt2_")
# @non_distributed_component_funcs.register(name="gpt2_")
def get_gpt2_components():
vocab_size = 1024
seq_len = 8
Expand Down
7 changes: 5 additions & 2 deletions tests/test_auto_parallel/test_offload/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.legacy.zero.gemini.colo_init_context import ColoInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from colossalai.utils import set_seed
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed

# from tests.test_tensor.common_utils import set_seed


@parameterize("model_name", ["gpt2_"])
Expand Down

0 comments on commit 1b76564

Please sign in to comment.