-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] fix the 2d ring attn when using multiple machine #6071
base: main
Are you sure you want to change the base?
Conversation
inter_ring_group = group | ||
world_size = dist.get_world_size() | ||
rank = dist.get_rank() | ||
groups = int(world_size / sp_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you do world_size / sp_size
, it will loop over all tp * pp * dp ranks instead of just tp ranks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, as the tp axis is the first, and sp axis is the next, see as the HybirdPlugin class. groups means the sp groups number. If sp_size = 8, which means there is 8 gpus in one sp group, and if you use two machines to do the trainning(world size = 16), the groups = 2. So it should be pp * dp ranks.
groups = int(world_size / sp_size) | ||
|
||
if tp_size > 1: | ||
for group_id in range(groups): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it should look like this?
world_size = 16
sp_size = 4
outer_ring_size = 2
inner_ring_size = sp_size // outer_ring_size
tp_size = 2
total_inner_rings = world_size // (outer_ring_size * tp_size) # loop through groups of size (inner_ring * tp)
total_inner_size = tp_size * inner_ring_size
for j in range(total_inner_rings):
# inside each group, duplicate tp group inner_sp_size times
for k in range(tp_size):
print(f"inner ring ranks: {list(range(k + j * total_inner_size, (j + 1) * total_inner_size, tp_size))}")
print("---------------------------------")
sp_tp_size = total_inner_size * outer_ring_size
n_groups = world_size // sp_tp_size # dp * pp
print(f"n_groups: {n_groups}")
for i in range(0, n_groups):
start = i * sp_tp_size
end = (i + 1) * sp_tp_size
for j in range(outer_ring_size):
for k in range(tp_size):
print(f"inter ring ranks: {list(range(start + k + j * tp_size, end, total_inner_size))}")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I use 4 gpus in two nodes separately in the example ,see as https://arxiv.org/pdf/2406.18485.
However, I believe there is an issue with the current algorithm, as such a rank group cannot proceed with the next step of communication. Additionally, in the CI tests, the author only tested it with 4 GPUs, so this PR may need to be closed for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If use this config, which is :
inner group:[0, 1, ][2, 3],[4, 5],[ 6, 7],
inter group: [0,4][1,5][2,6][3,7],
the communication can be proceed and can be training, seems like the current algorithm need the inner group number equals to the inter group number.
However it's wrong, the paper does not impose such requirements, the inner group can be set to 4, and the inter group to 2.
🚨 Issue number
fixed #6017
📝 What does this PR do?
The double_ring_groups need to consider the tp groups as the tp axis is the first axis.
And the ranks in double_ring_groups need to transformered into global ranks.
For example, if using the first four cards of two machines, totaling eight cards for ring attention, the ranks of the inner ring group would be [0, 2], [1, 3], [4, 6], [5, 7], while the ranks of the inter ring group would be [0, 4], [1, 5], [2, 6], [3, 7].
Results: