Skip to content

Commit

Permalink
fix object_to_tensor usage when torch>=2.3.0 (#5820)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurisusnowdeng authored and ver217 committed Jul 15, 2024
1 parent d46c7a6 commit 3e52e19
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def _broadcast_object_list(
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
if Version(torch.__version__) >= Version("1.13.0"):
if Version(torch.__version__) >= Version("2.3.0"):
tensor_list, size_list = zip(
*[c10d._object_to_tensor(obj, device=current_device, group=group) for obj in object_list]
)
elif Version(torch.__version__) >= Version("1.13.0"):
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list])
else:
tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list])
Expand Down Expand Up @@ -276,7 +280,11 @@ def _send_recv_serialization_object(
send_object_tensor = None
send_object_size_tensor = None
if object is not None and send_dst is not None:
if Version(torch.__version__) >= Version("1.13.0"):
if Version(torch.__version__) >= Version("2.3.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(
object, device=current_device, group=send_group
)
elif Version(torch.__version__) >= Version("1.13.0"):
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
else:
send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
Expand Down

0 comments on commit 3e52e19

Please sign in to comment.