Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inadequate memory consumption when using HSDP without gradient accumulation #24208

Open
qGentry opened this issue Oct 9, 2024 · 0 comments
Open
Labels
bug Something isn't working

Comments

@qGentry
Copy link

qGentry commented Oct 9, 2024

Description

Hi, I'm training transformer model with Hybrid Sharded Data Parallelism. This setup is similar to FSDP/ZeRO-3 where params all-gather-ed for each layer's forward/backward pass and dropped afterwards. Although, instead of sharding both model params and optimizer state over all GPUs in the cluster, I shard model params only over subset of devices (usually within single node for the fast all-gathers over NVLink) and shard optimizer state over all gpus (similar to FSDP/ZeRO-1/2/3).

Basically, I have mesh (param_groups, model) and for each param tensor P of shape (X, Y) I shard param tensor with partition spec (model, None) and corresponding to this param P optimizer state P_o of the same shape (X, Y) with partition spec (model, param_groups).

When mesh (param_groups, model) size is:

  1. (1, N_GPUs) - this is basically FSDP/ZeRO-3.
  2. (N, N_GPUs/ N), N > 1 - HSDP.

I'm also have a gradient accumulation implemented where we split input batch into chunks, calculate forward/backward pass independently and then sum their gradients.

When using gradient accumulation with the factor of N (batch is splitted into N chucks and processes independently) and sequence lengths of S, peak memory usage must be equal setup with gradient accumulation with the factor of 2 * N and 2 * SEQ_LEN. This is because resulting input tensor is of shape [B / 2, 2 * S] has the same numel as tensor [B, S].

And this is completely true for the FSDP setup with mesh size (1, N_GPUs) for any gradient accumulation factor I've tested, peak memory usages are identical but when I'm trying to use HSDP, something weird happens.

When I'm using gradient accumulation factor of N > 1, peak memory usage is totally expected BUT as soon as I set it to 1, peak memory usage greatly increases.

Here, I have a toy model with the mesh (2, 4), total batch size of 64 and 3 setups:

  1. gradient accumulation factor = 1, seq_len = 512
  2. gradient accumulation factor = 2, seq_len = 1024
  3. gradient accumulation factor = 4, seq_len = 2048

Second and third setup consumes practically identical amount of memory (~50 GB on each GPU), while first sone consumes way more - 61GB.

Here's HLOs of the first and second setups:
compiled_train_fn_grad_accum=2.txt
compiled_train_fn_grad_accum=1.txt

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.24.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='computeinstance-e00xy41pgq1s49hjc5', release='5.15.0-118-generic', version='#128-Ubuntu SMP Fri Jul 5 09:28:59 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Oct  4 10:07:59 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:8D:00.0 Off |                    0 |
| N/A   28C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:91:00.0 Off |                    0 |
| N/A   27C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:95:00.0 Off |                    0 |
| N/A   30C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:99:00.0 Off |                    0 |
| N/A   27C    P0            112W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:AB:00.0 Off |                    0 |
| N/A   28C    P0            109W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:AF:00.0 Off |                    0 |
| N/A   26C    P0            109W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:B3:00.0 Off |                    0 |
| N/A   29C    P0            112W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:B7:00.0 Off |                    0 |
| N/A   27C    P0            110W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

XLA issue: openxla/xla#18090

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant