Skip to content

Commit

Permalink
Added outpad argument to ConvTranspose (#2462)
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyrt committed Jun 30, 2024
1 parent 9061b79 commit 36abc73
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
26 changes: 18 additions & 8 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,18 @@ function _print_conv_opt(io::IO, l)
end

"""
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, outpad=0, dilation=1, [bias, init])
Standard convolutional transpose layer. `filter` is a tuple of integers
specifying the size of the convolutional kernel, while
`in` and `out` specify the number of input and output channels.
Note that `pad=SamePad()` here tries to ensure `size(output,d) == size(x,d) * stride`.
To conserve [`Conv`](@ref) inversability when `stride > 1`, `outpad` can be used to increase the size
of the output in the desired dimensions. Whereas `pad` is used to zero-pad the input,
`outpad` only affects the output shape.
Parameters are controlled by additional keywords, with defaults
`init=glorot_uniform` and `bias=true`.
Expand All @@ -250,6 +254,9 @@ julia> layer(xs) |> size
julia> ConvTranspose((5,5), 3 => 7, stride=2)(xs) |> size
(203, 203, 7, 50)
julia> ConvTranspose((5,5), 3 => 7, stride=2, outpad=1)(xs) |> size
(204, 204, 7, 50)
julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size
(300, 300, 7, 50)
```
Expand All @@ -260,6 +267,7 @@ struct ConvTranspose{N,M,F,A,V}
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
outpad::NTuple{N,Int}
dilation::NTuple{N,Int}
groups::Int
end
Expand All @@ -268,7 +276,7 @@ _channels_in(l::ConvTranspose) = size(l.weight)[end]
_channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups

"""
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation, groups])
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, outpad, dilation, groups])
Constructs a ConvTranspose layer with the given weight and bias.
Accepts the same keywords and has the same defaults as
Expand All @@ -291,31 +299,32 @@ julia> Flux.params(layer) |> length
```
"""
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1, groups=1) where {T,N}
stride = 1, pad = 0, outpad = 0, dilation = 1, groups = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
b = create_bias(w, bias, size(w, N-1) * groups)
return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
outpad = expand(Val(N-2), outpad)
return ConvTranspose(σ, w, b, stride, pad, outpad, dilation, groups)
end

function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
init = glorot_uniform, stride = 1, pad = 0, outpad = 0, dilation = 1,
groups = 1,
bias = true,
) where N

weight = convfilter(k, reverse(ch); init, groups)
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
ConvTranspose(weight, bias, σ; stride, pad, outpad, dilation, groups)
end

@layer ConvTranspose

function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
# Calculate size of "input", from ∇conv_data()'s perspective...
calc_dim(xsz, wsz, stride, dilation, pad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad
calc_dim(xsz, wsz, stride, dilation, pad, outpad) = (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + outpad
combined_pad = ntuple(i -> c.pad[2i-1] + c.pad[2i], length(c.pad) ÷ 2)
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad)
I = map(calc_dim, size(x)[1:end-2], size(c.weight)[1:end-2], c.stride, c.dilation, combined_pad, c.outpad)
C_in = size(c.weight)[end-1] * c.groups
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
Expand All @@ -342,6 +351,7 @@ function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", _channels_in(l), " => ", _channels_out(l))
_print_conv_opt(io, l)
all(==(0), l.outpad) || print(io, ", outpad=", l.outpad)
print(io, ")")
end

Expand Down
11 changes: 11 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ end

@test occursin("groups=2", sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))
@test occursin("2 => 4" , sprint(show, ConvTranspose((3,3), 2=>4, groups=2)))

# test ConvTranspose outpad argument for stride > 1
x = randn(Float32, 10, 11, 3,2)
m1 = ConvTranspose((3,5), 3=>6, stride=3)
m2 = ConvTranspose((3,5), 3=>6, stride=3, outpad=(1,0))
@test size(m2(x))[1:2] == (size(m1(x))[1:2] .+ (1,0))

x = randn(Float32, 10, 11, 12, 3,2)
m1 = ConvTranspose((3,5,3), 3=>6, stride=3)
m2 = ConvTranspose((3,5,3), 3=>6, stride=3, outpad=(1,0,1))
@test size(m2(x))[1:3] == (size(m1(x))[1:3] .+ (1,0,1))
end

@testset "CrossCor" begin
Expand Down

0 comments on commit 36abc73

Please sign in to comment.