Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable TritonFusedRMSNorm with local_map annotation #404

Merged
merged 15 commits into from
Jun 14, 2024
Merged

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Jun 14, 2024

Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP. #364

XilunWu added 15 commits June 8, 2024 17:40
…nnotation"


**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank0]:2024-06-05 11:57:35,505 - root - INFO - step:  1  loss: 12.2703  memory: 24.66GiB(31.15%)  wps: 143  mfu: 2.66%
[rank0]:2024-06-05 11:57:35,505 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-05 11:58:11,490 - root - INFO - step: 10  loss: 11.0446  memory: 31.96GiB(40.37%)  wps: 512  mfu: 9.51%
[rank0]:2024-06-05 11:58:46,488 - root - INFO - step: 20  loss:  9.2321  memory: 31.96GiB(40.37%)  wps: 586  mfu: 10.87%
[rank0]:2024-06-05 11:59:22,462 - root - INFO - step: 30  loss:  8.2184  memory: 31.96GiB(40.37%)  wps: 570  mfu: 10.58%
[rank0]:2024-06-05 11:59:57,301 - root - INFO - step: 40  loss:  7.6220  memory: 31.96GiB(40.37%)  wps: 589  mfu: 10.93%
[rank0]:2024-06-05 12:00:32,254 - root - INFO - step: 50  loss:  7.5399  memory: 31.96GiB(40.37%)  wps: 587  mfu: 10.89%
[rank0]:2024-06-05 12:01:07,155 - root - INFO - step: 60  loss:  7.3179  memory: 31.96GiB(40.37%)  wps: 588  mfu: 10.91%
[rank0]:2024-06-05 12:01:41,999 - root - INFO - step: 70  loss:  7.3508  memory: 31.96GiB(40.37%)  wps: 589  mfu: 10.92%
[rank0]:2024-06-05 12:02:17,093 - root - INFO - step: 80  loss:  7.2696  memory: 31.96GiB(40.37%)  wps: 584  mfu: 10.85%
[rank0]:2024-06-05 12:02:52,009 - root - INFO - step: 90  loss:  7.0481  memory: 31.96GiB(40.37%)  wps: 588  mfu: 10.91%
[rank0]:2024-06-05 12:03:27,715 - root - INFO - step: 100  loss:  6.9623  memory: 31.96GiB(40.37%)  wps: 575  mfu: 10.67%
```

3. with `norm_type = "fused_rmsnorm"`
```[rank0]:2024-06-05 12:08:35,004 - root - INFO - step:  1  loss: 12.2422  memory: 24.62GiB(31.10%)  wps: 95  mfu: 1.76%
[rank0]:2024-06-05 12:08:35,004 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-06-05 12:09:12,401 - root - INFO - step: 10  loss: 11.0361  memory: 32.09GiB(40.54%)  wps: 493  mfu: 9.15%
[rank0]:2024-06-05 12:09:49,380 - root - INFO - step: 20  loss:  9.2725  memory: 32.09GiB(40.54%)  wps: 554  mfu: 10.29%
[rank0]:2024-06-05 12:10:26,645 - root - INFO - step: 30  loss:  8.2091  memory: 32.09GiB(40.54%)  wps: 550  mfu: 10.21%
[rank0]:2024-06-05 12:11:03,616 - root - INFO - step: 40  loss:  7.5601  memory: 32.09GiB(40.54%)  wps: 555  mfu: 10.30%
[rank0]:2024-06-05 12:11:40,625 - root - INFO - step: 50  loss:  7.5144  memory: 32.09GiB(40.54%)  wps: 554  mfu: 10.29%
[rank0]:2024-06-05 12:12:17,768 - root - INFO - step: 60  loss:  7.3869  memory: 32.09GiB(40.54%)  wps: 552  mfu: 10.25%
[rank0]:2024-06-05 12:12:54,820 - root - INFO - step: 70  loss:  7.3358  memory: 32.09GiB(40.54%)  wps: 553  mfu: 10.27%
[rank0]:2024-06-05 12:13:31,817 - root - INFO - step: 80  loss:  7.2085  memory: 32.09GiB(40.54%)  wps: 554  mfu: 10.29%
[rank0]:2024-06-05 12:14:09,156 - root - INFO - step: 90  loss:  7.0140  memory: 32.09GiB(40.54%)  wps: 549  mfu: 10.19%
[rank0]:2024-06-05 12:14:48,518 - root - INFO - step: 100  loss:  6.9507  memory: 32.09GiB(40.54%)  wps: 521  mfu: 9.67%```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`):
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-12 13:55:25,005 - root - INFO - step:  1  loss: 12.2971  memory: 23.68GiB(29.92%)  wps: 258  mfu: 4.79%
[rank2]:2024-06-12 13:55:43,082 - root - INFO - step:  5  loss: 11.6237  memory: 30.98GiB(39.14%)  wps: 453  mfu: 8.41%
[rank2]:2024-06-12 13:56:00,742 - root - INFO - step: 10  loss: 10.7210  memory: 30.98GiB(39.14%)  wps: 580  mfu: 10.77%
[rank2]:2024-06-12 13:56:18,274 - root - INFO - step: 15  loss:  9.4563  memory: 30.98GiB(39.14%)  wps: 585  mfu: 10.85%
[rank2]:2024-06-12 13:56:35,888 - root - INFO - step: 20  loss:  8.9246  memory: 30.98GiB(39.14%)  wps: 582  mfu: 10.80%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-12 13:52:48,671 - root - INFO - step:  1  loss: 12.2779  memory: 23.64GiB(29.86%)  wps: 186  mfu: 3.45%
[rank2]:2024-06-12 13:53:06,983 - root - INFO - step:  5  loss: 11.6073  memory: 31.11GiB(39.31%)  wps: 447  mfu: 8.30%
[rank2]:2024-06-12 13:53:23,895 - root - INFO - step: 10  loss: 10.6355  memory: 31.11GiB(39.31%)  wps: 606  mfu: 11.25%
[rank2]:2024-06-12 13:53:41,108 - root - INFO - step: 15  loss:  9.5591  memory: 31.11GiB(39.31%)  wps: 596  mfu: 11.05%
[rank2]:2024-06-12 13:53:58,045 - root - INFO - step: 20  loss:  9.0287  memory: 31.11GiB(39.31%)  wps: 605  mfu: 11.23%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
…nnotation"


**Summary**
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with 7%-8% performance gain compared to RMSNorm with TP.

**Test Plan**
Here's the output of running `CONFIG_FILE=./train_configs/llama3_8b.toml NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` using 4-way Tensor Parallel (`tensor_parallel_degree = 4`). Detailed settings:
```
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 5
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm"  # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 4
seq_len = 8192
warmup_steps = 200  # lr scheduler warm up
max_norm = 1.0  # grad norm clipping
steps = 100
data_parallel_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
dataset = "c4_mini"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective'  # ['none', 'selective', 'full']
selective_ac_option = 'op'  # 'int' = ac every positive int layer or 'op', ac based on ops policy
```
1. with `norm_type = "rmsnorm"`
```
[rank2]:2024-06-13 00:47:55,607 - root - INFO - step:  1  loss: 12.2262  memory: 57.70GiB(72.89%)  wps: 429  mfu: 7.96%
[rank2]:2024-06-13 00:48:57,536 - root - INFO - step:  5  loss: 11.4801  memory: 65.53GiB(82.78%)  wps: 529  mfu: 9.82%
[rank2]:2024-06-13 00:50:05,746 - root - INFO - step: 10  loss: 10.2305  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 00:51:14,343 - root - INFO - step: 15  loss:  9.3287  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.09%
[rank2]:2024-06-13 00:52:22,325 - root - INFO - step: 20  loss:  8.7126  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.19%
[rank2]:2024-06-13 00:53:31,605 - root - INFO - step: 25  loss:  8.2011  memory: 65.53GiB(82.78%)  wps: 591  mfu: 10.98%
[rank2]:2024-06-13 00:54:39,861 - root - INFO - step: 30  loss:  7.7424  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 00:55:47,782 - root - INFO - step: 35  loss:  7.4964  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 00:56:55,927 - root - INFO - step: 40  loss:  7.2799  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.16%
[rank2]:2024-06-13 00:58:04,445 - root - INFO - step: 45  loss:  7.2280  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.10%
[rank2]:2024-06-13 00:59:12,503 - root - INFO - step: 50  loss:  7.0669  memory: 65.53GiB(82.78%)  wps: 602  mfu: 11.17%
[rank2]:2024-06-13 01:00:21,395 - root - INFO - step: 55  loss:  6.9967  memory: 65.53GiB(82.78%)  wps: 595  mfu: 11.04%
[rank2]:2024-06-13 01:01:29,641 - root - INFO - step: 60  loss:  7.0763  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:02:37,572 - root - INFO - step: 65  loss:  6.9260  memory: 65.53GiB(82.78%)  wps: 603  mfu: 11.20%
[rank2]:2024-06-13 01:03:45,755 - root - INFO - step: 70  loss:  6.9757  memory: 65.53GiB(82.78%)  wps: 601  mfu: 11.15%
[rank2]:2024-06-13 01:04:54,015 - root - INFO - step: 75  loss:  6.8074  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
[rank2]:2024-06-13 01:06:02,682 - root - INFO - step: 80  loss:  6.7362  memory: 65.53GiB(82.78%)  wps: 597  mfu: 11.08%
[rank2]:2024-06-13 01:07:11,232 - root - INFO - step: 85  loss:  6.7016  memory: 65.53GiB(82.78%)  wps: 598  mfu: 11.09%
[rank2]:2024-06-13 01:08:19,973 - root - INFO - step: 90  loss:  6.6640  memory: 65.53GiB(82.78%)  wps: 596  mfu: 11.06%
[rank2]:2024-06-13 01:09:27,858 - root - INFO - step: 95  loss:  6.7214  memory: 65.53GiB(82.78%)  wps: 604  mfu: 11.20%
[rank2]:2024-06-13 01:10:36,136 - root - INFO - step: 100  loss:  6.5953  memory: 65.53GiB(82.78%)  wps: 600  mfu: 11.14%
```

