Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 11, 2024
1 parent 0fd62a0 commit 2dff732
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 77 deletions.
3 changes: 1 addition & 2 deletions colossalai/nn/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
relative_step=True,
warmup_init=False,
):
lr=None
lr = None
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
Expand Down Expand Up @@ -198,5 +198,4 @@ def step(self, closure=None):
p.add_(p, alpha=(-group["weight_decay"] * lr))
p.add_(-update)


return loss
52 changes: 25 additions & 27 deletions colossalai/nn/optimizer/distributed_adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
import torch.distributed as dist

# from torch.optim import Optimizer
from colossalai.interface.optimizer import DistributedOptim

from colossalai.shardformer.layer._operation import _gather, _split
from colossalai.tensor.d_tensor import get_sharding_spec, is_distributed_tensor

Expand Down Expand Up @@ -50,14 +50,13 @@ def __init__(
self.data_parallel_group = None
self.shard_to_param = None # Dict{id:shape}, sample {id(param): torch.tensor}
self.use_zero = True
self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}

self.param_is_dtensor_dict = {} # {id(p): True/False}
self.grad_shape_dict = {} # {id(p): master param shape}
self.factored_dict = {} # {id(p): True/False}
self.use_first_moment_dict = {} # {id(p): True/False}
self.shard_spec_dict = {} # {id(p): ShardSpec}
super().__init__(params, defaults)


def setup_distributed(
self,
Expand All @@ -84,23 +83,25 @@ def setup_distributed(
if self.data_parallel_group is not None:
self.data_parallel_size = dist.get_world_size(self.data_parallel_group)
self.use_zero = use_zero

self.shard_to_param = shard_to_param if shard_to_param is not None else {}
# grad is None, cause we dont setup now
for group in self.param_groups:
for p in group["params"]:
self.param_is_dtensor_dict[id(p)] = is_distributed_tensor(self.shard_to_param.get(id(p)))
if self.param_is_dtensor_dict[id(p)]:
self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape
self.grad_shape_dict[id(p)] = self.shard_to_param.get(id(p)).shape
else:
self.grad_shape_dict[id(p)] = p.shape
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(group, self.grad_shape_dict[id(p)])
self.factored_dict[id(p)], self.use_first_moment_dict[id(p)] = self._get_options(
group, self.grad_shape_dict[id(p)]
)
# if self.factored_dict[id(p)]:
if self.param_is_dtensor_dict[id(p)]:
self.shard_spec_dict[id(p)] = get_sharding_spec(self.shard_to_param.get(id(p)))
else:
self.shard_spec_dict[id(p)] = None

@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
Expand Down Expand Up @@ -177,7 +178,7 @@ def step(self, closure=None):
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")

state = self.state[p]
grad_shape = self.grad_shape_dict[id(p)]
param_is_dtensor = self.param_is_dtensor_dict[id(p)]
Expand All @@ -203,16 +204,14 @@ def step(self, closure=None):
# Row indivisible shape situation
if grad_shape[0] % self.data_parallel_size != 0:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0], device=p.device, dtype=p.dtype
grad_shape[0], device=p.device, dtype=p.dtype
) # [H/dp/Tp]
else:
state["exp_avg_sq_row"] = torch.zeros(
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
grad_shape[0] // self.data_parallel_size, device=p.device, dtype=p.dtype
) # [H/dp/Tp]

state["exp_avg_sq_col"] = torch.zeros(
grad_shape[1], device=p.device, dtype=p.dtype
) # [W]

state["exp_avg_sq_col"] = torch.zeros(grad_shape[1], device=p.device, dtype=p.dtype) # [W]
else:
state["exp_avg_sq"] = torch.zeros_like(p)
state["RMS"] = 0
Expand Down Expand Up @@ -258,15 +257,11 @@ def step(self, closure=None):
# Row Residual situation
if grad_shape[0] % self.data_parallel_size != 0:
# gather update[flatten] along dp group then reshape to [H/tp, W]
update = _gather(
input_=update, dim=-1, process_group=self.data_parallel_group
)
update = _gather(input_=update, dim=-1, process_group=self.data_parallel_group)
# view update to origin[tp] shape
update_reshape = update.view(-1, grad_shape[1])
# gather grad[flatten] along dp group then reshape to [H/tp, W]
grad = _gather(
input_=grad, dim=-1, process_group=self.data_parallel_group
)
grad = _gather(input_=grad, dim=-1, process_group=self.data_parallel_group)
grad_reshape = grad.view(-1, grad_shape[1])
exp_avg_sq_row = state["exp_avg_sq_row"] # [H/tp]
exp_avg_sq_col = state["exp_avg_sq_col"] # [W]
Expand All @@ -278,7 +273,9 @@ def step(self, closure=None):
update_reshape = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = _split(input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group)
update = _split(
input_=update_reshape.view(-1), dim=-1, process_group=self.data_parallel_group
)
else:
update = update_reshape
else:
Expand All @@ -296,7 +293,9 @@ def step(self, closure=None):
input_=exp_avg_sq_row, dim=-1, process_group=self.tensor_parallel_group
)
sq_row_meam = exp_avg_sq_row_gather.mean(dim=-1, keepdim=True)
update_reshape = self._approx_sq_grad_row_parallel(exp_avg_sq_row, exp_avg_sq_col, sq_row_meam)
update_reshape = self._approx_sq_grad_row_parallel(
exp_avg_sq_row, exp_avg_sq_col, sq_row_meam
)
update_reshape.mul_(grad_reshape)
if self.use_zero:
update = update_reshape.view(-1)
Expand All @@ -317,8 +316,7 @@ def step(self, closure=None):

if group["weight_decay"] != 0:
p.add_(p, alpha=(-group["weight_decay"] * lr))

p.add_(-update)

p.add_(-update)

return loss
8 changes: 4 additions & 4 deletions docs/source/en/features/distributed_adafactor.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
# Distributed Adafactor

Author:
Author:

**Related Paper**
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)

## Introduction

Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.
Distributed Adafactor is an optimiser that supports hybrid optimisation, including 1D tensor parallelism as well as ZerO. It makes full use of computational resources through reasonable task parallelism, improves training efficiency and speed, and reduces space pressure on single card storage. It has a wide range of applications and currently supports a range of Transformer based models, see [tests.kit.model_zoo](https://github.com/hpcaitech/ColossalAI/tree/main/tests/kit/model_zoo) for details.

### API Reference

{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}

## Hands-On Practice
We now demonstrate how to use Distributed Adafactor.
We now demonstrate how to use Distributed Adafactor.
### step 1. Import libraries

```python
Expand Down Expand Up @@ -99,7 +99,7 @@ if zero_size > 1:
else:
out_tp.sum().backward()

# perform step for param and grad
# perform step for param and grad
dist_optim.step()
dist_optim.zero_grad()
```
Expand Down
Loading

0 comments on commit 2dff732

Please sign in to comment.