Skip to content

Commit

Permalink
[Pallas TPU] Consolidate OpsExtraTest into OpsTest
Browse files Browse the repository at this point in the history
Historically, tests that only ran on GPUs were placed in `OpsExtraTest`, while general tests were in `OpsTest`. However, this separation may cause us to miss issues that should be addressed on TPUs as well. Going forward, all tests will be unified in `OpsTest`, and any tests that fail on TPUs will be skipped individually using `skipTest`. This will help us better track and address TPU-specific failures.

PiperOrigin-RevId: 679493740
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 30, 2024
1 parent ff1c2ac commit 66fb579
Showing 1 changed file with 130 additions and 47 deletions.
177 changes: 130 additions & 47 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,16 +707,10 @@ def run(interpret=False):
for value in values
)
def test_sign(self, dtype, value):
if (
not jax.config.x64_enabled
and dtype in (jnp.uint64, jnp.int64, jnp.float64)
):
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

if (
jtu.test_device_matches(["tpu"])
and dtype in (jnp.uint16, jnp.int16, jnp.bfloat16, jnp.float16)
):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

@functools.partial(
Expand Down Expand Up @@ -753,37 +747,6 @@ def kernel(x_ref, o_ref):
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)


class OpsInterpretTest(OpsTest):
INTERPRET = True

def test_debug_print(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
)
def kernel(x_ref, o_ref):
jax.debug.print("x = {}", x_ref)

x = jnp.array([4.2, 2.4]).astype(jnp.float32)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
jax.effects_barrier()

self.assertIn("x = [4.2 2.4]", output())


class OpsExtraTest(PallasBaseTest):
"""These are additional ops tests that have not been ported to TPU yet."""
# TODO: fix these for TPU and merge with OpsTest.

def setUp(self):
super().setUp()
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
# TODO: most tests fail on TPU in non-interpret mode
self.skipTest("On TPU the test works only in interpret mode")

ELEMENTWISE_OPS = [
(
[jnp.abs, jnp.negative],
Expand Down Expand Up @@ -811,6 +774,18 @@ def setUp(self):
for fn, dtype in itertools.product(*args)
)
def test_elementwise(self, fn, dtype):
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

# TODO(ayx): implement these lowerings on TPU
if jtu.test_device_matches(["tpu"]) and fn in (
jnp.acosh, jnp.asin, jnp.atanh, jnp.cbrt, jnp.cos, jnp.tan,
):
self.skipTest(f"{fn.__name__} not implemented on TPU")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1
)
Expand Down Expand Up @@ -841,6 +816,9 @@ def kernel(x_ref, o_ref):
("float64", "float64"),
)
def test_pow(self, x_dtype, y_dtype):
if not jax.config.x64_enabled and jnp.dtype(x_dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), x_dtype), grid=1
)
Expand All @@ -867,8 +845,13 @@ def kernel(x_ref, o_ref):

@parameterized.parameters("float32", "float64")
def test_nextafter(self, dtype):
if jtu.test_device_matches(["tpu"]) and dtype == "float64":
self.skipTest("float64 disabled on TPU.")
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

# TODO: implement this on TPU
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not implemented: nextafter")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), dtype), grid=1
)
Expand Down Expand Up @@ -901,6 +884,13 @@ def test_comparison(self, fn, dtype):
if jtu.test_device_matches(["gpu"]) and dtype == "bool":
self.skipTest("Not implemented on GPU.")

if jtu.test_device_matches(["tpu"]) and dtype == "float16":
self.skipTest("float16 is not supported on TPU")

# TODO(ayx): skipped due to https://github.com/jax-ml/jax/issues/23972
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
self.skipTest("not supported on TPU")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_),
grid=1)
Expand Down Expand Up @@ -979,14 +969,17 @@ def kernel(x_ref, y_ref, o_ref):
for fn, dtype in itertools.product(*args)
)
def test_binary(self, f, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), dtype), grid=1
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = f(x_ref[...], y_ref[...])

x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
if (f == jnp.bitwise_left_shift):
if f == jnp.bitwise_left_shift:
y = jnp.array([3, 1, 4, 5, 2, 2, 2, 4]).astype(dtype)
else:
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
Expand All @@ -999,6 +992,9 @@ def kernel(x_ref, y_ref, o_ref):
((8, 16, 2), jnp.int8, 1),
)
def test_broadcasted_iota(self, shape, dtype, dimension):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Only 32-bit integer iota supported")

f = lambda: jax.lax.broadcasted_iota(dtype, shape, dimension)

@functools.partial(
Expand All @@ -1011,8 +1007,12 @@ def kernel(o_ref):

@parameterized.parameters("float16", "bfloat16", "float32")
def test_approx_tanh(self, dtype):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not implemented on TPU")

if self.INTERPRET:
self.skipTest("approx_tanh is not supported in interpret mode")

if (dtype == "bfloat16" and
not jtu.is_cuda_compute_capability_at_least("9.0")):
self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90")
Expand All @@ -1034,6 +1034,9 @@ def kernel(x_ref, o_ref):
)

def test_elementwise_inline_asm(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not implemented: elementwise_inline_asm_p")

if self.INTERPRET:
self.skipTest(
"elementwise_inline_asm is not supported in interpret mode"
Expand Down Expand Up @@ -1127,6 +1130,9 @@ def kernel(x_ref, o_ref):
((64,), (32, 2)),
)
def test_reshape(self, in_shape, out_shape):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
Expand Down Expand Up @@ -1156,6 +1162,10 @@ def f(x_ref, o_ref):
# fmt: on
)
def test_reshape_noop_or_singleton_dims(self, in_shape, out_shape):
# Unsupported implicit dim change: from "32,{0,0},(2,128),-1" to none
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
Expand All @@ -1182,6 +1192,10 @@ def kernel(o_ref):
)

