Skip to content

Commit

Permalink
lax.zeros_like_shaped_array shouldn't apply lax.broadcast to float0-d…
Browse files Browse the repository at this point in the history
…typed array

instead, float0-dtyped arrays should just be handled by ordinary numpy

also, revise a weirdly-written test that revelaed this bug

(the test rewrite and the bug fix each independently make the test start passing, but i think it's best to include both)
  • Loading branch information
mattjj committed Nov 30, 2023
1 parent 53e66c1 commit e6a6b70
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
14 changes: 8 additions & 6 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from __future__ import annotations

import types
from typing import Any, Callable, TypeVar, cast
from typing import Any, Callable, TypeVar, Union

import numpy as np

from jax._src import core
from jax._src import traceback_util
Expand Down Expand Up @@ -48,15 +50,15 @@ def add_abstract(xs, ys):

jaxval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

def instantiate(z: Zero | Array) -> Array:
if type(z) is Zero:
def instantiate(z: Zero | Array) -> Union[np.ndarray, Array]:
if isinstance(z, Zero):
return zeros_like_aval(z.aval)
return cast(Array, z)
return z

def zeros_like_aval(aval: core.AbstractValue) -> Array:
def zeros_like_aval(aval: core.AbstractValue) -> Union[np.ndarray, Array]:
return aval_zeros_likers[type(aval)](aval)

aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
aval_zeros_likers: dict[type, Callable[[Any], Union[np.ndarray, Array]]] = {}

def zeros_like_jaxval(val: ArrayLike) -> Array:
return zeros_like_p.bind(val)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,10 +1219,10 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None)
def zeros_like_shaped_array(aval: ShapedArray) -> Array:
assert isinstance(aval, ShapedArray)
if aval.dtype == dtypes.float0:
scalar_zero = np.zeros((), dtype=aval.dtype)
return np.zeros(aval.shape, dtype=aval.dtype) # type: ignore
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
return broadcast(scalar_zero, aval.shape)
return broadcast(scalar_zero, aval.shape)

ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5540,7 +5540,7 @@ def test_vjp_caching(self):
def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
_, f_vjp = jax.vjp(lambda x: identity(x, True), 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
Expand Down

0 comments on commit e6a6b70

Please sign in to comment.