Skip to content

Commit

Permalink
address TODOs as 2D recompiles is fixed
Browse files Browse the repository at this point in the history
ghstack-source-id: 2927f0a8082171da3e9f59a5d04f8325cbdf3653
Pull Request resolved: #508
  • Loading branch information
tianyu-l committed Aug 7, 2024
1 parent 05db84d commit 6e7a183
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ def apply_tp(
if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

# TODO: remove cache_size_limit adjustment after 2D compile is fixed
torch._dynamo.config.cache_size_limit = 10000

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

Expand Down Expand Up @@ -280,18 +277,15 @@ def apply_ac(model: nn.Module, ac_config):


def apply_compile(model: nn.Module):
"""Apply torch.compile to each transformer block."""

# the following flag can be used to to accelarate per-TransformerBlock compilation
# TODO(bdhirsh): turning it off because it's currently not working with 2D
# TODO(anijain): remove it after it's enabled in pytorch by default
# torch._dynamo.config.inline_inbuilt_nn_modules = True

"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiled each TransformerBlock with torch.compile")
logger.info("Compiling each TransformerBlock with torch.compile")
return model


Expand Down

0 comments on commit 6e7a183

Please sign in to comment.