def test_where_broadcasting(self):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx),
Expand All @@ -1207,6 +1221,10 @@ def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref):
((), (2, 2), ()),
)
def test_broadcast_in_dim(self, in_shape, out_shape, dims):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(out_shape, jnp.float32),
Expand All @@ -1227,6 +1245,12 @@ def f(x_ref, o_ref):
trans_y=[False, True],
)
def test_dot(self, size, dtype, trans_x, trans_y):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

if jtu.test_device_matches(["tpu"]):
self.skipTest("Not implemented: Transposed LHS")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((size, size), dtype),
Expand All @@ -1249,6 +1273,9 @@ def dot(x_ref, y_ref, o_ref):
block_size=[1, 2, 32, 64, 128],
)
def test_masked_load_store(self, size, block_size):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not implemented")

@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((size,), floatx)),
Expand Down Expand Up @@ -1290,15 +1317,18 @@ def test_strided_load(self):
# Reproducer from https://github.com/jax-ml/jax/issues/20895.
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
out_shape=jax.ShapeDtypeStruct((4, 4), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[::4]

x = jnp.arange(16, dtype=jnp.float32)
x = jnp.arange(64, dtype=jnp.float32).reshape((16, 4))
np.testing.assert_array_equal(kernel(x), x[::4])

def test_broadcasted_load_store(self):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Unimplemented primitive: broadcast_to")

m, n = 16, 32

@functools.partial(
Expand All @@ -1320,6 +1350,10 @@ def load(x_ref, o_ref):
((16, 32), (16, 16)),
)
def test_invalid_broadcasted_load(self, x_shape, mask_shape):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

if self.INTERPRET:
self.skipTest("No broadcasting checks in pl.load in interpret mode")

Expand All @@ -1342,6 +1376,10 @@ def kernel(x_ref, mask_ref, o_ref):
self.fail("Expected exception due to invalid broadcasting")

def test_swap(self):
# TODO: skipped due to https://github.com/jax-ml/jax/issues/24023
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

m, n = 16, 32

@functools.partial(
Expand Down Expand Up @@ -1421,6 +1459,10 @@ def masked_oob_swap_slice(_, _2, mask_ref, start_idx_ref, x_ref, y_ref):
("min_f32", pl.atomic_min, np.array([1, 2, 3, 4], np.float32), np.min),
)
def test_scalar_atomic(self, op, value, numpy_op):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((), value.dtype),
Expand Down Expand Up @@ -1452,6 +1494,9 @@ def atomic_kernel(x_ref, _, o_ref):

@parameterized.parameters((0,), (1,))
def test_array_atomic_add(self, axis):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Unimplemented primitive: broadcast_to")

m, n = 32, 8
if axis == 0:
grid = m
Expand Down Expand Up @@ -1489,6 +1534,10 @@ def reduce(x_ref, _, y_ref):
(2, 1, 1),
)
def test_atomic_cas(self, init_value, cmp, new_value):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]):
self.skipTest("Not supported on GPU in 64-bit mode")

Expand All @@ -1507,6 +1556,10 @@ def swap(_, lock_ref, out_ref):

@parameterized.parameters(1, 2, 3, 4, 8)
def test_atomic_counter(self, num_threads):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

if self.INTERPRET:
self.skipTest("While loop not supported in interpret mode.")

Expand All @@ -1532,6 +1585,10 @@ def _cond(_):

@parameterized.parameters(False, True)
def test_reduce_only_dim(self, use_store):
# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

m = 32
x = random.normal(random.key(0), (m,), dtype=jnp.float32)
out_shape = jax.ShapeDtypeStruct((), x.dtype)
Expand Down Expand Up @@ -1573,9 +1630,10 @@ def reduce(x_ref, y_ref):
if isinstance(axis, int) or "arg" not in op_name
])
def test_array_reduce(self, op, dtype, axis):
m, n = 32, 8
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"):
if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8:
self.skipTest("64-bit types require x64_enabled")

# Skip argmin/argmax on GPU in 64-bit mode because Pallas expects
Expand All @@ -1587,6 +1645,12 @@ def test_array_reduce(self, op, dtype, axis):
):
self.skipTest("Not supported on GPU in 64-bit mode")

# The Pallas TPU lowering currently supports only blocks of rank >= 1
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not supported on TPU")

m, n = 32, 8

def make_x(key):
if jnp.issubdtype(dtype, jnp.integer):
return random.permutation(
Expand Down Expand Up @@ -1623,6 +1687,9 @@ def reduce(x_ref, y_ref):
dtype=["float16", "float32", "int32", "uint32"],
)
def test_cumsum(self, dtype, axis):
if jtu.test_device_matches(["tpu"]):
self.skipTest("Not implemented on TPU")

m, n = 32, 8
out_dtype = dtype

Expand All @@ -1649,9 +1716,25 @@ def reduce(x_ref, y_ref):
np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i)


class OpsExtraInterpretTest(OpsExtraTest):
class OpsInterpretTest(OpsTest):
INTERPRET = True

def test_debug_print(self):
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),
grid=1,
)
def kernel(x_ref, o_ref):
jax.debug.print("x = {}", x_ref)

x = jnp.array([4.2, 2.4]).astype(jnp.float32)
with jtu.capture_stdout() as output:
jax.block_until_ready(kernel(x))
jax.effects_barrier()

self.assertIn("x = [4.2 2.4]", output())


class PallasPrimitivesTest(PallasBaseTest):

Expand Down

0 comments on commit 66fb579

Please sign in to comment.