Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix the 2d ring attn when using multiple machine #6071

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None

@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
Expand All @@ -442,7 +442,7 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
tp_size = dist.get_world_size(tp_group)

if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
Expand Down Expand Up @@ -471,19 +471,42 @@ def get_double_ring_groups(sp_group, inner_ring_size=None):
inner_ring_group = None
inter_ring_group = None

# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
if sp_rank in ranks:
inner_ring_group = group

# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
if sp_rank in ranks:
inter_ring_group = group
world_size = dist.get_world_size()
rank = dist.get_rank()
groups = int(world_size / sp_size)
Copy link
Contributor

@Edenzzzz Edenzzzz Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do world_size / sp_size, it will loop over all tp * pp * dp ranks instead of just tp ranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, as the tp axis is the first, and sp axis is the next, see as the HybirdPlugin class. groups means the sp groups number. If sp_size = 8, which means there is 8 gpus in one sp group, and if you use two machines to do the trainning(world size = 16), the groups = 2. So it should be pp * dp ranks.


if tp_size > 1:
for group_id in range(groups):
Copy link
Contributor

@Edenzzzz Edenzzzz Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should look like this?

world_size = 16
sp_size = 4
outer_ring_size = 2
inner_ring_size = sp_size // outer_ring_size
tp_size = 2
total_inner_rings = world_size // (outer_ring_size * tp_size) # loop through groups of size (inner_ring * tp) 
total_inner_size = tp_size * inner_ring_size
for j in range(total_inner_rings):
    # inside each group, duplicate tp group inner_sp_size times
    for k in range(tp_size):
        print(f"inner ring ranks: {list(range(k + j * total_inner_size, (j + 1) * total_inner_size, tp_size))}")
print("---------------------------------")


sp_tp_size = total_inner_size * outer_ring_size
n_groups = world_size // sp_tp_size # dp * pp
print(f"n_groups: {n_groups}")
for i in range(0, n_groups):
    start = i * sp_tp_size
    end = (i + 1) * sp_tp_size
    for j in range(outer_ring_size):    
        for k in range(tp_size):
            print(f"inter ring ranks: {list(range(start + k + j * tp_size, end, total_inner_size))}")
image image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I use 4 gpus in two nodes separately in the example ,see as https://arxiv.org/pdf/2406.18485.
However, I believe there is an issue with the current algorithm, as such a rank group cannot proceed with the next step of communication. Additionally, in the CI tests, the author only tested it with 4 GPUs, so this PR may need to be closed for now.

Copy link
Contributor Author

@wangbluo wangbluo Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If use this config, which is :
inner group:[0, 1, ][2, 3],[4, 5],[ 6, 7],
inter group: [0,4][1,5][2,6][3,7],
the communication can be proceed and can be training, seems like the current algorithm need the inner group number equals to the inter group number.
However it's wrong, the paper does not impose such requirements, the inner group can be set to 4, and the inter group to 2.
img_v3_02f3_a0329a76-5eff-42d5-b8e1-e62510f1413g

for i in range(inner_ring_size):
ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size))
group = dist.new_group(ranks)
if rank in ranks:
inner_ring_group = group
for group_id in range(groups):
for i in range(num_rings):
ranks = list(range(i + group_id * num_rings, world_size, sp_size))
group = dist.new_group(ranks)
if rank in ranks:
inter_ring_group = group
else:
for i in range(sp_size // 2):
ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1))
if rank in ranks:
print(
"rank:",
rank,
"inner ranks:",
ranks,
)
group = dist.new_group(ranks)
inner_ring_group = group
for group_id in range(num_rings):
for i in range(num_rings):
ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size))
ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7]
if rank in ranks:
group = dist.new_group(ranks)
inter_ring_group = group

return inner_ring_group, inter_ring_group

Expand All @@ -502,6 +525,7 @@ def attention(
deterministic=False,
return_softmax=False,
inner_ring_size=None,
tp_group=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -537,7 +561,6 @@ def attention(
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()

assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
Expand All @@ -550,7 +573,9 @@ def attention(

if RingAttention.SP_GROUP is not sp_group:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, tp_group, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:
Expand Down Expand Up @@ -597,6 +622,7 @@ def attention(
attention_mask_type == AttnMaskType.PADDED_CAUSAL,
inner_ring_group,
inter_ring_group,
tp_group,
)

if attention_mask_type == AttnMaskType.PADDED_CAUSAL:
Expand Down Expand Up @@ -627,6 +653,7 @@ def forward(
is_packed: Optional[bool] = False,
inner_ring_group: Optional[dist.ProcessGroup] = None,
inter_ring_group: Optional[dist.ProcessGroup] = None,
tp_group: Optional[dist.ProcessGroup] = None,
):

cu_seqlens_q = cu_seqlens_kv = cu_seqlens
Expand Down Expand Up @@ -1123,7 +1150,7 @@ def _other_ring_backward(ring_num_idx, dq):
if not is_packed:
dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)]

return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None)
return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None)

@staticmethod
def prepare_varlen_batch(
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ def forward(

sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
tp_group = shard_config.tensor_parallel_process_group

if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
Expand All @@ -868,6 +870,7 @@ def forward(
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,8 @@ def forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

tp_group = shard_config.tensor_parallel_process_group

if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
Expand All @@ -571,6 +573,7 @@ def forward(
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
)

elif shard_config.enable_flash_attention:
Expand Down
Loading