From c8a5c6f79c3d8bf1ac245cdc643e28c0bb03e8b9 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 5 Sep 2024 12:56:20 -0500 Subject: [PATCH] fix --- src/enzyme_ad/jax/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index ea422b1..bfee6a8 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -1621,7 +1621,7 @@ def zero_like(i, arg): lowered_func = lower(jitres, avals_in) kept = lowered_func.compile()._executable._kept_var_idx args_flat = [ - arg if i in kept else zero_like(arg) + arg if i in kept else zero_like(i, arg) for (i, arg) in enumerate(args_flat) ]