Skip to content

Commit

Permalink
[Pallas TPU] Fix boolean comparison
Browse files Browse the repository at this point in the history
Fixes #24030 and #24027

PiperOrigin-RevId: 684198399
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 10, 2024
1 parent 351187d commit 8627b0c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
23 changes: 15 additions & 8 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,15 +2148,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y):
)
dtype = x_aval.dtype

# Handle bool comparisons by casting to int32.
if jnp.issubdtype(dtype, jnp.bool_):
bool_cast_to = _dtype_to_ir_type(jnp.dtype("int32"))
true_ = ir_constant(1, mlir_type=bool_cast_to)
false_ = ir_constant(0, mlir_type=bool_cast_to)

x = arith.SelectOp(x, true_, false_)
y = arith.SelectOp(y, true_, false_)
dtype = jnp.dtype("int32")
# convert boolean array to int32 for comparison
return lower_fun(
lambda x, y: {
lax.eq_p: lambda x, y: x == y,
lax.ne_p: lambda x, y: x != y,
lax.lt_p: lambda x, y: x < y,
lax.le_p: lambda x, y: x <= y,
lax.gt_p: lambda x, y: x > y,
lax.ge_p: lambda x, y: x >= y,
}[prim](
jnp.where(x, 1, 0),
jnp.where(y, 1, 0),
),
multiple_results=False,
)(ctx, x, y)

if jnp.issubdtype(dtype, jnp.integer):
is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger)
Expand Down
12 changes: 0 additions & 12 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,10 +881,6 @@ def test_comparison(self, fn, dtype):
if jtu.test_device_matches(["tpu"]) and dtype == "float16":
self.skipTest("float16 is not supported on TPU")

# TODO: skipped due to https://github.com/jax-ml/jax/issues/24030
if jtu.test_device_matches(["tpu"]) and dtype == "bool":
self.skipTest("Not supported on TPU")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_),
grid=1)
Expand Down Expand Up @@ -966,14 +962,6 @@ def test_binary(self, f, dtype):
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
self.skipTest("16-bit types are not supported on TPU")

# TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
if (
jtu.test_device_matches(["tpu"])
and f == jnp.remainder
and not self.INTERPRET
):
self.skipTest("jnp.remainder on TPU is only supported in interpret mode")

# TODO(ayx): fix this on TPU
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
self.skipTest("Not supported on TPU")
Expand Down

0 comments on commit 8627b0c

Please sign in to comment.