Skip to content

Commit

Permalink
rewrite test not to include float0 broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 30, 2023
1 parent 5b3fc1b commit 43ed74f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5519,7 +5519,7 @@ def test_vjp_caching(self):
def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
_, f_vjp = jax.vjp(lambda x: identity(x, True), 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
Expand Down

0 comments on commit 43ed74f

Please sign in to comment.