diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index ea744bfbf1ee..56392cf77046 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2020,6 +2020,13 @@ def _cos_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.cos_p] = _cos_lowering_rule +def _tan_lowering_rule(ctx: LoweringRuleContext, x): + return math.TanOp(x).result + + +lowering_rules[lax.tan_p] = _tan_lowering_rule + + def _tanh_lowering_rule(ctx: LoweringRuleContext, x): return math.TanhOp(x).result diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 3d1d13f2baa4..e15db78740e5 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -776,7 +776,7 @@ def test_elementwise(self, fn, dtype): if jtu.test_device_matches(["tpu"]) and fn in ( jnp.acos, jnp.acosh, jnp.asin, jnp.asinh, jnp.atan, jnp.atanh, jnp.cbrt, jnp.ceil, jnp.cosh, lax.clz, jnp.expm1, - jnp.floor, lax.population_count, jnp.sinh, jnp.tan, + jnp.floor, lax.population_count, jnp.sinh, ): self.skipTest(f"{fn.__name__} not implemented on TPU")