Skip to content

Commit

Permalink
[zero] fix missing hook removal (#5824)
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw authored Jun 17, 2024
1 parent a10802e commit 4cd4a1f
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import weakref
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -94,20 +95,27 @@ def __init__(
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
self.grad_handles = []
if self._overlap_communication or self._partition_grad:
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
param_group = self.working_param_group
for param in param_group:
if param.requires_grad:
self_weak_proxy = weakref.proxy(self)
param_weak_proxy = weakref.proxy(param)

def _grad_handler(grad, param):
def _grad_handler(grad):
# if run with no_sync context, would not sync grad when backward
if self.require_grad_sync:
self._add_to_bucket(param)
if self_weak_proxy.require_grad_sync:
self_weak_proxy._add_to_bucket(param_weak_proxy)
return grad

param.register_hook(partial(_grad_handler, param=param))
self.grad_handles.append(param.register_post_accumulate_grad_hook(partial(_grad_handler)))

def __del__(self):
for handle in self.grad_handles:
handle.remove()

def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
Expand Down

0 comments on commit 4cd4a1f

Please sign in to comment.