Skip to content

Commit

Permalink
[kernel] support pure fp16 for cpu adam and update gemini optim tests (
Browse files Browse the repository at this point in the history
…hpcaitech#4921)

* [kernel] support pure fp16 for cpu adam (hpcaitech#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (hpcaitech#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test
  • Loading branch information
ver217 authored and flybird11111 committed Nov 9, 2023
1 parent 52707c6 commit 61ec9f7
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 136 deletions.
201 changes: 94 additions & 107 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.cpp

Large diffs are not rendered by default.

41 changes: 30 additions & 11 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ SOFTWARE
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))

#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
Expand All @@ -66,9 +66,9 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))

#endif

Expand All @@ -83,11 +83,12 @@ union AVX_Data {

#endif

#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);

class Adam_Optimizer {
public:
Expand Down Expand Up @@ -141,6 +142,24 @@ class Adam_Optimizer {
}
}

inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}

inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}

void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
Expand Down
7 changes: 2 additions & 5 deletions tests/test_optimizer/test_adam_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
_FUSED_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
(torch.bfloat16, torch.float),
(torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16),
]

_CPU_ALLOWED_P_G_TYPES = [
(torch.float, torch.half),
(torch.float, torch.float),
(torch.half, torch.float),
(torch.half, torch.half),
]

Expand Down Expand Up @@ -138,8 +135,8 @@ def check_adam_kernel(
master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone()
exp_avg_sq = master_exp_avg_sq.clone()
exp_avg = master_exp_avg.clone().to(p_dtype)
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)

for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_optimizer/test_adam_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]

N_STEPS = 3
Expand Down
12 changes: 9 additions & 3 deletions tests/test_zero/test_gemini/test_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):

@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
def exam_grad_clipping(placement_config, model_name: str):
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
Expand Down Expand Up @@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str):
chunk_config_dict=config_dict,
chunk_init_device=init_device,
pin_memory=True,
master_weights=master_weights,
**placement_config,
)

Expand All @@ -103,15 +105,19 @@ def exam_grad_clipping(placement_config, model_name: str):

torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
assert_close(torch_loss, loss)

# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)

import apex.amp as apex_amp

torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0)
torch_optim.step()
zero_optim.step()

check_param(model, torch_model)
if master_weights:
check_param(model, torch_model)


def run_dist(rank, world_size, port):
Expand Down
15 changes: 11 additions & 4 deletions tests/test_zero/test_gemini/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", TEST_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
@parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

torch_model = model_builder().cuda()
# apex no master weights leads to nan, so we don't use it
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
Expand All @@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = False
model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
model = GeminiDDP(
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
)

optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
Expand All @@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt

torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss, rtol=rtol, atol=atol)

zero_optim.step()
torch_optim.step()

check_param(model, torch_model, mixed_precision)
if master_weights:
check_param(model, torch_model, mixed_precision)


@parameterize("placement_config", PLACEMENT_CONFIGS)
Expand Down

0 comments on commit 61ec9f7

Please sign in to comment.