Skip to content

Commit

Permalink
feat: support compute dtype in spmd fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
lausannel committed Sep 14, 2024
1 parent a32543b commit db59a32
Showing 1 changed file with 67 additions and 4 deletions.
71 changes: 67 additions & 4 deletions torch_xla/experimental/spmd_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import (Any, Callable, Dict, Optional, Union)
import warnings
from collections import OrderedDict
from typing import (Any, Callable, Dict, Generator, List, Optional, Set, Tuple,
Union, cast)

import numpy as np
import torch
import torch.nn as nn
from torch._prims_common import TensorLike, TensorSequenceType

import numpy as np
from torch.nn.utils.rnn import PackedSequence

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as spmd
from torch_xla.distributed.fsdp.wrap import recursive_wrap
from torch_xla.distributed.fsdp._init_utils import _materialize_module
from torch_xla.distributed.fsdp.wrap import recursive_wrap

FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]


def _prepare_spmd_partition_spec(param):
Expand Down Expand Up @@ -40,13 +44,18 @@ class SpmdFullyShardedDataParallel(nn.Module):
The callable should have the signature (output, mesh) -> None.
If None, the default implementation will shard the first tensor in the output.
If the output is a tuple, only the first tensor will be sharded.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` but can be set to ``torch.float16`` or
``torch.bfloat16``. The sharded parameters will always be in FP32.
"""

def __init__(
self,
module: nn.Module,
mesh: Optional[spmd.Mesh] = None,
shard_output: Optional[Callable] = None,
compute_dtype: Optional[torch.dtype] = None,
auto_wrap_policy: Optional[Callable] = None,
auto_wrapper_callable: Optional[Callable] = None,
):
Expand Down Expand Up @@ -96,6 +105,11 @@ def __init__(
)
self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs)

if compute_dtype is not None and compute_dtype not in FLOAT_DTYPES:
raise ValueError(
f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}")
self.compute_dtype = compute_dtype or torch.float32

_materialize_module(
module,
None, [],
Expand Down Expand Up @@ -150,6 +164,9 @@ def module(self) -> nn.Module:
return self._orig_module

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.compute_dtype != torch.float32:
# Cast the input float tensors to the specified compute_dtype
args, kwargs = _cast_floats_tensors(self.compute_dtype, *args, **kwargs)
output = self.module(*args, **kwargs)
# Need to shard the output of the forward to instruct the compiler
# to enforce the FSDP algorithm.
Expand Down Expand Up @@ -192,3 +209,49 @@ def _auto_wrap(
"if using an `auto_wrap_policy`")

recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)


def _cast_floats_tensors(dtype: torch.dtype, *args: Any,
**kwargs: Any) -> Tuple[Any, Any]:
"""
Cast floating point Tensors in *args or **kwargs to dtype if they are not.
"""

def fn(t):
if t.dtype != dtype and torch.is_floating_point(t):
t = t.to(dtype)
return t

return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)


def apply_to_tensors(
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]
) -> Union[torch.Tensor, Dict, List, Tuple, Set]:
"""Recursively apply to all tensor in different kinds of container types."""

def _apply(
x: Union[torch.Tensor, Dict, List, Tuple, Set]
) -> Union[torch.Tensor, Dict, List, Tuple, Set]:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, OrderedDict):
od = x.__class__()
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, PackedSequence):
_apply(x)
return x
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x

return _apply(container)

0 comments on commit db59a32

Please sign in to comment.