diff --git a/Project.toml b/Project.toml index c4189397..c74cac02 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.15" +version = "0.7.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "0.9.44" +ChainRulesCore = "0.10" Compat = "3" FiniteDifferences = "0.12" julia = "1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 8a223d01..29f1b06e 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" +git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.44" +version = "0.10.1" [[ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] path = ".." uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.13" +version = "0.7.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"] -git-tree-sha1 = "8662836e29702fdfdb1b90cbe4162e31b94f1e51" +git-tree-sha1 = "f8c8e287c1d68abc2719ad58fb39de9f6c0d71b1" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.7" +version = "0.12.10" [[IOCapture]] deps = ["Logging"] @@ -167,9 +167,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "a1f226ebe197578c25fcf948bfff3d0d12f2ff20" +git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.2.1" +version = "1.2.2" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] diff --git a/docs/src/index.md b/docs/src/index.md index 0b57aa85..a632f404 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -37,7 +37,7 @@ and `rrule` function ChainRulesCore.rrule(::typeof(two2three), x1, x2) y = two2three(x1, x2) function two2three_pullback(Ȳ) - return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3]) + return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3]) end return y, two2three_pullback end diff --git a/src/testers.jl b/src/testers.jl index b1fa69b1..c00eedd6 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -106,9 +106,9 @@ function test_frule( xs = primal.(xẋs) ẋs = tangent.(xẋs) if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...) - _test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) + _test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) end - res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) + res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) res === nothing && throw(MethodError(frule, typeof((f, xs...)))) res isa Tuple || error("The frule should return (y, ∂y), not $res.") Ω_ad, dΩ_ad = res @@ -190,7 +190,7 @@ function test_rrule( ∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.") ∂self = ∂s[1] x̄s_ad = ∂s[2:end] - @test ∂self === NO_FIELDS # No internal fields + @test ∂self === NoTangent() # No internal fields # Correctness testing via finite differencing. # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 diff --git a/test/deprecated.jl b/test/deprecated.jl index 62570170..ef6603dd 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -33,7 +33,7 @@ end end function ChainRulesCore.rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return x, identity_pullback end @@ -53,7 +53,7 @@ end # define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative # in the rrule function ChainRulesCore.rrule(::typeof(sinconj), x) - sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ) + sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ) return sin(x), sinconj_pullback end @@ -66,7 +66,7 @@ end ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx) function ChainRulesCore.rrule(::typeof(fst), x, y) function fst_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return x, fst_pullback end @@ -83,7 +83,7 @@ end @testset "single input, multiple output" begin simo(x) = (x, 2x) function ChainRulesCore.rrule(simo, x) - simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b) + simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b) return simo(x), simo_pullback end function ChainRulesCore.frule((_, ẋ), simo, x) @@ -106,7 +106,7 @@ end ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx)) function ChainRulesCore.rrule(::typeof(first), x::Tuple) function first_pullback(Δx) - return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x) - 1)...)) + return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...)) end return first(x), first_pullback end @@ -142,7 +142,7 @@ end ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx) function ChainRulesCore.rrule(::typeof(fsymtest), x, s) function fsymtest_pullback(Δx) - return NO_FIELDS, Δx, NoTangent() + return NoTangent(), Δx, NoTangent() end return x, fsymtest_pullback end @@ -164,7 +164,7 @@ end end function ChainRulesCore.rrule(::typeof(futestkws), x; err=true) function futestkws_pullback(Δx) - return (NO_FIELDS, Δx) + return (NoTangent(), Δx) end return futestkws(x; err=err), futestkws_pullback end @@ -198,7 +198,7 @@ end end function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true) function fbtestkws_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return fbtestkws(x, y; err=err), fbtestkws_pullback end diff --git a/test/testers.jl b/test/testers.jl index ac9a7ad1..7db1d8df 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -41,7 +41,7 @@ end end function ChainRulesCore.rrule(::typeof(identity), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return x, identity_pullback end @@ -67,7 +67,7 @@ end x̄_ret = InplaceableThunk( @thunk(ȳ), ā -> (inplace_used = true; ā .+= ȳ) ) - return (NO_FIELDS, x̄_ret) + return (NoTangent(), x̄_ret) end return identity(x), identity_pullback end @@ -93,7 +93,7 @@ end function my_identity_pullback(ȳ) # only the in-place part is incorrect x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ) - return (NO_FIELDS, x̄_ret) + return (NoTangent(), x̄_ret) end return my_identity(x), my_identity_pullback end @@ -106,7 +106,7 @@ end @testset "check inferred" begin ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable), x) = (x, Δx) function ChainRulesCore.rrule(::typeof(f_inferrable), x) - f_inferrable_pullback(Δy) = (NO_FIELDS, Δy) + f_inferrable_pullback(Δy) = (NoTangent(), Δy) return x, f_inferrable_pullback end @@ -123,7 +123,7 @@ end return (x, x > 0 ? Float64(Δx) : Float32(Δx)) end function ChainRulesCore.rrule(::typeof(f_noninferrable_frule), x) - f_noninferrable_frule_pullback(Δy) = (NO_FIELDS, Δy) + f_noninferrable_frule_pullback(Δy) = (NoTangent(), Δy) return x, f_noninferrable_frule_pullback end @@ -144,10 +144,10 @@ end ChainRulesCore.frule((_, Δx), ::typeof(f_noninferrable_rrule), x) = (x, Δx) function ChainRulesCore.rrule(::typeof(f_noninferrable_rrule), x) if x > 0 - f_noninferrable_rrule_pullback(Δy) = (NO_FIELDS, Δy) + f_noninferrable_rrule_pullback(Δy) = (NoTangent(), Δy) return x, f_noninferrable_rrule_pullback else - return x, _ -> (NO_FIELDS, Δy) # this is not hit by the used point + return x, _ -> (NoTangent(), Δy) # this is not hit by the used point end end @@ -167,7 +167,7 @@ end @testset "check not inferred in pullback" begin function ChainRulesCore.rrule(::typeof(f_noninferrable_pullback), x) function f_noninferrable_pullback_pullback(Δy) - return (NO_FIELDS, x > 0 ? Float64(Δy) : Float32(Δy)) + return (NoTangent(), x > 0 ? Float64(Δy) : Float32(Δy)) end return x, f_noninferrable_pullback_pullback end @@ -182,7 +182,7 @@ end function ChainRulesCore.rrule(::typeof(f_noninferrable_thunk), x, y) function f_noninferrable_thunk_pullback(Δz) ∂x = @thunk(x > 0 ? Float64(Δz) : Float32(Δz)) - return (NO_FIELDS, ∂x, Δz) + return (NoTangent(), ∂x, Δz) end return x + y, f_noninferrable_thunk_pullback end @@ -198,7 +198,7 @@ end return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx)) end function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x) - f_inferrable_pullback_only_pullback(Δy) = (NO_FIELDS, oftype(x, Δy)) + f_inferrable_pullback_only_pullback(Δy) = (NoTangent(), oftype(x, Δy)) return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback end test_frule(f_inferrable_pullback_only, 2.0; check_inferred=true) @@ -212,7 +212,7 @@ end # define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative # in the rrule function ChainRulesCore.rrule(::typeof(sinconj), x) - sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ) + sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ) return sin(x), sinconj_pullback end @@ -225,7 +225,7 @@ end ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx) function ChainRulesCore.rrule(::typeof(fst), x, y) function fst_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return x, fst_pullback end @@ -242,7 +242,7 @@ end @testset "single input, multiple output" begin simo(x) = (x, 2x) function ChainRulesCore.rrule(simo, x) - simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b) + simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b) return simo(x), simo_pullback end function ChainRulesCore.frule((_, ẋ), simo, x) @@ -264,7 +264,7 @@ end ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx)) function ChainRulesCore.rrule(::typeof(first), x::Tuple) function first_pullback(Δx) - return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x) - 1)...)) + return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...)) end return first(x), first_pullback end @@ -294,7 +294,7 @@ end ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx) function ChainRulesCore.rrule(::typeof(fsymtest), x, s) function fsymtest_pullback(Δx) - return NO_FIELDS, Δx, NoTangent() + return NoTangent(), Δx, NoTangent() end return x, fsymtest_pullback end @@ -314,7 +314,7 @@ end end function ChainRulesCore.rrule(::typeof(futestkws), x; err=true) function futestkws_pullback(Δx) - return (NO_FIELDS, Δx) + return (NoTangent(), Δx) end return futestkws(x; err=err), futestkws_pullback end @@ -348,7 +348,7 @@ end end function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true) function fbtestkws_pullback(Δx) - return (NO_FIELDS, Δx, ZeroTangent()) + return (NoTangent(), Δx, ZeroTangent()) end return fbtestkws(x, y; err=err), fbtestkws_pullback end @@ -381,7 +381,7 @@ end function ChainRulesCore.rrule(::typeof(primalapprox), x) function primalapprox_pullback(Δx) - return (NO_FIELDS, Δx) + return (NoTangent(), Δx) end return x + sqrt(eps(x)), primalapprox_pullback end @@ -391,21 +391,21 @@ end end @testset "frule with mutation" begin - function ChainRulesCore.frule((_, ẋ), ::typeof(finplace!), x; y=[1]) + function ChainRulesCore.frule((_, ẋ), ::typeof(finplace!), x; y=[1]) y[1] *= 2 x .*= y[1] - ẋ .*= 2 # hardcoded to match y defined below - return x, ẋ + ẋ .*= 2 # hardcoded to match y defined below + return x, ẋ end # these pass in tangents explictly so that we can check them after x = randn(3) - ẋ = [4.0, 5.0, 6.0] - xcopy, ẋcopy = copy(x), copy(ẋ) + ẋ = [4.0, 5.0, 6.0] + xcopy, ẋcopy = copy(x), copy(ẋ) y = [1, 2] - test_frule(finplace!, x ⊢ ẋ; fkwargs=(y=y,)) + test_frule(finplace!, x ⊢ ẋ; fkwargs=(y=y,)) @test x == xcopy - @test ẋ == ẋcopy + @test ẋ == ẋcopy @test y == [1, 2] end @@ -450,7 +450,7 @@ end ∂iter = TestIterator( ∂data, Base.IteratorSize(iter), Base.IteratorEltype(iter) ) - return (NO_FIELDS, ∂iter) + return (NoTangent(), ∂iter) end return iterfun(iter), iterfun_pullback end @@ -471,7 +471,7 @@ end end function ChainRulesCore.rrule(::typeof(my_identity1), x) function identity_pullback(ȳ) - return (NO_FIELDS, ȳ) + return (NoTangent(), ȳ) end return 2.5 * x, identity_pullback end @@ -487,7 +487,7 @@ end end function ChainRulesCore.rrule(::typeof(my_identity2), x) function identity_pullback(ȳ) - return (NO_FIELDS, 31.8 * ȳ) + return (NoTangent(), 31.8 * ȳ) end return x, identity_pullback end @@ -505,7 +505,7 @@ end rev_trouble((x, y)) = y function ChainRulesCore.rrule(::typeof(rev_trouble), (x, y)::P) where {P} - rev_trouble_pullback(ȳ) = (NO_FIELDS, Tangent{P}(ZeroTangent(), ȳ)) + rev_trouble_pullback(ȳ) = (NoTangent(), Tangent{P}(ZeroTangent(), ȳ)) return y, rev_trouble_pullback end test_rrule(rev_trouble, (3, 3.0) ⊢ Tangent{Tuple{Int,Float64}}(ZeroTangent(), 1.0)) @@ -517,7 +517,7 @@ end function foo_pullback(Δy) da = zeros(size(a)) da[i] = Δy - return NO_FIELDS, da, ZeroTangent() + return NoTangent(), da, ZeroTangent() end return foo(a, i), foo_pullback end