Skip to content

Commit

Permalink
Fix ConvTranspose symmetric non-constant padding (#2463)
Browse files Browse the repository at this point in the history
The ConvTranspose was not able to handle symmetric non-constant
padding (ie, `pad=(1, 0)` for 2D ConvTranspose). Constant padding (ie
`pad=1` for 2D ConvTranspose) and assymetric non-constant padding (ie,
`pad=(1, 0, 2, 3)`) worked correctly. This commit fixes symmetric
non-constant padding and adds unit tests to ensure it produces the same
output size as an equivalent fully expanded padding.

Fixes #2424
  • Loading branch information
paulnovo committed Aug 1, 2024
1 parent 942c6e5 commit 7c1ef13
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
10 changes: 9 additions & 1 deletion ext/FluxAMDGPUExt/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray
end

function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray
# Calculate combined pad in each dimension
nd = ndims(x) - 2
if length(c.pad) == nd
# Handle symmetric non-constant padding
combined_pad = ntuple(i -> 2 * c.pad[i], nd)
else
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd)
end

# Calculate size of "input", from ∇conv_data()'s perspective...
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+
(size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad .+ c.outpad
C_in = size(c.weight)[end - 1] * c.groups
Expand Down
11 changes: 10 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,21 @@ end
@layer ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate combined pad in each dimension
nd = ndims(x) - 2
if length(c.pad) == nd
# Handle symmetric non-constant padding
combined_pad = ntuple(i -> 2 * c.pad[i], nd)
else
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], nd)
end

# Calculate size of "input", from ∇conv_data()'s perspective...
calc_dim(xsz, wsz, stride, dilation, pad, outpad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + outpad
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2)
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad, c.outpad)
C_in = size(c.weight)[end-1] * c.groups
batch_size = size(x)[end]

# Create DenseConvDims() that looks like the corresponding conv()
w_size = size(c.weight)
return DenseConvDims((I..., C_in, batch_size), w_size;
Expand Down
15 changes: 15 additions & 0 deletions test/ext_amdgpu/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ end
end
end

@testset "Convolution with symmetric non-constant padding" begin
for conv_type in (Conv, ConvTranspose), nd in 1:3
kernel = tuple(fill(2, nd)...)
x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu

pad = ntuple(i -> i, nd)
m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu

@test size(m(x)) == size(m_expanded(x))
end
end

@testset "ConvTranspose output padding" begin
x = randn(Float32, 10, 11, 3, 2)
m = ConvTranspose((3, 5), 3=>6, stride=3, outpad=(1, 0))
Expand Down
13 changes: 13 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ end
end
end

@testset "$ltype $(nd)D symmetric non-constant padding" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), nd in (1, 2, 3)
kernel = ntuple(Returns(3), nd)
data = ones(Float32, (kernel .+ 5)..., 1,1)

pad = ntuple(i -> i, nd)
l = ltype(kernel, 1=>1, pad=pad)

expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
l_expanded = ltype(kernel, 1=>1, pad=expanded_pad)

@test size(l(data)) == size(l_expanded(data))
end

@testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8))
data = ones(Float32, (k .+ 3)..., 1,1)
l = ltype(k, 1=>1, pad=SamePad())
Expand Down

0 comments on commit 7c1ef13

Please sign in to comment.