Skip to content

Commit

Permalink
introduce trilinear upsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Freudenberg committed Apr 19, 2021
1 parent 2c3bdb8 commit 3e557aa
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 26 deletions.
214 changes: 191 additions & 23 deletions lib/NNlibCUDA/src/upsample.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#
# Upsampling
#
Expand Down Expand Up @@ -44,7 +43,16 @@
# Forward and backward pass have been tested to produce the same output
# as pytorch with align_corners=True - it works modulo bit noise.

function upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, x, y)
@inline function compute_source_index(ratio::T, dst_index, align_corners) where T
if align_corners
return ratio*dst_index
else
src_idx = ratio * (dst_index + T(0.5)) - T(0.5)
return max(zero(T), src_idx)
end
end

function upsample_bilinear_whcn_kernel!(n_elem, rwidth, rheight, x, y, align_corners)
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x

if index < n_elem
Expand All @@ -54,15 +62,17 @@ function upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, x, y)
ow = index % out_w
oh = index ÷ out_w

real_index = rheight*oh
# real_index = rheight*oh
real_index = compute_source_index(rheight, oh, align_corners)
ih0 = Base.floor(Int, real_index)
offset = (ih0 < in_h-1) ? 1 : 0
ih1 = ih0 + offset + 1
h1lambda = real_index - ih0
h0lambda = 1 - h1lambda
ih0 += 1

real_index = rwidth*ow
# real_index = rwidth*ow
real_index = compute_source_index(rwidth, ow, align_corners)
iw0 = Base.floor(Int, real_index)
offset = (iw0 < in_w-1) ? 1 : 0
iw1 = iw0 + offset + 1
Expand All @@ -84,7 +94,7 @@ function upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, x, y)
end

# Δ is the gradient backpropagated from downstream layers
function ∇upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, Δ, dx)
function ∇upsample_bilinear_whcn_kernel!(n_elem, rwidth, rheight, Δ, dx, align_corners)
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x

if index < n_elem
Expand All @@ -95,7 +105,8 @@ function ∇upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, Δ, dx)
ih = index ÷ in_width

# Compute Y axis lambdas
real_index_h = rheight*ih
# real_index_h = rheight*ih
real_index_h = compute_source_index(rheight, ih, align_corners)
oh0 = Base.floor(Int, real_index_h)
offset = (oh0 < out_height-1) ? 1 : 0
oh1 = oh0 + offset + 1
Expand All @@ -104,7 +115,8 @@ function ∇upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, Δ, dx)
oh0 += 1

# # Compute X axis lambdas
real_index_w = rwidth * iw
# real_index_w = rwidth * iw
real_index_w = compute_source_index(rwidth, iw, align_corners)
ow0 = Base.floor(Int, real_index_w)
offset = (ow0 < out_width - 1) ? 1 : 0
ow1 = ow0 + offset + 1
Expand All @@ -125,33 +137,189 @@ function ∇upsample_bilinear_whcn_kernel!(n_elem, rheight, rwidth, Δ, dx)
return nothing
end

function NNlib.upsample_bilinear_whcn!(y::CuArray{T,4}, x::CuArray{T,4}) where T
w,h,c,n = size(x)
out_w, out_h = (size(y,1), size(y,2))
function NNlib.upsample_bilinear_whcn!(y::CuArray{T,4}, x::CuArray{T,4}; align_corners=true) where T
out_size = prod(size(y)[1:2]) # w*h

out_size = out_h*out_w
rheight = T((h-1)/(out_h-1))
rwidth = T((w-1)/(out_w-1))
if align_corners
ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), 2)
else
ratios = ntuple(i -> T(size(x,i) / size(y,i)), 2)
end

kernel = @cuda launch=false upsample_bilinear_whcn_kernel!(out_size, rheight, rwidth, x, y)
kernel = @cuda launch=false upsample_bilinear_whcn_kernel!(out_size, ratios..., x, y, align_corners)
config = launch_configuration(kernel.fun; max_threads=256)
threads = Base.min(out_size, config.threads)
blocks = cld(out_size, threads)
kernel(out_size, rheight, rwidth, x, y; threads=threads, blocks=blocks)
kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks)
return y
end

function NNlib.∇upsample_bilinear_whcn!(dx::CuArray{T,4}, Δ::CuArray{T,4}) where T
w,h,c,n = Base.size(Δ)
out_w, out_h = (size(dx, 1), size(dx, 2))
in_size = h*w
rheight = T((out_h-1)/(h-1)) # reversed compared to forward pass
rwidth = T((out_w-1)/(w-1))
function NNlib.∇upsample_bilinear_whcn!(dx::CuArray{T,4}, Δ::CuArray{T,4}; align_corners=true) where T
in_size = prod(size(Δ)[1:2])
if align_corners
ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), 2) # reversed compared to forward pass
else
ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), 2)
end

