Skip to content

Commit

Permalink
[plugin] add cast inputs option for zero (#6003) (#6022)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Aug 21, 2024
1 parent dcc44aa commit 0d3b0bd
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -73,7 +75,7 @@ def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool =
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
Expand Down Expand Up @@ -334,6 +336,7 @@ def __init__(
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
Expand Down Expand Up @@ -361,6 +364,8 @@ def __init__(
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
self.cast_inputs = cast_inputs

# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")

Expand Down Expand Up @@ -475,7 +480,10 @@ def configure(

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
)

# TODO: Support Galore + ZeRO
Expand Down

0 comments on commit 0d3b0bd

Please sign in to comment.