Skip to content

Commit

Permalink
[Pallas TPU] Add lowering for lax.cos_p
Browse files Browse the repository at this point in the history
Fixes #24026

PiperOrigin-RevId: 680754948
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Sep 30, 2024
1 parent 23ce5a1 commit a24420e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,6 +2001,13 @@ def _sin_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.sin_p] = _sin_lowering_rule


def _cos_lowering_rule(ctx: LoweringRuleContext, x):
return math.CosOp(x).result


lowering_rules[lax.cos_p] = _cos_lowering_rule


def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
return math.TanhOp(x).result

Expand Down
2 changes: 1 addition & 1 deletion tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def test_elementwise(self, fn, dtype):
# TODO(b/370578663): implement these lowerings on TPU
if jtu.test_device_matches(["tpu"]) and fn in (
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,
jnp.cbrt, jnp.ceil, jnp.cos, jnp.cosh, lax.clz, jnp.expm1,
jnp.cbrt, jnp.ceil, jnp.cosh, lax.clz, jnp.expm1,
jnp.floor, lax.population_count, jnp.sinh, jnp.tan,
):
self.skipTest(f"{fn.__name__} not implemented on TPU")
Expand Down

0 comments on commit a24420e

Please sign in to comment.