Skip to content

Commit

Permalink
came zero test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
chongqichuizi875 committed Apr 7, 2024
1 parent 036ecb6 commit 82b04cb
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/test_optimizer/test_dist_came_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer


def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool, tp: int, zero: int, col: bool):
def check_linear_1d_col(seq_parallel: bool, overlap: bool, tp: int, zero: int, col: bool, partition_grad: bool):
rtol, atol = 1e-3, 1e-3
# create shardformer
# ranks: [0, 1, 2, 3]
Expand Down Expand Up @@ -73,7 +73,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool, tp:
optim = CAME(linear.parameters(), lr=0.1)
dist_optim = CAME(linear_col.parameters(), lr=0.1, tp_process_group=tp_group, zero_process_group=zero_group)
if zero_group:
dist_optim = LowLevelZeroOptimizer(optimizer=dist_optim, partition_grad=False, dp_process_group=dp_group)
dist_optim = LowLevelZeroOptimizer(
optimizer=dist_optim, partition_grad=partition_grad, dp_process_group=dp_group
)
clip_methods = copy.deepcopy(dist_optim.optim.clip_method)
shape_dict = copy.deepcopy(dist_optim.optim.ori_shape)
for key in clip_methods.keys():
Expand Down Expand Up @@ -147,14 +149,14 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool, tp:
assert_close(torch.chunk(ori_state.clone(), tp, dim=clip_dim)[tp_rank], dist_state)


@parameterize("lazy_init", [False])
@parameterize("seq_parallel", [False])
@parameterize("overlap", [True])
@parameterize("tp", [2])
@parameterize("tp", [1])
@parameterize("zero", [2]) # zero parallel size
@parameterize("col", [True, False])
def run_dist_linear_test(lazy_init, seq_parallel, overlap, tp, zero, col):
check_linear_1d_col(lazy_init, seq_parallel, overlap, tp, zero, col)
@parameterize("partition_grad", [False, True])
def run_dist_linear_test(seq_parallel, overlap, tp, zero, col, partition_grad):
check_linear_1d_col(seq_parallel, overlap, tp, zero, col, partition_grad)
# check_linear_1d_row(lazy_init, seq_parallel)
# check_linear_col_plus_row(lazy_init, seq_parallel, overlap)

Expand All @@ -166,7 +168,7 @@ def check_dist_linear(rank, world_size, port=12256):

@rerun_if_address_is_in_use()
def test_linear():
spawn(check_dist_linear, nprocs=4)
spawn(check_dist_linear, nprocs=2)


if __name__ == "__main__":
Expand Down

0 comments on commit 82b04cb

Please sign in to comment.