Skip to content

Commit

Permalink
[Pallas TPU] Add lowerings for lax.population_count_p and lax.clz_p
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683159569
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 9, 2024
1 parent f52b016 commit f26771d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
15 changes: 15 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,21 @@ def _floor_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.floor_p] = _floor_lowering_rule


def _clz_lowering_rule(ctx: LoweringRuleContext, x):
return math.CountLeadingZerosOp(x).result

lowering_rules[lax.clz_p] = _clz_lowering_rule


def _population_count_lowering_rule(ctx: LoweringRuleContext, x):
aval_out = ctx.avals_out[0]
if aval_out.shape == ():
raise ValueError("Population count is not supported on scalars")
return math.CtPopOp(x).result

lowering_rules[lax.population_count_p] = _population_count_lowering_rule


# Mapping for signed integer comparisons.
_cmpsi_lowering_types = {
lax.eq_p: arith.CmpIPredicate.eq,
Expand Down
3 changes: 1 addition & 2 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,7 @@ def test_elementwise(self, fn, dtype):
# TODO(b/370578663): implement these lowerings on TPU
if jtu.test_device_matches(["tpu"]) and fn in (
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,
jnp.cbrt, jnp.cosh, lax.clz, jnp.expm1,
lax.population_count, jnp.sinh,
jnp.cbrt, jnp.cosh, jnp.expm1, jnp.sinh,
):
self.skipTest(f"{fn.__name__} not implemented on TPU")

Expand Down

0 comments on commit f26771d

Please sign in to comment.