From c73d49f8135d341b051234f7aa0b257fc87e8965 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 5 Aug 2021 17:03:16 +0200 Subject: [PATCH 1/4] add notagnent --- src/finite_difference_calls.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/finite_difference_calls.jl b/src/finite_difference_calls.jl index 4aa29a15..09b81091 100644 --- a/src/finite_difference_calls.jl +++ b/src/finite_difference_calls.jl @@ -19,7 +19,7 @@ function _make_jvp_call(fdm, f, y, xs, ẋs, ignores) f2 = _wrap_function(f, xs, ignores) ignores = collect(ignores) - all(ignores) && return ntuple(_ -> nothing, length(xs)) + all(ignores) && return NoTangent() sigargs = zip(xs[.!ignores], ẋs[.!ignores]) return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...)) end @@ -45,7 +45,7 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores) ignores = collect(ignores) args = Any[nothing for _ in 1:length(xs)] - all(ignores) && return (args...,) + all(ignores) && return NoTangent() sigargs = xs[.!ignores] arginds = (1:length(xs))[.!ignores] fd = j′vp(fdm, f2, ȳ, sigargs...) From 2df1434f2a997540218f58a40778c58dd90f0fe5 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 5 Aug 2021 21:36:41 +0200 Subject: [PATCH 2/4] use NoTangent instead of nothing --- src/finite_difference_calls.jl | 6 +++--- src/testers.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/finite_difference_calls.jl b/src/finite_difference_calls.jl index 09b81091..0027383e 100644 --- a/src/finite_difference_calls.jl +++ b/src/finite_difference_calls.jl @@ -35,7 +35,7 @@ Call `FiniteDifferences.j′vp`, with the option to ignore certain `xs`. - `ȳ`: The adjoint w.r.t. output of `f`. - `xs`: Inputs to `f`, such that `y = f(xs...)`. - `ignores`: Collection of `Bool`s, the same length as `xs`. - If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === nothing`. + If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === NoTangent()`. # Returns - `∂xs::Tuple`: Derivatives estimated by finite differencing. @@ -44,7 +44,7 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores) f2 = _wrap_function(f, xs, ignores) ignores = collect(ignores) - args = Any[nothing for _ in 1:length(xs)] + args = Any[NoTangent() for _ in 1:length(xs)] all(ignores) && return NoTangent() sigargs = xs[.!ignores] arginds = (1:length(xs))[.!ignores] @@ -66,7 +66,7 @@ Return a new version of `f`, `fnew`, that ignores some of the arguments `xs`. - `f`: The function to be wrapped. - `xs`: Inputs to `f`, such that `y = f(xs...)`. - `ignores`: Collection of `Bool`s, the same length as `xs`. - If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === nothing`. + If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === NoTangent()`. """ function _wrap_function(f, xs, ignores) function fnew(sigargs...) diff --git a/src/testers.jl b/src/testers.jl index 1bce4ceb..91f21b3d 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -225,7 +225,7 @@ function test_rrule( accum_cotangents, ad_cotangents, fd_cotangents ) if accum_cotangent isa NoTangent # then we marked this argument as not differentiable - @assert fd_cotangent === nothing # this is how `_make_j′vp_call` works + @assert fd_cotangent === NoTangent() ad_cotangent isa ZeroTangent && error( "The pullback in the rrule should use NoTangent()" * " rather than ZeroTangent() for non-perturbable arguments.", From daea9aaa11dd4f55920590468be5765d57a6eba9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 5 Aug 2021 21:37:52 +0200 Subject: [PATCH 3/4] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d9e826ac..83ba2681 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.0.0" +version = "1.1.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From be9dbdb3cfeec094246d19fe51011189815c84aa Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 6 Aug 2021 11:48:43 +0200 Subject: [PATCH 4/4] return all the args --- src/finite_difference_calls.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/finite_difference_calls.jl b/src/finite_difference_calls.jl index 0027383e..72cb7aea 100644 --- a/src/finite_difference_calls.jl +++ b/src/finite_difference_calls.jl @@ -19,7 +19,7 @@ function _make_jvp_call(fdm, f, y, xs, ẋs, ignores) f2 = _wrap_function(f, xs, ignores) ignores = collect(ignores) - all(ignores) && return NoTangent() + all(ignores) && return ntuple(_ -> NoTangent(), length(xs)) sigargs = zip(xs[.!ignores], ẋs[.!ignores]) return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...)) end @@ -45,7 +45,7 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores) ignores = collect(ignores) args = Any[NoTangent() for _ in 1:length(xs)] - all(ignores) && return NoTangent() + all(ignores) && return (args...,) sigargs = xs[.!ignores] arginds = (1:length(xs))[.!ignores] fd = j′vp(fdm, f2, ȳ, sigargs...)