Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove scan for pre-marking of nodes #166

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 127 additions & 51 deletions src/back.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
:(
try
$(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
)
end

# In-place gradients
Expand All @@ -14,54 +17,113 @@ init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
zero_grad!(x::AbstractArray) = (x .= 0)

scan(c::Call) = foreach(scan, c.args)

function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
end
return
end

function scan(x)
istracked(x) && scan(tracker(x))
return
end

function back_(c::Call, Δ, once)
# scan(c::Call) = foreach(scan, c.args)

# function scan(x::Tracked)
# x.isleaf && return
# ref = x.ref += 1
# if ref == 1
# scan(x.f)
# isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
# end
# return
# end

# function scan(x)
# istracked(x) && scan(tracker(x))
# return
# end

# function back_(c::Call, Δ, once)
# Δs = c.func(Δ)
# (Δs isa Tuple && length(Δs) >= length(c.args)) ||
# error("Gradient is not a tuple of length $(length(c.args))")
# foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
# end

function back_(c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, d) -> back(x, d, once), c.args, data.(Δs))
foreach((x, d) -> back(x, d), c.args, data.(Δs))
end

back_(::Call{Nothing}, Δ, once) = nothing
back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
# back_(::Call{Nothing}, Δ, once) = nothing
# back_(::Call{Missing}, Δ, once) = error("`back!` was already used")

back_(::Call{Nothing}, Δ) = nothing
back_(::Call{Missing}, Δ) = error("`back!` was already used")

accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)

function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
grad = if isdefined(x, :grad)
# function back(x::Tracked, Δ, once)
# x.isleaf && (x.grad = accum!(x.grad, Δ); return)
# ref = x.ref -= 1
# grad = if isdefined(x, :grad)
# x.grad = accum!(x.grad, Δ)
# elseif ref > 0
# x.grad = Δ
# else
# Δ
# end
# if ref == 0
# back_(x.f, grad, once)
# once && !x.isleaf && (x.f = Call(missing, ()))
# end
# return
# end


# function back(x::Tracked, Δ)
# # Increment the reference count
# x.ref += 1

# # Handle gradient accumulation and backpropagation based on the reference count
# if x.ref == 1
# # Node has no more references, perform backpropagation and reset gradient
# x.grad = Δ
# back_(x.f, Δ)
# else
# # Node already has additional references, accumulate gradient into the gradient buffer
# x.grad = accum!(x.grad, Δ)
# end

# # Decrement the reference count
# x.ref -= 1

# return
# end



function back(x::Tracked, Δ)
if x.isleaf
x.grad = accum!(x.grad, Δ)
return
end

x.ref -= 1
if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
elseif x.ref > 0
x.grad = Δ
else
Δ
x.grad = Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))

if x.ref == 0
Δs = x.f(Δ)
for (arg, d) in zip(x.args, Δs)
back(arg, d)
end
end
return
end

back(::Nothing, Δ, once) = return


# back(::Nothing, Δ, once) = return
back(::Nothing, Δ) = return

# Interface methods

Expand All @@ -71,10 +133,24 @@ back(::Nothing, Δ, once) = return
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)

function back!(x, Δ; once = true)

function back!(x, Δ)
istracked(x) || return
back(tracker(x), Δ)
end


# function back!(x, Δ; once=true)
# # back(tracker(x), Δ, once) # Call the back function starting from the tracker of x
# back(tracker(x), Δ) # Call the back function starting from the tracker of x
# return
# end

function back!(x, Δ; once=true)
istracked(x) || return
scan(x)
back(tracker(x), Δ, once)
# scan(x)
# back(tracker(x), Δ, once)
back(tracker(x), Δ)
return
end

Expand Down Expand Up @@ -161,7 +237,7 @@ function gradient_nested(f, args...)
return back(1)
end

gradient(f, xs...; nest = false) =
gradient(f, xs...; nest=false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)

# Jacobians and Hessians
Expand Down Expand Up @@ -219,22 +295,22 @@ julia> withgradient(model, rand(Float32, 2)) do m, x
```
"""
function withgradient(f, xs...)
pxs = fmap(param, xs; exclude = isnumeric, walk = _trainable_walk)
l = f(pxs...)
l1 = l isa Union{Tuple, NamedTuple} ? first(l) : l
val = l isa Union{Tuple, NamedTuple} ? fmap(data, l) : data(l)
losscheck(l1)
l1 isa TrackedReal || return (; val, grad = map(_ -> nothing, xs))
@interrupts back!(l1)
(; val, grad = rec_grad(pxs))
pxs = fmap(param, xs; exclude=isnumeric, walk=_trainable_walk)
l = f(pxs...)
l1 = l isa Union{Tuple,NamedTuple} ? first(l) : l
val = l isa Union{Tuple,NamedTuple} ? fmap(data, l) : data(l)
losscheck(l1)
l1 isa TrackedReal || return (; val, grad=map(_ -> nothing, xs))
@interrupts back!(l1)
(; val, grad=rec_grad(pxs))
end

function _trainable_walk(f, x)
func, re = functor(x)
isempty(func) && return x
done = map(f, _trainable(x)) # recurse only into trainable fields, this contains `nothing` elsewhere
map(func, merge(func, done)) do n, t
isnothing(t) ? n : t
isnothing(t) ? n : t
end |> re # reconstruct the whole thing
end
_trainable_walk(f, x::Tuple) = map(f, x)
Expand All @@ -247,9 +323,9 @@ rec_grad(x::Number) = nothing

rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x)
rec_grad(::Tuple{}) = nothing
rec_grad(::NamedTuple{(), Tuple{}}) = nothing
rec_grad(::NamedTuple{(),Tuple{}}) = nothing
function rec_grad(x::T) where {T}
F = fieldnames(T)
isempty(F) && return nothing
map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F))
F = fieldnames(T)
isempty(F) && return nothing
map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F))
end
Loading