diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index ea422b1..bfee6a8 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1621,7 +1621,7 @@ def zero_like(i, arg): lowered_func = lower(jitres, avals_in) kept = lowered_func.compile()._executable._kept_var_idx args_flat = [ - arg if i in kept else zero_like(arg) + arg if i in kept else zero_like(i, arg) for (i, arg) in enumerate(args_flat) ]