2. with `norm_type = "fused_rmsnorm"`
```
[rank2]:2024-06-13 00:19:33,609 - root - INFO - step:  1  loss: 12.2194  memory: 57.31GiB(72.40%)  wps: 412  mfu: 7.64%
[rank2]:2024-06-13 00:20:29,175 - root - INFO - step:  5  loss: 11.4519  memory: 65.13GiB(82.29%)  wps: 590  mfu: 10.95%
[rank2]:2024-06-13 00:21:33,667 - root - INFO - step: 10  loss: 10.2199  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.79%
[rank2]:2024-06-13 00:22:37,492 - root - INFO - step: 15  loss:  9.3509  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.92%
[rank2]:2024-06-13 00:23:42,592 - root - INFO - step: 20  loss:  8.7972  memory: 65.13GiB(82.29%)  wps: 629  mfu: 11.68%
[rank2]:2024-06-13 00:24:46,466 - root - INFO - step: 25  loss:  8.2348  memory: 65.13GiB(82.29%)  wps: 642  mfu: 11.91%
[rank2]:2024-06-13 00:25:50,900 - root - INFO - step: 30  loss:  7.7037  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:26:54,794 - root - INFO - step: 35  loss:  7.4639  memory: 65.13GiB(82.29%)  wps: 641  mfu: 11.90%
[rank2]:2024-06-13 00:27:59,235 - root - INFO - step: 40  loss:  7.2406  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.80%
[rank2]:2024-06-13 00:29:03,304 - root - INFO - step: 45  loss:  7.1822  memory: 65.13GiB(82.29%)  wps: 640  mfu: 11.87%
[rank2]:2024-06-13 00:30:07,607 - root - INFO - step: 50  loss:  7.0580  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:31:11,764 - root - INFO - step: 55  loss:  6.9888  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:32:16,001 - root - INFO - step: 60  loss:  7.0387  memory: 65.13GiB(82.29%)  wps: 638  mfu: 11.84%
[rank2]:2024-06-13 00:33:20,137 - root - INFO - step: 65  loss:  6.9199  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:34:24,424 - root - INFO - step: 70  loss:  6.9503  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:35:28,722 - root - INFO - step: 75  loss:  6.7960  memory: 65.13GiB(82.29%)  wps: 637  mfu: 11.83%
[rank2]:2024-06-13 00:36:32,865 - root - INFO - step: 80  loss:  6.6798  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:37:36,981 - root - INFO - step: 85  loss:  6.6504  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.86%
[rank2]:2024-06-13 00:38:41,407 - root - INFO - step: 90  loss:  6.6655  memory: 65.13GiB(82.29%)  wps: 636  mfu: 11.81%
[rank2]:2024-06-13 00:39:45,981 - root - INFO - step: 95  loss:  6.7359  memory: 65.13GiB(82.29%)  wps: 635  mfu: 11.78%
[rank2]:2024-06-13 00:40:50,146 - root - INFO - step: 100  loss:  6.5410  memory: 65.13GiB(82.29%)  wps: 639  mfu: 11.85%
```

[ghstack-poisoned]
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 14, 2024
@XilunWu
Copy link
Contributor Author

XilunWu commented Jun 14, 2024

#364 is created from ghstack. Create this PR to merge from base branch into main.

@XilunWu XilunWu merged commit d761994 into main Jun 14, 2024
6 checks passed
@fabianlim
Copy link

I get an error cannot import name 'Partial' from 'torch.distributed._tensor' on torch=2.3.1. It seems Partial is available only for nightly torch?

@XilunWu
Copy link
Contributor Author

XilunWu commented Jul 23, 2024

@fabianlim That's true. Partial is added in pytorch 2.4

@wanchaol
Copy link
Contributor

I get an error cannot import name 'Partial' from 'torch.distributed._tensor' on torch=2.3.1. It seems Partial is available only for nightly torch?

@fabianlim curious if you are using fused_rmsnorm or not? currently torchtitan is evolving and depending on nightlies, we'll start committing to BC with code releases. If you want to unblock, you can either use a older commit, or do some simple changes in this file https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py#L17

from

from torch.distributed._tensor import Partial, Replicate, Shard

to

from torch.distributed._tensor import Replicate, Shard
from torch.distributed._tensor.placement_types import _Partial 

@fabianlim
Copy link

@XilunWu yea i brought it up beacuse it is not clear from the requirements file that it was only for nightly torch https://github.com/pytorch/torchtitan/blob/main/.ci/docker/requirements.txt#L1

@wanchaol yea i downgraded and Im fine now

tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP. pytorch#364
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
Summary
This PR enables the use of TritonFusedRMSNorm with Tensor Parallel with
7%-8% performance gain compared to RMSNorm with TP. pytorch#364
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants