Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix the 2d ring attn when using multiple machine #6071

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

wangbluo
Copy link
Contributor

🚨 Issue number

fixed #6017

📝 What does this PR do?

The double_ring_groups need to consider the tp groups as the tp axis is the first axis.
And the ranks in double_ring_groups need to transformered into global ranks.

For example, if using the first four cards of two machines, totaling eight cards for ring attention, the ranks of the inner ring group would be [0, 2], [1, 3], [4, 6], [5, 7], while the ranks of the inter ring group would be [0, 4], [1, 5], [2, 6], [3, 7].

Results:
image

@wangbluo wangbluo requested a review from a team as a code owner September 25, 2024 10:57
@wangbluo wangbluo closed this Sep 26, 2024
@wangbluo wangbluo deleted the ring_attention branch September 26, 2024 10:06
@wangbluo wangbluo restored the ring_attention branch September 26, 2024 10:07
@wangbluo wangbluo reopened this Sep 26, 2024
inter_ring_group = group
world_size = dist.get_world_size()
rank = dist.get_rank()
groups = int(world_size / sp_size)
Copy link
Contributor

@Edenzzzz Edenzzzz Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do world_size / sp_size, it will loop over all tp * pp * dp ranks instead of just tp ranks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, as the tp axis is the first, and sp axis is the next, see as the HybirdPlugin class. groups means the sp groups number. If sp_size = 8, which means there is 8 gpus in one sp group, and if you use two machines to do the trainning(world size = 16), the groups = 2. So it should be pp * dp ranks.

groups = int(world_size / sp_size)

if tp_size > 1:
for group_id in range(groups):
Copy link
Contributor

@Edenzzzz Edenzzzz Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should look like this?

world_size = 16
sp_size = 4
outer_ring_size = 2
inner_ring_size = sp_size // outer_ring_size
tp_size = 2
total_inner_rings = world_size // (outer_ring_size * tp_size) # loop through groups of size (inner_ring * tp) 
total_inner_size = tp_size * inner_ring_size
for j in range(total_inner_rings):
    # inside each group, duplicate tp group inner_sp_size times
    for k in range(tp_size):
        print(f"inner ring ranks: {list(range(k + j * total_inner_size, (j + 1) * total_inner_size, tp_size))}")
print("---------------------------------")


sp_tp_size = total_inner_size * outer_ring_size
n_groups = world_size // sp_tp_size # dp * pp
print(f"n_groups: {n_groups}")
for i in range(0, n_groups):
    start = i * sp_tp_size
    end = (i + 1) * sp_tp_size
    for j in range(outer_ring_size):    
        for k in range(tp_size):
            print(f"inter ring ranks: {list(range(start + k + j * tp_size, end, total_inner_size))}")
image image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I use 4 gpus in two nodes separately in the example ,see as https://arxiv.org/pdf/2406.18485.
However, I believe there is an issue with the current algorithm, as such a rank group cannot proceed with the next step of communication. Additionally, in the CI tests, the author only tested it with 4 GPUs, so this PR may need to be closed for now.

Copy link
Contributor Author

@wangbluo wangbluo Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If use this config, which is :
inner group:[0, 1, ][2, 3],[4, 5],[ 6, 7],
inter group: [0,4][1,5][2,6][3,7],
the communication can be proceed and can be training, seems like the current algorithm need the inner group number equals to the inter group number.
However it's wrong, the paper does not impose such requirements, the inner group can be set to 4, and the inter group to 2.
img_v3_02f3_a0329a76-5eff-42d5-b8e1-e62510f1413g

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants