From f26771db9c4fd98d6a815c8af8b9100560bace67 Mon Sep 17 00:00:00 2001 From: Ayaka Date: Mon, 7 Oct 2024 06:31:04 -0700 Subject: [PATCH] [Pallas TPU] Add lowerings for `lax.population_count_p` and `lax.clz_p` PiperOrigin-RevId: 683159569 --- jax/_src/pallas/mosaic/lowering.py | 15 +++++++++++++++ tests/pallas/ops_test.py | 3 +-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e3bfbf6d10fa..c9dee8137161 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2086,6 +2086,21 @@ def _floor_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.floor_p] = _floor_lowering_rule +def _clz_lowering_rule(ctx: LoweringRuleContext, x): + return math.CountLeadingZerosOp(x).result + +lowering_rules[lax.clz_p] = _clz_lowering_rule + + +def _population_count_lowering_rule(ctx: LoweringRuleContext, x): + aval_out = ctx.avals_out[0] + if aval_out.shape == (): + raise ValueError("Population count is not supported on scalars") + return math.CtPopOp(x).result + +lowering_rules[lax.population_count_p] = _population_count_lowering_rule + + # Mapping for signed integer comparisons. _cmpsi_lowering_types = { lax.eq_p: arith.CmpIPredicate.eq, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 979870ff351f..96671c25ece4 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -773,8 +773,7 @@ def test_elementwise(self, fn, dtype): # TODO(b/370578663): implement these lowerings on TPU if jtu.test_device_matches(["tpu"]) and fn in ( jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh, - jnp.cbrt, jnp.cosh, lax.clz, jnp.expm1, - lax.population_count, jnp.sinh, + jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh, ): self.skipTest(f"{fn.__name__} not implemented on TPU")