-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[fix] merge upstream feature/dist-optim;
- Loading branch information
1 parent
a7790a9
commit 4e9a571
Showing
3 changed files
with
993 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
import os | ||
import torch | ||
from torch.optim import Optimizer | ||
|
||
|
||
__all__ = ["Adafactor"] | ||
# Adafactor | ||
class Adafactor(Optimizer): | ||
def __init__( | ||
self, | ||
params, | ||
lr=None, | ||
eps=(1e-30, 1e-3), | ||
clip_threshold=1.0, | ||
decay_rate=-0.8, | ||
beta1=None, | ||
weight_decay=0.0, | ||
scale_parameter=True, | ||
relative_step=True, | ||
warmup_init=False, | ||
): | ||
if lr is not None and relative_step: | ||
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") | ||
if warmup_init and not relative_step: | ||
raise ValueError("`warmup_init=True` requires `relative_step=True`") | ||
|
||
defaults = { | ||
"lr": lr, | ||
"eps": eps, | ||
"clip_threshold": clip_threshold, | ||
"decay_rate": decay_rate, | ||
"beta1": beta1, | ||
"weight_decay": weight_decay, | ||
"scale_parameter": scale_parameter, | ||
"relative_step": relative_step, | ||
"warmup_init": warmup_init, | ||
} | ||
super().__init__(params, defaults) | ||
|
||
@staticmethod | ||
def _get_lr(param_group, param_state): | ||
rel_step_sz = param_group["lr"] | ||
if param_group["relative_step"]: | ||
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 | ||
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) | ||
param_scale = 1.0 | ||
if param_group["scale_parameter"]: | ||
param_scale = max(param_group["eps"][1], param_state["RMS"]) | ||
return param_scale * rel_step_sz | ||
|
||
@staticmethod | ||
def _get_options(param_group, param_shape): | ||
factored = len(param_shape) >= 2 | ||
use_first_moment = param_group["beta1"] is not None | ||
return factored, use_first_moment | ||
|
||
@staticmethod | ||
def _rms(tensor): | ||
return tensor.norm(2) / (tensor.numel() ** 0.5) | ||
|
||
@staticmethod | ||
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): | ||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) | ||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() | ||
return torch.mul(r_factor, c_factor) | ||
|
||
@torch.no_grad() | ||
def step(self, closure=None): | ||
""" | ||
Performs a single optimization step | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
""" | ||
param_groups: Dict | ||
{ | ||
"params":[weight, bias] | ||
"lr" | ||
"eps" | ||
"clip_threshold" | ||
"decay_rate" | ||
"beta1" | ||
"weight_decay" | ||
"scale_parameter" | ||
"relative_step" | ||
"warmup_init" | ||
} | ||
""" | ||
|
||
for group in self.param_groups: | ||
# update weight & bias | ||
for p in group["params"]: | ||
if p.grad is None: | ||
continue | ||
""" | ||
# grad shape is same as weigh / bias | ||
""" | ||
grad = p.grad | ||
if grad.dtype in {torch.float16, torch.bfloat16}: | ||
grad = grad.float() | ||
if grad.is_sparse: | ||
raise RuntimeError("Adafactor does not support sparse gradients.") | ||
|
||
""" | ||
p is weight | ||
state | ||
{'step', | ||
'exp_avg_sq_row', | ||
'exp_avg_sq_col', | ||
'RMS' | ||
} | ||
p is bias | ||
state | ||
{'step', | ||
'exp_avg_sq', | ||
'RMS' | ||
} | ||
""" | ||
|
||
state = self.state[p] | ||
grad_shape = grad.shape | ||
|
||
factored, use_first_moment = self._get_options(group, grad_shape) | ||
# State Initialization | ||
if len(state) == 0: | ||
state["step"] = 0 | ||
if use_first_moment: | ||
# Exponential moving average of gradient values | ||
state["exp_avg"] = torch.zeros_like(grad) | ||
if factored: | ||
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) | ||
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) | ||
else: | ||
state["exp_avg_sq"] = torch.zeros_like(grad) | ||
|
||
state["RMS"] = 0 | ||
else: | ||
if use_first_moment: | ||
state["exp_avg"] = state["exp_avg"].to(grad) | ||
if factored: | ||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) | ||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) | ||
else: | ||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) | ||
|
||
p_data_fp32 = p | ||
if p.dtype in {torch.float16, torch.bfloat16}: | ||
p_data_fp32 = p_data_fp32.float() | ||
|
||
state["step"] += 1 | ||
# state["RMS"] = self._rms(p_data_fp32) | ||
lr = self._get_lr(group, state) | ||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) | ||
update = (grad**2) + group["eps"][0] | ||
if factored: | ||
exp_avg_sq_row = state["exp_avg_sq_row"] | ||
exp_avg_sq_col = state["exp_avg_sq_col"] | ||
# Exponential average of row indexes | ||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) | ||
# Exponential average of columns indexes | ||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) | ||
# Approximation of exponential moving average of square of gradient | ||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) | ||
update.mul_(grad) | ||
else: | ||
exp_avg_sq = state["exp_avg_sq"] | ||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) | ||
update = exp_avg_sq.rsqrt().mul_(grad) | ||
# RMS | ||
# update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) | ||
update.mul_(lr) | ||
|
||
if use_first_moment: | ||
exp_avg = state["exp_avg"] | ||
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) | ||
update = exp_avg | ||
|
||
if group["weight_decay"] != 0: | ||
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) | ||
p_data_fp32.add_(-update) | ||
if p.dtype in {torch.float16, torch.bfloat16}: | ||
p.copy_(p_data_fp32) | ||
|
||
return loss |
Oops, something went wrong.