kernel = @cuda launch=false ∇upsample_bilinear_whcn_kernel!(in_size, ratios..., Δ, dx, align_corners)
config = launch_configuration(kernel.fun; max_threads=256)
threads = Base.min(in_size, config.threads)
blocks = cld(in_size, threads)
kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks)
return dx
end


###########
# trilinear
###########
function upsample_trilinear_whdcn_kernel!(n_elem, rwidth, rheight, rdepth, x, y, align_corners)
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x

if index < n_elem
in_w, in_h, in_d, channels, batchsize = size(x)
out_w, out_h, out_d, _, _ = size(y)

ow = (index % (out_w * out_h)) % out_w
oh = (index % (out_w * out_h)) ÷ out_w
od = index ÷ (out_w * out_h)

# real_index = rwidth*ow
real_index = compute_source_index(rwidth, ow, align_corners)
iw0 = Base.floor(Int, real_index)
offset = (iw0 < in_w-1) ? 1 : 0
iw1 = iw0 + offset + 1
w1lambda = real_index - iw0
w0lambda = 1 - w1lambda
iw0 += 1

# real_index = rheight*oh
real_index = compute_source_index(rheight, oh, align_corners)
ih0 = Base.floor(Int, real_index)
offset = (ih0 < in_h-1) ? 1 : 0
ih1 = ih0 + offset + 1
h1lambda = real_index - ih0
h0lambda = 1 - h1lambda
ih0 += 1

# real_index = rdepth*od
real_index = compute_source_index(rdepth, od, align_corners)
id0 = Base.floor(Int, real_index)
offset = (id0 < in_d-1) ? 1 : 0
id1 = id0 + offset + 1
d1lambda = real_index - id0
d0lambda = 1 - d1lambda
id0 += 1

@inbounds for n in 1:batchsize
for c in 1:channels
val = d0lambda *
(h0lambda *
(w0lambda * x[iw0, ih0, id0, c, n] +
w1lambda * x[iw1, ih0, id0, c, n]) +
h1lambda *
(w0lambda * x[iw0, ih1, id0, c, n] +
w1lambda * x[iw1, ih1, id0, c, n])) +
d1lambda *
(h0lambda *
(w0lambda * x[iw0, ih0, id1, c, n] +
w1lambda * x[iw1, ih0, id1, c, n]) +
h1lambda *
(w0lambda * x[iw0, ih1, id1, c, n] +
w1lambda * x[iw1, ih1, id1, c, n]))

y[ow+1, oh+1, od+1, c, n] = val
end
end
end
return nothing
end

kernel = @cuda launch=false ∇upsample_bilinear_whcn_kernel!(in_size, rheight, rwidth, Δ, dx)
# Δ is the gradient backpropagated from downstream layers
function ∇upsample_trilinear_whdcn_kernel!(n_elem, rwidth, rheight, rdepth, Δ, dx, align_corners)
index = (threadIdx().x - 1) + (blockIdx().x - 1) * blockDim().x

if index < n_elem
in_width, in_height, in_depth, channels, batchsize = size(Δ)
out_width, out_height, out_depth, _, _ = size(dx)

iw = (index % (in_height * in_width)) % in_width
ih = (index % (in_height * in_width)) ÷ in_width
id = index ÷ (in_height * in_width)

real_index_w = compute_source_index(rwidth, iw, align_corners)
ow0 = Base.floor(Int, real_index_w)
offset = (ow0 < out_width - 1) ? 1 : 0
ow1 = ow0 + offset + 1
w1lambda = real_index_w - ow0
w0lambda = 1 - w1lambda
ow0 += 1

real_index_h = compute_source_index(rheight, ih, align_corners)
oh0 = Base.floor(Int, real_index_h)
offset = (oh0 < out_height-1) ? 1 : 0
oh1 = oh0 + offset + 1
h1lambda = real_index_h - oh0
h0lambda = 1 - h1lambda
oh0 += 1

real_index_d = compute_source_index(rdepth, id, align_corners)
od0 = Base.floor(Int, real_index_d)
offset = (od0 < out_depth-1) ? 1 : 0
od1 = od0 + offset + 1
d1lambda = real_index_d - od0
d0lambda = 1 - d1lambda
od0 += 1

