Skip to content

Commit

Permalink
Move additional CI enabled/disabled configurations into jax BUILD files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684457403
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 10, 2024
1 parent aa3254d commit 19dbff5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
51 changes: 48 additions & 3 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jax_generate_backend_suites()
jax_multiplatform_test(
name = "api_test",
srcs = ["api_test.py"],
enable_configs = ["tpu_v3_2x2"],
shard_count = 10,
)

Expand Down Expand Up @@ -70,6 +71,9 @@ jax_multiplatform_test(
"cpu",
"gpu",
],
enable_configs = [
"gpu_2gpu",
],
tags = ["multiaccelerator"],
deps = py_deps("tensorflow_core"),
)
Expand Down Expand Up @@ -214,6 +218,14 @@ jax_py_test(
jax_multiplatform_test(
name = "memories_test",
srcs = ["memories_test.py"],
enable_configs = [
"cpu",
"gpu_2gpu",
"tpu_v3_2x2",
"tpu_v4_2x2",
"tpu_v5p_2x2",
"tpu_v5e_4x2",
],
shard_count = {
"tpu": 5,
},
Expand All @@ -234,6 +246,8 @@ jax_multiplatform_test(
"gpu_2gpu_shardy",
"tpu_v3_2x2_shardy",
"tpu_v4_2x2_shardy",
"tpu_v3_2x2",
"gpu_2gpu",
],
shard_count = {
"cpu": 5,
Expand All @@ -258,6 +272,11 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "shard_alike_test",
srcs = ["shard_alike_test.py"],
enable_configs = [
"tpu_v3_2x2",
"tpu_v5e_4x2",
"tpu_v4_2x2",
],
deps = [
"//jax:experimental",
],
Expand Down Expand Up @@ -298,6 +317,9 @@ jax_multiplatform_test(
backend_tags = {
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
},
enable_configs = [
"tpu_v3_2x2",
],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
Expand Down Expand Up @@ -644,6 +666,10 @@ jax_py_test(
jax_multiplatform_test(
name = "multibackend_test",
srcs = ["multibackend_test.py"],
enable_configs = [
"tpu_v3_2x2",
"gpu_2gpu",
],
)

jax_multiplatform_test(
Expand Down Expand Up @@ -693,6 +719,10 @@ jax_multiplatform_test(
"requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit.
],
},
enable_configs = [
"gpu_v100",
"tpu_v3_2x2",
],
shard_count = {
"cpu": 30,
"gpu": 30,
Expand Down Expand Up @@ -1030,6 +1060,7 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "checkify_test",
srcs = ["checkify_test.py"],
enable_configs = ["tpu_v3_2x2"],
shard_count = {
"gpu": 2,
"tpu": 4,
Expand Down Expand Up @@ -1187,8 +1218,11 @@ jax_multiplatform_test(
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"gpu_h100",
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
tags = ["multiaccelerator"],
)
Expand All @@ -1197,8 +1231,11 @@ jax_multiplatform_test(
name = "debugging_primitives_test",
srcs = ["debugging_primitives_test.py"],
enable_configs = [
"gpu_h100",
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
)

Expand All @@ -1208,6 +1245,11 @@ jax_multiplatform_test(
backend_tags = {
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
},
enable_configs = [
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
Expand All @@ -1218,8 +1260,11 @@ jax_multiplatform_test(
name = "debugger_test",
srcs = ["debugger_test.py"],
enable_configs = [
"gpu_h100",
"cpu",
"gpu_h100",
"tpu_v2_1x1",
"tpu_v3_2x2",
"tpu_v4_2x2",
],
)

Expand Down
21 changes: 20 additions & 1 deletion tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ jax_multiplatform_test(
"tpu_all_gather_test.py",
],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e_4x2",
],
deps = [
"//jax:pallas_tpu_ops",
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
Expand Down Expand Up @@ -277,6 +280,10 @@ jax_multiplatform_test(
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
args = ["--logtostderr"],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e",
"tpu_v5p_1x1",
],
deps = [
"//jax:extend",
"//jax:pallas_tpu",
Expand Down Expand Up @@ -305,6 +312,12 @@ jax_multiplatform_test(
name = "tpu_pallas_distributed_test",
srcs = ["tpu_pallas_distributed_test.py"],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e_4x2",
"tpu_v5p_2x2",
"tpu_v4_2x2",
"tpu_v3_2x2",
],
deps = [
"//jax:extend",
"//jax:pallas_tpu",
Expand All @@ -316,6 +329,10 @@ jax_multiplatform_test(
name = "tpu_pallas_pipeline_test",
srcs = ["tpu_pallas_pipeline_test.py"],
enable_backends = ["tpu"],
enable_configs = [
"tpu_v5e_4x2",
"tpu_v5p_1x1",
],
shard_count = 5,
tags = [
"noasan", # Times out.
Expand All @@ -333,7 +350,9 @@ jax_multiplatform_test(
name = "tpu_pallas_async_test",
srcs = ["tpu_pallas_async_test.py"],
enable_backends = ["tpu"],
tags = [
enable_configs = [
"tpu_v5e_4x2",
"tpu_v5p_1x1",
],
deps = [
"//jax:pallas_tpu",
Expand Down

0 comments on commit 19dbff5

Please sign in to comment.