Skip to content

Commit

Permalink
use torch.compile; add nsys
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 8, 2024
1 parent e26c910 commit 02466c6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
hooks:
- id: isort
name: sort all imports (python)
args: ["--profile", "black"] # avoid comflict with black
args: ["--profile", "black"] # avoid conflict with black

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
Expand Down
10 changes: 10 additions & 0 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# Constants
# ==============================

# We have lots of llamas for your choice!
MODEL_CONFIGS = {
"100m": LlamaConfig(
max_position_embeddings=4096,
Expand All @@ -36,6 +37,7 @@
intermediate_size=2048,
hidden_size=1024,
),
"5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8),
"7b": LlamaConfig(max_position_embeddings=4096),
"13b": LlamaConfig(
hidden_size=5120,
Expand Down Expand Up @@ -92,6 +94,13 @@ def main():
parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
parser.add_argument("--profile", action="store_true", help="Profile the code")
parser.add_argument(
"--nsys",
action="store_true",
help="Use nsys for profiling. \
You should put something like this before colossalai launch: \
nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
)
parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
Expand Down Expand Up @@ -298,6 +307,7 @@ def empty_init():
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
Expand Down

0 comments on commit 02466c6

Please sign in to comment.