From 2febf9f1737eeb48f8d841eba0c3b332adc78f7b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 28 Aug 2024 21:01:45 -0400 Subject: [PATCH 1/2] Fix Base.zero type output zero(x::T)::T is a standard that applies to pretty much any other array type, but TrackedArray fails to match the standard interfaces. This fixes that issue. The only major violation to where this behavior is expected is if you're trying to write a grad rule that's mutating, which really only shows up in rules libraries, and those are thus updated here. Note that there is an alternative implementation via `zero.(x)`, but this implementation drops the compute graph that isn't needed if you have a zero. --- Project.toml | 2 +- src/lib/array.jl | 13 +++++++------ test/tracker.jl | 5 +++++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index dd704ac..0cf5020 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Tracker" uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.34" +version = "0.2.35" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/lib/array.jl b/src/lib/array.jl index 2022ff4..aee6872 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -99,7 +99,7 @@ Base.getindex(xs::TrackedArray, i...; kwargs...) = track(getindex, xs, i...; kwa @grad function getindex(xs::AbstractArray, i...; kwargs...) getindex(data(xs), i...; kwargs...), function (Δ) - Δ′ = zero(xs) + Δ′ = zero(data(xs)) setindex!(Δ′, data(Δ), i...; kwargs...) (nobacksies(:getindex, Δ′), map(_->nothing, i)...) end @@ -107,7 +107,7 @@ end @grad function getindex(xs::AbstractArray, i::Array...) data(xs)[i...], function (Δ) - Δ′ = zero(xs) + Δ′ = zero(data(xs)) @views Δ′[i...] .+= data(Δ) (nobacksies(:getindex, Δ′), map(_->nothing, i)...) end @@ -117,7 +117,7 @@ Base.view(x::TrackedArray, inds...; kwargs...) = track(Base.view, x, inds...; kw @grad function view(x::AbstractArray, inds...; kwargs...) view(data(x), inds...; kwargs...), function (Δ) - grad_output = zero(x) + grad_output = zero(data(x)) subgrad = view(grad_output, inds...; kwargs...) subgrad[:] = data(Δ) (nobacksies(:view, grad_output), map(_->nothing, inds)...) @@ -144,10 +144,11 @@ logabsdet(xs::TrackedArray) = track(logabsdet, xs) @grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) +Base.zero(x::Tracker.TrackedArray) = zero.(x) @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) repeat(data(xs), inner = inner, outer = outer), function (Δ) - Δ′ = zero(xs) + Δ′ = zero(data(xs)) S = size(xs) # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ @@ -433,7 +434,7 @@ Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims) @grad function maximum(xs; dims = dims) maximum(data(xs), dims = dims), function (Δ) - Δ′ = zero(xs) + Δ′ = zero(data(xs)) _, i = findmax(data(xs), dims = dims) Δ′[i] = data(Δ) return (nobacksies(:maximum, Δ′),) @@ -442,7 +443,7 @@ end @grad function minimum(xs; dims = dims) minimum(data(xs), dims = dims), function (Δ) - Δ′ = zero(xs) + Δ′ = zero(data(xs)) _, i = findmin(data(xs), dims = dims) Δ′[i] = data(Δ) return (nobacksies(:minimum, Δ′),) diff --git a/test/tracker.jl b/test/tracker.jl index d17c259..1661355 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -52,6 +52,11 @@ RNG = NNlib.Random.MersenneTwister(1) end # @testset gradtests +@testset "zero" begin + @test zero(TrackedArray(rand(2))) isa TrackedArray + @test gradtest(x-> zero(x) .* x, (2,)) +end + @testset "indexing & slicing" begin @test gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) end From a816377efc485655ccb674f70865167c50ee10fe Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 28 Aug 2024 21:23:21 -0400 Subject: [PATCH 2/2] smaller computational graph --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index aee6872..231d61c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -144,7 +144,7 @@ logabsdet(xs::TrackedArray) = track(logabsdet, xs) @grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),) Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) -Base.zero(x::Tracker.TrackedArray) = zero.(x) +Base.zero(x::Tracker.TrackedArray) = TrackedArray(zero(x.data)) @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) repeat(data(xs), inner = inner, outer = outer), function (Δ)