Skip to content

Commit

Permalink
[Pallas TPU] Add lowering for lax.tan_p
Browse files Browse the repository at this point in the history
This is a follow-up of #24028, which adds lowering for `lax.cos_p`

PiperOrigin-RevId: 681134519
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 1, 2024
1 parent 49ad220 commit 73eac8b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,13 @@ def _cos_lowering_rule(ctx: LoweringRuleContext, x):
lowering_rules[lax.cos_p] = _cos_lowering_rule


def _tan_lowering_rule(ctx: LoweringRuleContext, x):
return math.TanOp(x).result


lowering_rules[lax.tan_p] = _tan_lowering_rule


def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
return math.TanhOp(x).result

Expand Down
2 changes: 1 addition & 1 deletion tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def test_elementwise(self, fn, dtype):
if jtu.test_device_matches(["tpu"]) and fn in (
jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh,
jnp.cbrt, jnp.ceil, jnp.cosh, lax.clz, jnp.expm1,
jnp.floor, lax.population_count, jnp.sinh, jnp.tan,
jnp.floor, lax.population_count, jnp.sinh,
):
self.skipTest(f"{fn.__name__} not implemented on TPU")

Expand Down

0 comments on commit 73eac8b

Please sign in to comment.