-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] fix the 2d ring attn when using multiple machine #6071
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function): | |
INTER_RING_GROUP_COPY: dist.ProcessGroup = None | ||
|
||
@staticmethod | ||
def get_double_ring_groups(sp_group, inner_ring_size=None): | ||
def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): | ||
""" | ||
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size | ||
shouldn't be larger than the number of NICs on each node. | ||
|
@@ -442,7 +442,7 @@ def get_double_ring_groups(sp_group, inner_ring_size=None): | |
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. | ||
""" | ||
sp_size = dist.get_world_size(sp_group) | ||
sp_rank = dist.get_rank(sp_group) | ||
tp_size = dist.get_world_size(tp_group) | ||
|
||
if inner_ring_size is None: | ||
if torch.cuda.device_count() >= dist.get_world_size(): | ||
|
@@ -471,19 +471,42 @@ def get_double_ring_groups(sp_group, inner_ring_size=None): | |
inner_ring_group = None | ||
inter_ring_group = None | ||
|
||
# Create inner ring groups | ||
for i in range(inner_ring_size): | ||
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) | ||
group = dist.new_group(ranks) | ||
if sp_rank in ranks: | ||
inner_ring_group = group | ||
|
||
# Create inter ring groups | ||
for i in range(num_rings): | ||
ranks = list(range(i, sp_size, num_rings)) | ||
group = dist.new_group(ranks) | ||
if sp_rank in ranks: | ||
inter_ring_group = group | ||
world_size = dist.get_world_size() | ||
rank = dist.get_rank() | ||
groups = int(world_size / sp_size) | ||
|
||
if tp_size > 1: | ||
for group_id in range(groups): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it should look like this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I use 4 gpus in two nodes separately in the example ,see as https://arxiv.org/pdf/2406.18485. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If use this config, which is : |
||
for i in range(inner_ring_size): | ||
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) | ||
group = dist.new_group(ranks) | ||
if rank in ranks: | ||
inner_ring_group = group | ||
for group_id in range(groups): | ||
for i in range(num_rings): | ||
ranks = list(range(i + group_id * num_rings, world_size, sp_size)) | ||
group = dist.new_group(ranks) | ||
if rank in ranks: | ||
inter_ring_group = group | ||
else: | ||
for i in range(sp_size // 2): | ||
ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) | ||
if rank in ranks: | ||
print( | ||
"rank:", | ||
rank, | ||
"inner ranks:", | ||
ranks, | ||
) | ||
group = dist.new_group(ranks) | ||
inner_ring_group = group | ||
for group_id in range(num_rings): | ||
for i in range(num_rings): | ||
ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) | ||
ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] | ||
if rank in ranks: | ||
group = dist.new_group(ranks) | ||
inter_ring_group = group | ||
|
||
return inner_ring_group, inter_ring_group | ||
|
||
|
@@ -502,6 +525,7 @@ def attention( | |
deterministic=False, | ||
return_softmax=False, | ||
inner_ring_size=None, | ||
tp_group=None, | ||
**kwargs, | ||
): | ||
""" | ||
|
@@ -537,7 +561,6 @@ def attention( | |
RingAttention.ATTN_DONE = torch.cuda.Event() | ||
if RingAttention.SP_STREAM is None: | ||
RingAttention.SP_STREAM = torch.cuda.Stream() | ||
|
||
assert ( | ||
q.shape[2] == k.shape[2] | ||
), "Q, K and V having different sequence lengths (inference or cross-attn)\ | ||
|
@@ -550,7 +573,9 @@ def attention( | |
|
||
if RingAttention.SP_GROUP is not sp_group: | ||
RingAttention.SP_GROUP = sp_group | ||
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) | ||
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( | ||
sp_group, tp_group, inner_ring_size | ||
) | ||
RingAttention.INNER_RING_GROUP = inner_ring_group | ||
RingAttention.INTER_RING_GROUP = inter_ring_group | ||
else: | ||
|
@@ -597,6 +622,7 @@ def attention( | |
attention_mask_type == AttnMaskType.PADDED_CAUSAL, | ||
inner_ring_group, | ||
inter_ring_group, | ||
tp_group, | ||
) | ||
|
||
if attention_mask_type == AttnMaskType.PADDED_CAUSAL: | ||
|
@@ -627,6 +653,7 @@ def forward( | |
is_packed: Optional[bool] = False, | ||
inner_ring_group: Optional[dist.ProcessGroup] = None, | ||
inter_ring_group: Optional[dist.ProcessGroup] = None, | ||
tp_group: Optional[dist.ProcessGroup] = None, | ||
): | ||
|
||
cu_seqlens_q = cu_seqlens_kv = cu_seqlens | ||
|
@@ -1123,7 +1150,7 @@ def _other_ring_backward(ring_num_idx, dq): | |
if not is_packed: | ||
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] | ||
|
||
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) | ||
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) | ||
|
||
@staticmethod | ||
def prepare_varlen_batch( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you do
world_size / sp_size
, it will loop over all tp * pp * dp ranks instead of just tp ranks?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, as the tp axis is the first, and sp axis is the next, see as the HybirdPlugin class. groups means the sp groups number. If sp_size = 8, which means there is 8 gpus in one sp group, and if you use two machines to do the trainning(world size = 16), the groups = 2. So it should be pp * dp ranks.