Skip to content

Commit

Permalink
replace NO_FIELDS with NoTangent() (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzgubic committed Jun 1, 2021
1 parent c405d75 commit 203fce0
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 51 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.6.15"
version = "0.7.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "0.9.44"
ChainRulesCore = "0.10"
Compat = "3"
FiniteDifferences = "0.12"
julia = "1"
14 changes: 7 additions & 7 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.44"
version = "0.10.1"

[[ChainRulesTestUtils]]
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
path = ".."
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.6.13"
version = "0.7.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down Expand Up @@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[FiniteDifferences]]
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
git-tree-sha1 = "8662836e29702fdfdb1b90cbe4162e31b94f1e51"
git-tree-sha1 = "f8c8e287c1d68abc2719ad58fb39de9f6c0d71b1"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.7"
version = "0.12.10"

[[IOCapture]]
deps = ["Logging"]
Expand Down Expand Up @@ -167,9 +167,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "a1f226ebe197578c25fcf948bfff3d0d12f2ff20"
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.1"
version = "1.2.2"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ and `rrule`
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
Expand Down
6 changes: 3 additions & 3 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ function test_frule(
xs = primal.(xẋs)
ẋs = tangent.(xẋs)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
end
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
Ω_ad, dΩ_ad = res
Expand Down Expand Up @@ -190,7 +190,7 @@ function test_rrule(
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS # No internal fields
@test ∂self === NoTangent() # No internal fields

# Correctness testing via finite differencing.
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
Expand Down
16 changes: 8 additions & 8 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
end
function ChainRulesCore.rrule(::typeof(identity), x)
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
return (NoTangent(), ȳ)
end
return x, identity_pullback
end
Expand All @@ -53,7 +53,7 @@ end
# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
# in the rrule
function ChainRulesCore.rrule(::typeof(sinconj), x)
sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ)
sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ)
return sin(x), sinconj_pullback
end

Expand All @@ -66,7 +66,7 @@ end
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
function ChainRulesCore.rrule(::typeof(fst), x, y)
function fst_pullback(Δx)
return (NO_FIELDS, Δx, ZeroTangent())
return (NoTangent(), Δx, ZeroTangent())
end
return x, fst_pullback
end
Expand All @@ -83,7 +83,7 @@ end
@testset "single input, multiple output" begin
simo(x) = (x, 2x)
function ChainRulesCore.rrule(simo, x)
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b)
return simo(x), simo_pullback
end
function ChainRulesCore.frule((_, ẋ), simo, x)
Expand All @@ -106,7 +106,7 @@ end
ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx))
function ChainRulesCore.rrule(::typeof(first), x::Tuple)
function first_pullback(Δx)
return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
end
return first(x), first_pullback
end
Expand Down Expand Up @@ -142,7 +142,7 @@ end
ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx)
function ChainRulesCore.rrule(::typeof(fsymtest), x, s)
function fsymtest_pullback(Δx)
return NO_FIELDS, Δx, NoTangent()
return NoTangent(), Δx, NoTangent()
end
return x, fsymtest_pullback
end
Expand All @@ -164,7 +164,7 @@ end
end
function ChainRulesCore.rrule(::typeof(futestkws), x; err=true)
function futestkws_pullback(Δx)
return (NO_FIELDS, Δx)
return (NoTangent(), Δx)
end
return futestkws(x; err=err), futestkws_pullback
end
Expand Down Expand Up @@ -198,7 +198,7 @@ end
end
function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true)
function fbtestkws_pullback(Δx)
return (NO_FIELDS, Δx, ZeroTangent())
return (NoTangent(), Δx, ZeroTangent())
end
return fbtestkws(x, y; err=err), fbtestkws_pullback
end
Expand Down
60 changes: 30 additions & 30 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
end
function ChainRulesCore.rrule(::typeof(identity), x)
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
return (NoTangent(), ȳ)
end
return x, identity_pullback
end
Expand All @@ -67,7 +67,7 @@ end
x̄_ret = InplaceableThunk(
@thunk(ȳ), ā -> (inplace_used = true; ā .+= ȳ)
)
return (NO_FIELDS, x̄_ret)
return (NoTangent(), x̄_ret)
end
return identity(x), identity_pullback
end
Expand All @@ -93,7 +93,7 @@ end
function my_identity_pullback(ȳ)
# only the in-place part is incorrect
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ)
return (NO_FIELDS, x̄_ret)
return (NoTangent(), x̄_ret)
end
return my_identity(x), my_identity_pullback
end
Expand All @@ -106,7 +106,7 @@ end
@testset "check inferred" begin
ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable), x) = (x, Δx)
function ChainRulesCore.rrule(::typeof(f_inferrable), x)
f_inferrable_pullback(Δy) = (NO_FIELDS, Δy)
f_inferrable_pullback(Δy) = (NoTangent(), Δy)
return x, f_inferrable_pullback
end

Expand All @@ -123,7 +123,7 @@ end
return (x, x > 0 ? Float64(Δx) : Float32(Δx))
end
function ChainRulesCore.rrule(::typeof(f_noninferrable_frule), x)
f_noninferrable_frule_pullback(Δy) = (NO_FIELDS, Δy)
f_noninferrable_frule_pullback(Δy) = (NoTangent(), Δy)
return x, f_noninferrable_frule_pullback
end

Expand All @@ -144,10 +144,10 @@ end
ChainRulesCore.frule((_, Δx), ::typeof(f_noninferrable_rrule), x) = (x, Δx)
function ChainRulesCore.rrule(::typeof(f_noninferrable_rrule), x)
if x > 0
f_noninferrable_rrule_pullback(Δy) = (NO_FIELDS, Δy)
f_noninferrable_rrule_pullback(Δy) = (NoTangent(), Δy)
return x, f_noninferrable_rrule_pullback
else
return x, _ -> (NO_FIELDS, Δy) # this is not hit by the used point
return x, _ -> (NoTangent(), Δy) # this is not hit by the used point
end
end

Expand All @@ -167,7 +167,7 @@ end
@testset "check not inferred in pullback" begin
function ChainRulesCore.rrule(::typeof(f_noninferrable_pullback), x)
function f_noninferrable_pullback_pullback(Δy)
return (NO_FIELDS, x > 0 ? Float64(Δy) : Float32(Δy))
return (NoTangent(), x > 0 ? Float64(Δy) : Float32(Δy))
end
return x, f_noninferrable_pullback_pullback
end
Expand All @@ -182,7 +182,7 @@ end
function ChainRulesCore.rrule(::typeof(f_noninferrable_thunk), x, y)
function f_noninferrable_thunk_pullback(Δz)
∂x = @thunk(x > 0 ? Float64(Δz) : Float32(Δz))
return (NO_FIELDS, ∂x, Δz)
return (NoTangent(), ∂x, Δz)
end
return x + y, f_noninferrable_thunk_pullback
end
Expand All @@ -198,7 +198,7 @@ end
return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx))
end
function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x)
f_inferrable_pullback_only_pullback(Δy) = (NO_FIELDS, oftype(x, Δy))
f_inferrable_pullback_only_pullback(Δy) = (NoTangent(), oftype(x, Δy))
return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback
end
test_frule(f_inferrable_pullback_only, 2.0; check_inferred=true)
Expand All @@ -212,7 +212,7 @@ end
# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
# in the rrule
function ChainRulesCore.rrule(::typeof(sinconj), x)
sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ)
sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ)
return sin(x), sinconj_pullback
end

Expand All @@ -225,7 +225,7 @@ end
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
function ChainRulesCore.rrule(::typeof(fst), x, y)
function fst_pullback(Δx)
return (NO_FIELDS, Δx, ZeroTangent())
return (NoTangent(), Δx, ZeroTangent())
end
return x, fst_pullback
end
Expand All @@ -242,7 +242,7 @@ end
@testset "single input, multiple output" begin
simo(x) = (x, 2x)
function ChainRulesCore.rrule(simo, x)
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b)
return simo(x), simo_pullback
end
function ChainRulesCore.frule((_, ẋ), simo, x)
Expand All @@ -264,7 +264,7 @@ end
ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx))
function ChainRulesCore.rrule(::typeof(first), x::Tuple)
function first_pullback(Δx)
return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
end
return first(x), first_pullback
end
Expand Down Expand Up @@ -294,7 +294,7 @@ end
ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx)
function ChainRulesCore.rrule(::typeof(fsymtest), x, s)
function fsymtest_pullback(Δx)
return NO_FIELDS, Δx, NoTangent()
return NoTangent(), Δx, NoTangent()
end
return x, fsymtest_pullback
end
Expand All @@ -314,7 +314,7 @@ end
end
function ChainRulesCore.rrule(::typeof(futestkws), x; err=true)
function futestkws_pullback(Δx)
return (NO_FIELDS, Δx)
return (NoTangent(), Δx)
end
return futestkws(x; err=err), futestkws_pullback
end
Expand Down Expand Up @@ -348,7 +348,7 @@ end
end
function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true)
function fbtestkws_pullback(Δx)
return (NO_FIELDS, Δx, ZeroTangent())
return (NoTangent(), Δx, ZeroTangent())
end
return fbtestkws(x, y; err=err), fbtestkws_pullback
end
Expand Down Expand Up @@ -381,7 +381,7 @@ end

