diff --git a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py index b1f5d709920..dc37ff7d9c7 100644 --- a/torch_xla/experimental/spmd_fully_sharded_data_parallel.py +++ b/torch_xla/experimental/spmd_fully_sharded_data_parallel.py @@ -1,17 +1,21 @@ -from typing import (Any, Callable, Dict, Optional, Union) import warnings +from collections import OrderedDict +from typing import (Any, Callable, Dict, Generator, List, Optional, Set, Tuple, + Union, cast) +import numpy as np import torch import torch.nn as nn from torch._prims_common import TensorLike, TensorSequenceType - -import numpy as np +from torch.nn.utils.rnn import PackedSequence import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as spmd -from torch_xla.distributed.fsdp.wrap import recursive_wrap from torch_xla.distributed.fsdp._init_utils import _materialize_module +from torch_xla.distributed.fsdp.wrap import recursive_wrap + +FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16] def _prepare_spmd_partition_spec(param): @@ -40,6 +44,10 @@ class SpmdFullyShardedDataParallel(nn.Module): The callable should have the signature (output, mesh) -> None. If None, the default implementation will shard the first tensor in the output. If the output is a tuple, only the first tensor will be sharded. + compute_dtype (torch.dtype, Optional): + dtype for full parameters for computation. This defaults to + ``torch.float32`` but can be set to ``torch.float16`` or + ``torch.bfloat16``. The sharded parameters will always be in FP32. """ def __init__( @@ -47,6 +55,7 @@ def __init__( module: nn.Module, mesh: Optional[spmd.Mesh] = None, shard_output: Optional[Callable] = None, + compute_dtype: Optional[torch.dtype] = None, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, ): @@ -96,6 +105,11 @@ def __init__( ) self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs) + if compute_dtype is not None and compute_dtype not in FLOAT_DTYPES: + raise ValueError( + f"compute_dtype must be one of {FLOAT_DTYPES}, not {compute_dtype}") + self.compute_dtype = compute_dtype or torch.float32 + _materialize_module( module, None, [], @@ -150,6 +164,9 @@ def module(self) -> nn.Module: return self._orig_module def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + if self.compute_dtype != torch.float32: + # Cast the input float tensors to the specified compute_dtype + args, kwargs = _cast_floats_tensors(self.compute_dtype, *args, **kwargs) output = self.module(*args, **kwargs) # Need to shard the output of the forward to instruct the compiler # to enforce the FSDP algorithm. @@ -192,3 +209,49 @@ def _auto_wrap( "if using an `auto_wrap_policy`") recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) + + +def _cast_floats_tensors(dtype: torch.dtype, *args: Any, + **kwargs: Any) -> Tuple[Any, Any]: + """ + Cast floating point Tensors in *args or **kwargs to dtype if they are not. + """ + + def fn(t): + if t.dtype != dtype and torch.is_floating_point(t): + t = t.to(dtype) + return t + + return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs) + + +def apply_to_tensors( + fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set] +) -> Union[torch.Tensor, Dict, List, Tuple, Set]: + """Recursively apply to all tensor in different kinds of container types.""" + + def _apply( + x: Union[torch.Tensor, Dict, List, Tuple, Set] + ) -> Union[torch.Tensor, Dict, List, Tuple, Set]: + if torch.is_tensor(x): + return fn(x) + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = _apply(value) + return od + elif isinstance(x, PackedSequence): + _apply(x) + return x + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(container)