From 02466c6954cdf4cc9d38eee86d6b936910badd37 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 8 Aug 2024 19:01:12 +0000 Subject: [PATCH] use torch.compile; add nsys --- .pre-commit-config.yaml | 2 +- examples/language/llama/benchmark.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e2a038e628d2..250a9b4077c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: hooks: - id: isort name: sort all imports (python) - args: ["--profile", "black"] # avoid comflict with black + args: ["--profile", "black"] # avoid conflict with black - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.4.2 diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index e9a8a28980de..82335dc17ecc 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -28,6 +28,7 @@ # Constants # ============================== +# We have lots of llamas for your choice! MODEL_CONFIGS = { "100m": LlamaConfig( max_position_embeddings=4096, @@ -36,6 +37,7 @@ intermediate_size=2048, hidden_size=1024, ), + "5b": LlamaConfig(max_position_embeddings=4096, num_key_value_heads=8), "7b": LlamaConfig(max_position_embeddings=4096), "13b": LlamaConfig( hidden_size=5120, @@ -92,6 +94,13 @@ def main(): parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"]) parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") + parser.add_argument( + "--nsys", + action="store_true", + help="Use nsys for profiling. \ + You should put something like this before colossalai launch: \ + nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out", + ) parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") @@ -298,6 +307,7 @@ def empty_init(): args.ignore_steps, 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", + nsys=args.nsys, ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader)