Skip to content

Commit

Permalink
Update test_command.py to original
Browse files Browse the repository at this point in the history
  • Loading branch information
mag1cp1n committed Nov 16, 2023
1 parent 63b3053 commit fa2012a
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions tests/unit/legate/driver/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,10 +1046,13 @@ def test_default_single_rank(self, genobjs: GenObjs) -> None:

assert result == ()

def test_utility_1_single_rank(self, genobjs: GenObjs) -> None:
def test_utility_1_single_rank_no_ucx(self, genobjs: GenObjs) -> None:
config, system, launcher = genobjs(["--utility", "1"])

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ()

Expand All @@ -1064,12 +1067,15 @@ def test_utility_1_single_rank_and_ucx(self, genobjs: GenObjs) -> None:
assert result == ()

@pytest.mark.parametrize("value", ("2", "3", "10"))
def test_utiltity_n_single_rank(
def test_utiltity_n_single_rank_no_ucx(
self, genobjs: GenObjs, value: str
) -> None:
config, system, launcher = genobjs(["--utility", value])

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

assert result == ()

Expand All @@ -1088,19 +1094,19 @@ def test_utiltity_n_single_rank_and_ucx(

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
def test_default_multi_rank(
def test_default_multi_rank_no_ucx(
self, genobjs: GenObjs, rank: str, rank_var: dict[str, str]
) -> None:
config, system, launcher = genobjs(
[], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2")

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1120,19 +1126,19 @@ def test_default_multi_rank_and_ucx(

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
def test_utility_1_multi_rank_no_launcher(
def test_utility_1_multi_rank_no_launcher_no_ucx(
self, genobjs: GenObjs, rank: str, rank_var: dict[str, str]
) -> None:
config, system, launcher = genobjs(
["--utility", "1"], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2")

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1151,19 +1157,19 @@ def test_utility_1_multi_rank_no_launcher_and_ucx(
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
def test_utility_1_multi_rank_with_launcher(
def test_utility_1_multi_rank_with_launcher_no_ucx(
self, genobjs: GenObjs, launch: str
) -> None:
config, system, launcher = genobjs(
["--utility", "1", "--launcher", launch], multi_rank=(2, 2)
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", "2", "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", "2")
assert result == ("-ll:bgwork", "2")

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
def test_utility_1_multi_rank_with_launcher_and_ucx(
Expand All @@ -1183,19 +1189,19 @@ def test_utility_1_multi_rank_with_launcher_and_ucx(
@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
def test_utility_n_multi_rank_no_launcher(
def test_utility_n_multi_rank_no_launcher_no_ucx(
self, genobjs: GenObjs, value: str, rank: str, rank_var: dict[str, str]
) -> None:
config, system, launcher = genobjs(
["--utility", value], multi_rank=(2, 2), rank_env={rank_var: rank}
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", value)
assert result == ("-ll:bgwork", value)

@pytest.mark.parametrize("rank_var", RANK_ENV_VARS)
@pytest.mark.parametrize("rank", ("0", "1", "2"))
Expand All @@ -1216,19 +1222,19 @@ def test_utility_n_multi_rank_no_launcher_and_ucx(

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
def test_utility_n_multi_rank_with_launcher(
def test_utility_n_multi_rank_with_launcher_no_ucx(
self, genobjs: GenObjs, value: str, launch: str
) -> None:
config, system, launcher = genobjs(
["--utility", value, "--launcher", launch], multi_rank=(2, 2)
)

networks_orig = list(install_info.networks)
install_info.networks = [x for x in networks_orig if x != "ucx"]
result = m.cmd_bgwork(config, system, launcher)
install_info.networks[:] = networks_orig[:]

if "ucx" in install_info.networks:
assert result == ("-ll:bgwork", value, "-ll:bgworkpin", "1")
else:
assert result == ("-ll:bgwork", value)
assert result == ("-ll:bgwork", value)

@pytest.mark.parametrize("launch", ("mpirun", "jsrun", "srun"))
@pytest.mark.parametrize("value", ("2", "3", "10"))
Expand Down Expand Up @@ -1607,4 +1613,4 @@ def test_with_legate_opts(self, genobjs: GenObjs, opts: list[str]) -> None:


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
sys.exit(pytest.main(sys.argv))

0 comments on commit fa2012a

Please sign in to comment.