function ChainRulesCore.rrule(::typeof(primalapprox), x)
function primalapprox_pullback(Δx)
return (NO_FIELDS, Δx)
return (NoTangent(), Δx)
end
return x + sqrt(eps(x)), primalapprox_pullback
end
Expand All @@ -391,21 +391,21 @@ end
end

@testset "frule with mutation" begin
function ChainRulesCore.frule((_, ), ::typeof(finplace!), x; y=[1])
function ChainRulesCore.frule((_, ), ::typeof(finplace!), x; y=[1])
y[1] *= 2
x .*= y[1]
.*= 2 # hardcoded to match y defined below
return x,
.*= 2 # hardcoded to match y defined below
return x,
end

# these pass in tangents explictly so that we can check them after
x = randn(3)
= [4.0, 5.0, 6.0]
xcopy, ẋcopy = copy(x), copy()
= [4.0, 5.0, 6.0]
xcopy, ẋcopy = copy(x), copy()
y = [1, 2]
test_frule(finplace!, x ; fkwargs=(y=y,))
test_frule(finplace!, x ; fkwargs=(y=y,))
@test x == xcopy
@test == ẋcopy
@test == ẋcopy
@test y == [1, 2]
end

Expand Down Expand Up @@ -450,7 +450,7 @@ end
∂iter = TestIterator(
∂data, Base.IteratorSize(iter), Base.IteratorEltype(iter)
)
return (NO_FIELDS, ∂iter)
return (NoTangent(), ∂iter)
end
return iterfun(iter), iterfun_pullback
end
Expand All @@ -471,7 +471,7 @@ end
end
function ChainRulesCore.rrule(::typeof(my_identity1), x)
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
return (NoTangent(), ȳ)
end
return 2.5 * x, identity_pullback
end
Expand All @@ -487,7 +487,7 @@ end
end
function ChainRulesCore.rrule(::typeof(my_identity2), x)
function identity_pullback(ȳ)
return (NO_FIELDS, 31.8 * ȳ)
return (NoTangent(), 31.8 * ȳ)
end
return x, identity_pullback
end
Expand All @@ -505,7 +505,7 @@ end

rev_trouble((x, y)) = y
function ChainRulesCore.rrule(::typeof(rev_trouble), (x, y)::P) where {P}
rev_trouble_pullback(ȳ) = (NO_FIELDS, Tangent{P}(ZeroTangent(), ȳ))
rev_trouble_pullback(ȳ) = (NoTangent(), Tangent{P}(ZeroTangent(), ȳ))
return y, rev_trouble_pullback
end
test_rrule(rev_trouble, (3, 3.0) Tangent{Tuple{Int,Float64}}(ZeroTangent(), 1.0))
Expand All @@ -517,7 +517,7 @@ end
function foo_pullback(Δy)
da = zeros(size(a))
da[i] = Δy
return NO_FIELDS, da, ZeroTangent()
return NoTangent(), da, ZeroTangent()
end
return foo(a, i), foo_pullback
end
Expand Down

4 comments on commit 203fce0

@mzgubic
Copy link
Member Author

@mzgubic mzgubic commented on 203fce0 Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37968

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" 203fce008b55540b7e86ba6da5d345939d176a46
git push origin v0.7.0

@mzgubic
Copy link
Member Author

@mzgubic mzgubic commented on 203fce0 Jun 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/37968

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" 203fce008b55540b7e86ba6da5d345939d176a46
git push origin v0.7.0

Please sign in to comment.