@inbounds for n in 1:batchsize
for c in 1:channels
val = Δ[iw+1, ih+1, id+1, c, n]
@atomic dx[ow0, oh0, od0, c, n] += w0lambda * h0lambda * d0lambda * val
@atomic dx[ow1, oh0, od0, c, n] += w1lambda * h0lambda * d0lambda * val
@atomic dx[ow0, oh1, od0, c, n] += w0lambda * h1lambda * d0lambda * val
@atomic dx[ow1, oh1, od0, c, n] += w1lambda * h1lambda * d0lambda * val

@atomic dx[ow0, oh0, od1, c, n] += w0lambda * h0lambda * d1lambda * val
@atomic dx[ow1, oh0, od1, c, n] += w1lambda * h0lambda * d1lambda * val
@atomic dx[ow0, oh1, od1, c, n] += w0lambda * h1lambda * d1lambda * val
@atomic dx[ow1, oh1, od1, c, n] += w1lambda * h1lambda * d1lambda * val
end
end
end # if
return nothing
end

function NNlib.upsample_trilinear_whdcn!(y::CuArray{T,5}, x::CuArray{T,5}; align_corners=true) where T
out_size = prod(size(y)[1:3]) # w*h*d

if align_corners
ratios = ntuple(i -> T((size(x,i)-1) / (size(y,i)-1)), 3)
else
ratios = ntuple(i -> T(size(x,i) / size(y,i)), 3)
end

kernel = @cuda launch=false upsample_trilinear_whdcn_kernel!(out_size, ratios..., x, y, align_corners)
config = launch_configuration(kernel.fun; max_threads=256)
threads = Base.min(out_size, config.threads)
blocks = cld(out_size, threads)
kernel(out_size, ratios..., x, y, align_corners; threads=threads, blocks=blocks)
return y
end

function NNlib.∇upsample_trilinear_whdcn!(dx::CuArray{T,5}, Δ::CuArray{T,5}; align_corners=true) where T
in_size = prod(size(Δ)[1:3])

if align_corners
ratios = ntuple(i -> T((size(dx,i)-1) / (size(Δ,i)-1)), 3) # reversed compared to forward pass
else
ratios = ntuple(i -> T(size(dx,i) / size(Δ,i)), 3)
end

kernel = @cuda launch=false ∇upsample_trilinear_whdcn_kernel!(in_size, ratios..., Δ, dx, align_corners)
config = launch_configuration(kernel.fun; max_threads=256)
threads = Base.min(in_size, config.threads)
blocks = cld(in_size, threads)
kernel(in_size, rheight, rwidth, Δ, dx; threads=threads, blocks=blocks)
kernel(in_size, ratios..., Δ, dx, align_corners; threads=threads, blocks=blocks)
return dx
end
34 changes: 32 additions & 2 deletions lib/NNlibCUDA/test/upsample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
x = cat(x,x; dims=3)
x = cat(x,x; dims=4)
xgpu = cu(x)

y_true = Float32[ 1//1 4//3 5//3 2//1;
7//5 26//15 31//15 12//5;
9//5 32//15 37//15 14//5;
Expand All @@ -13,7 +13,7 @@
y_true = cat(y_true,y_true; dims=3)
y_true = cat(y_true,y_true; dims=4)
y_true_gpu = cu(y_true)

y = upsample_bilinear(xgpu, (3,2))
@test size(y) == size(y_true_gpu)
@test eltype(y) == Float32
Expand All @@ -25,3 +25,33 @@

gputest(x -> upsample_bilinear(x, (3, 2)), x, atol=1e-5)
end

@testset "Trilinear upsampling" begin
# Layout: WHDCN, where D is depth
# we generate data which is constant along W & H and differs in D
# then we upsample along all dimensions
x = CUDA.ones(Float32, 3,3,3,1,1)
x[:,:,1,:,:] .= 1.
x[:,:,2,:,:] .= 2.
x[:,:,3,:,:] .= 3.

y_true = CUDA.ones(Float32, 5,5,5,1,1)
y_true[:,:,1,:,:] .= 1.
y_true[:,:,2,:,:] .= 1.5
y_true[:,:,3,:,:] .= 2.
y_true[:,:,4,:,:] .= 2.5
y_true[:,:,5,:,:] .= 3.

y = upsample_trilinear(x; size=(5,5,5))

@test size(y) == size(y_true)
@test eltype(y) == Float32
@test collect(y) collect(y_true)

# this test only works when align_corners=false
# o = CUDA.ones(Float32,8,8,8,1,1)
# grad_true = 8*CUDA.ones(Float32,4,4,4,1,1)
# @test ∇upsample_trilinear(o; size=(4,4,4)) ≈ grad_true

gputest(x -> upsample_trilinear(x, (2,2,2)), x, atol=1e-5)
end
Loading

0 comments on commit 3e557aa

Please sign in to comment.