diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 969c1f59a89e..4dd393c7c66e 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2148,15 +2148,22 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): ) dtype = x_aval.dtype - # Handle bool comparisons by casting to int32. if jnp.issubdtype(dtype, jnp.bool_): - bool_cast_to = _dtype_to_ir_type(jnp.dtype("int32")) - true_ = ir_constant(1, mlir_type=bool_cast_to) - false_ = ir_constant(0, mlir_type=bool_cast_to) - - x = arith.SelectOp(x, true_, false_) - y = arith.SelectOp(y, true_, false_) - dtype = jnp.dtype("int32") + # convert boolean array to int32 for comparison + return lower_fun( + lambda x, y: { + lax.eq_p: lambda x, y: x == y, + lax.ne_p: lambda x, y: x != y, + lax.lt_p: lambda x, y: x < y, + lax.le_p: lambda x, y: x <= y, + lax.gt_p: lambda x, y: x > y, + lax.ge_p: lambda x, y: x >= y, + }[prim]( + jnp.where(x, 1, 0), + jnp.where(y, 1, 0), + ), + multiple_results=False, + )(ctx, x, y) if jnp.issubdtype(dtype, jnp.integer): is_uint = jnp.issubdtype(dtype, jnp.unsignedinteger) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7eed08d0ef95..69718d9eb645 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -881,10 +881,6 @@ def test_comparison(self, fn, dtype): if jtu.test_device_matches(["tpu"]) and dtype == "float16": self.skipTest("float16 is not supported on TPU") - # TODO: skipped due to https://github.com/jax-ml/jax/issues/24030 - if jtu.test_device_matches(["tpu"]) and dtype == "bool": - self.skipTest("Not supported on TPU") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((8,), jnp.bool_), grid=1) @@ -966,14 +962,6 @@ def test_binary(self, f, dtype): if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2: self.skipTest("16-bit types are not supported on TPU") - # TODO: skipped due to https://github.com/jax-ml/jax/issues/24027 - if ( - jtu.test_device_matches(["tpu"]) - and f == jnp.remainder - and not self.INTERPRET - ): - self.skipTest("jnp.remainder on TPU is only supported in interpret mode") - # TODO(ayx): fix this on TPU if jtu.test_device_matches(["tpu"]) and dtype == "uint32": self.skipTest("Not supported on TPU")