Skip to content

Commit

Permalink
change tester name
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 22, 2024
1 parent 7551bc6 commit f326884
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_shardformer/test_layer/test_ring_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16])
def test_ring_attn(seq_len, batch_size, nheads, d, dtype):
def check_ring_attn(seq_len, batch_size, nheads, d, dtype):
torch.cuda.manual_seed(2)
rank = dist.get_rank()
world_size = dist.get_world_size()
Expand Down Expand Up @@ -59,13 +59,13 @@ def test_ring_attn(seq_len, batch_size, nheads, d, dtype):

def launch(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
test_ring_attn()
check_ring_attn()


@rerun_if_address_is_in_use()
def run_ring_attn():
def test_ring_attn():
spawn(launch, nprocs=8)


if __name__ == "__main__":
run_ring_attn()
test_ring_attn()

0 comments on commit f326884

Please sign in to comment.