diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 0ab250218da3..be9300c545c2 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -35,23 +35,19 @@ SOFTWARE void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; float step_size = -1 * _alpha / _bias_correction1; float w_decay = -1 * _alpha * _weight_decay; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { AVX_Data grad_4; - if (grad_half_precision) { - grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); - } else { - grad_4.data = SIMD_LOAD(grads + i); - } + this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); } AVX_Data momentum_4; - momentum_4.data = SIMD_LOAD(_exp_avg + i); + this->simd_load(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); AVX_Data variance_4; - variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + this->simd_load(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); AVX_Data param_4; - if (param_half_precision) { - param_4.data = SIMD_LOAD_HALF(params_cast_h + i); - } else { - param_4.data = SIMD_LOAD(_params + i); - } + this->simd_load(param_half_precision, _params + i, params_cast_h + i, + param_4); if (_weight_decay > 0 && !_adamw_mode) { grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); @@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); - } else { - SIMD_STORE(_params + i, param_4.data); - } - SIMD_STORE(_exp_avg + i, momentum_4.data); - SIMD_STORE(_exp_avg_sq + i, variance_4.data); + this->simd_store(param_half_precision, _params + i, params_cast_h + i, + param_4); + this->simd_store(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + this->simd_store(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); } } #endif @@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; + float momentum = + momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k]; + float variance = variance_half_precision ? (float)variance_cast_h[k] + : _exp_avg_sq[k]; if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } @@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, params_cast_h[k] = (__half)param; else _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; + if (momentum_half_precision) + momentum_cast_h[k] = (__half)(momentum); + else + _exp_avg[k] = momentum; + if (variance_half_precision) + variance_cast_h[k] = (__half)(variance); + else + _exp_avg_sq[k] = variance; } } } @@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[4]; #pragma unroll 4 for (int j = 0; j < 4; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, } param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[8]; #pragma unroll 8 for (int j = 0; j < 8; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - - SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, @@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, this->update_state(lr, epsilon, weight_decay, bias_correction); this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), loss_scale); + (grads.options().dtype() == at::kHalf), + (exp_avg.options().dtype() == at::kHalf), + (exp_avg_sq.options().dtype() == at::kHalf), loss_scale); } namespace py = pybind11; diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 4247da942775..bf9b85997c78 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -50,9 +50,9 @@ SOFTWARE #define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ - x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) #define SIMD_WIDTH 8 @@ -66,9 +66,9 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ - x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -83,11 +83,12 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ - bool grad_half_precision = false, float loss_scale = -1); +#define STEP(SPAN) \ + void Step_##SPAN( \ + float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \ + size_t _param_size, bool param_half_precision = false, \ + bool grad_half_precision = false, bool momentum_half_precision = false, \ + bool variance_half_precision = false, float loss_scale = -1); class Adam_Optimizer { public: @@ -141,6 +142,24 @@ class Adam_Optimizer { } } + inline void simd_load(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + data.data = SIMD_LOAD_HALF(h_ptr); + } else { + data.data = SIMD_LOAD(ptr); + } + } + + inline void simd_store(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + SIMD_STORE_HALF(h_ptr, data.data); + } else { + SIMD_STORE(ptr, data.data); + } + } + void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, torch::Tensor &grads, torch::Tensor &exp_avg, diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 1bdb81e2d6ec..238ba366da43 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 7dc4590dc3f2..c7a309b872ce 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 8131ea3234d8..6bbe3e4e8172 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -13,9 +13,7 @@ _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), - (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), (torch.bfloat16, torch.bfloat16), ] @@ -23,7 +21,6 @@ _CPU_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), ] @@ -138,8 +135,8 @@ def check_adam_kernel( master_exp_avg_sq = torch.zeros_like(master_p) p = master_p.clone().to(p_dtype) g = master_g.clone().to(g_dtype) - exp_avg = master_exp_avg.clone() - exp_avg_sq = master_exp_avg_sq.clone() + exp_avg = master_exp_avg.clone().to(p_dtype) + exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype) for step in range(1, 1 + n_steps): torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 59b40a0afa3c..68d71e3c4194 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -21,8 +21,6 @@ (torch.float, torch.float), # pure fp32 (torch.float, torch.half), # fp16 amp (torch.float, torch.bfloat16), # bfloat16 amp - # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 - # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] N_STEPS = 3 diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index a3af81646a18..4c84e9e5a89a 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["gpt2"]) -def exam_grad_clipping(placement_config, model_name: str): +@parameterize("master_weights", [True, False]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str): chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, + master_weights=master_weights, **placement_config, ) @@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str): torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) loss = run_fwd_bwd(model, data, label, criterion, zero_optim) - assert_close(torch_loss, loss) + + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss) import apex.amp as apex_amp @@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str): torch_optim.step() zero_optim.step() - check_param(model, torch_model) + if master_weights: + check_param(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8e8e508ff483..9b84d68f3c7a 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() + # apex no master weights leads to nan, so we don't use it amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) @@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False - model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) + model = GeminiDDP( + model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) @@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model, mixed_precision) + if master_weights: + check_param(model, torch_model, mixed_precision) @parameterize("placement_config", PLACEMENT_CONFIGS)