Skip to content

Commit

Permalink
[BE][1/n] simplify train.py
Browse files Browse the repository at this point in the history
ghstack-source-id: 3879e764e7b33afde5d778810c71d1d2a8f82f6d
Pull Request resolved: #494
  • Loading branch information
tianyu-l committed Aug 1, 2024
1 parent b069f70 commit 3ddce59
Show file tree
Hide file tree
Showing 21 changed files with 231 additions and 240 deletions.
20 changes: 10 additions & 10 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,22 @@
import os

import torch
import torch.nn.functional as F
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed import destroy_process_group
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer
from torchtitan.datasets import build_tokenizer
from torchtitan.float8_linear import (
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import build_optimizers, get_train_context
from train import get_train_context


def estimate_memory(job_config: JobConfig):
Expand Down Expand Up @@ -97,7 +95,7 @@ def estimate_memory(job_config: JobConfig):

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
Expand All @@ -106,7 +104,9 @@ def estimate_memory(job_config: JobConfig):

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
Expand Down Expand Up @@ -146,7 +146,7 @@ def loss_fn(pred, labels):

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
Expand Down Expand Up @@ -224,4 +224,4 @@ def loss_fn(pred, labels):
try:
estimate_memory(config)
finally:
destroy_process_group()
torch.distributed.destroy_process_group()
4 changes: 2 additions & 2 deletions test/datasets/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer
from torchtitan.datasets.tokenizer import build_tokenizer


class TestCheckpoint:
Expand Down Expand Up @@ -42,7 +42,7 @@ def _build_dataloader(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
tokenizer = build_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
return build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
Expand Down
45 changes: 42 additions & 3 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import re
import shutil
import time
from dataclasses import dataclass, field
from io import BytesIO
from multiprocessing import get_context
from typing import Any, Dict, List, Union

Expand All @@ -27,7 +29,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger
from torchtitan.logging import init_logger, logger


class IntervalType(enum.Enum):
Expand All @@ -41,6 +43,43 @@ class AsyncMode(str, enum.Enum):
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"


@dataclass
class TrainState(Stateful):
step: int = 0
global_avg_losses: List[float] = field(default_factory=list)
global_max_losses: List[float] = field(default_factory=list)
log_steps: List[int] = field(default_factory=list)

def state_dict(self) -> Dict[str, Any]:
# Only checkpoint global_avg_losses and global_max_losses per log frequency
# to avoid sync overhead in every iteration.
global_avg_losses_bytes = BytesIO()
torch.save(self.global_avg_losses, global_avg_losses_bytes)
global_max_losses_bytes = BytesIO()
torch.save(self.global_max_losses, global_max_losses_bytes)
log_steps_bytes = BytesIO()
torch.save(self.log_steps, log_steps_bytes)
return {
"step": torch.tensor(self.step, dtype=torch.int32),
"global_avg_losses": global_avg_losses_bytes,
"global_max_losses": global_max_losses_bytes,
"log_steps": log_steps_bytes,
}

def load_state_dict(self, state_dict) -> None:
self.step = state_dict["step"].item()
state_dict["global_avg_losses"].seek(0)
self.global_avg_losses = torch.load(
state_dict["global_avg_losses"], weights_only=False
)
state_dict["global_max_losses"].seek(0)
self.global_max_losses = torch.load(
state_dict["global_max_losses"], weights_only=False
)
state_dict["log_steps"].seek(0)
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)


class ModelWrapper(Stateful):
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
Expand Down Expand Up @@ -124,10 +163,10 @@ def checkpoint_mp(recv, send):
class CheckpointManager:
def __init__(
self,
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: List[torch.optim.Optimizer],
lr_schedulers: List[torch.optim.lr_scheduler.LRScheduler],
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down Expand Up @@ -390,7 +429,7 @@ def save(self, curr_step: int, force: bool = False) -> None:
f"in {time.monotonic() - begin:.2f} seconds."
)

def wait_for_staging(self) -> None:
def maybe_wait_for_staging(self) -> None:
if (
self.enable_checkpoint
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
except ModuleNotFoundError:
import tomli as tomllib

from torchtitan.logging_utils import logger
from torchtitan.logging import logger

TORCH_DTYPE_MAP = {
"float16": torch.float16,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# LICENSE file in the root directory of this source tree.

from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer
from torchtitan.datasets.tokenizer import build_tokenizer

__all__ = [
"build_hf_data_loader",
"create_tokenizer",
"build_tokenizer",
]
2 changes: 1 addition & 1 deletion torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
) from e

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
from torchtitan.logging import logger

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/datasets/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer
from torchtitan.datasets.tokenizer.tokenizer import Tokenizer

from torchtitan.logging_utils import logger
from torchtitan.logging import logger


def create_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
def build_tokenizer(tokenizer_type: str, tokenizer_path: str) -> Tokenizer:
logger.info(f"Building {tokenizer_type} tokenizer locally from {tokenizer_path}")
if tokenizer_type == "sentencepiece":
return SentencePieceTokenizer(tokenizer_path)
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/datasets/tokenizer/sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sentencepiece import SentencePieceProcessor

from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


class SentencePieceTokenizer(Tokenizer):
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/datasets/tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tiktoken.load import load_tiktoken_bpe

from torchtitan.datasets.tokenizer.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


class TikTokenizer(Tokenizer):
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch._logging import warning_once

from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
from torchtitan.logging import logger


@functools.lru_cache(None)
Expand Down
File renamed without changes.
32 changes: 23 additions & 9 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torch
from torch.utils.tensorboard import SummaryWriter
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims

# named tuple for passing GPU memory stats for logging
GPUMemStats = namedtuple(
Expand Down Expand Up @@ -110,16 +111,29 @@ def close(self):
self.writer.close()


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
"""
Returns global rank 0 in non-pipeline-parallel configs, and returns the global
rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled.
"""
if parallel_dims.pp_enabled:
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
metrics_log_rank = (world_size // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0

return metrics_log_rank


def build_metric_logger(
config: JobConfig, metrics_log_rank: int = 0, tag: Optional[str] = None
config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
"""
metrics_log_rank controls which rank acts as 'rank 0' for logging metrics.
If 'tb_config.rank_0_only' is set, then `metrics_log_rank` will be used as the rank to log metrics.
This is intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information when using pipeline
parallelism.
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
"""
dump_dir = config.job.dump_folder
tb_config = config.metrics
Expand All @@ -134,7 +148,7 @@ def build_metric_logger(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == metrics_log_rank
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)
Expand Down
8 changes: 4 additions & 4 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.models.norms import create_norm
from torchtitan.models.norms import build_norm


@dataclass
Expand Down Expand Up @@ -291,10 +291,10 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.layer_id = layer_id
self.num_layers = model_args.n_layers

self.attention_norm = create_norm(
self.attention_norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = create_norm(
self.ffn_norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, model_args: ModelArgs):
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

self.norm = create_norm(
self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

Expand Down
8 changes: 4 additions & 4 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from torch.distributed._tensor.experimental import local_map


def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Creates the specified normalization layer based on the norm_type.
Builds the specified normalization layer based on the norm_type.
Args:
norm_type (str): The type of normalization layer to create.
norm_type (str): The type of normalization layer to build.
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
Returns:
The created normalization layer.
The built normalization layer.
Raises:
NotImplementedError: If an unknown norm_type is provided.
Expand Down
Loading

0 comments on commit 3ddce59

Please sign in to comment.