From 82b04cb7e53c7d1f5316148460e2cbfd8922cf63 Mon Sep 17 00:00:00 2001 From: lc-maorenjie Date: Fri, 22 Mar 2024 18:40:19 +0800 Subject: [PATCH] came zero test pass --- tests/test_optimizer/test_dist_came_optim.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_optimizer/test_dist_came_optim.py b/tests/test_optimizer/test_dist_came_optim.py index 5d69362c0b6d..e3e4621aac91 100644 --- a/tests/test_optimizer/test_dist_came_optim.py +++ b/tests/test_optimizer/test_dist_came_optim.py @@ -13,7 +13,7 @@ from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer -def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool, tp: int, zero: int, col: bool): +def check_linear_1d_col(seq_parallel: bool, overlap: bool, tp: int, zero: int, col: bool, partition_grad: bool): rtol, atol = 1e-3, 1e-3 # create shardformer # ranks: [0, 1, 2, 3] @@ -73,7 +73,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool, tp: optim = CAME(linear.parameters(), lr=0.1) dist_optim = CAME(linear_col.parameters(), lr=0.1, tp_process_group=tp_group, zero_process_group=zero_group) if zero_group: - dist_optim = LowLevelZeroOptimizer(optimizer=dist_optim, partition_grad=False, dp_process_group=dp_group) + dist_optim = LowLevelZeroOptimizer( + optimizer=dist_optim, partition_grad=partition_grad, dp_process_group=dp_group + ) clip_methods = copy.deepcopy(dist_optim.optim.clip_method) shape_dict = copy.deepcopy(dist_optim.optim.ori_shape) for key in clip_methods.keys(): @@ -147,14 +149,14 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool, tp: assert_close(torch.chunk(ori_state.clone(), tp, dim=clip_dim)[tp_rank], dist_state) -@parameterize("lazy_init", [False]) @parameterize("seq_parallel", [False]) @parameterize("overlap", [True]) -@parameterize("tp", [2]) +@parameterize("tp", [1]) @parameterize("zero", [2]) # zero parallel size @parameterize("col", [True, False]) -def run_dist_linear_test(lazy_init, seq_parallel, overlap, tp, zero, col): - check_linear_1d_col(lazy_init, seq_parallel, overlap, tp, zero, col) +@parameterize("partition_grad", [False, True]) +def run_dist_linear_test(seq_parallel, overlap, tp, zero, col, partition_grad): + check_linear_1d_col(seq_parallel, overlap, tp, zero, col, partition_grad) # check_linear_1d_row(lazy_init, seq_parallel) # check_linear_col_plus_row(lazy_init, seq_parallel, overlap) @@ -166,7 +168,7 @@ def check_dist_linear(rank, world_size, port=12256): @rerun_if_address_is_in_use() def test_linear(): - spawn(check_dist_linear, nprocs=4) + spawn(check_dist_linear, nprocs=2) if __name__ == "__main__":