Skip to content

Commit

Permalink
Fix UNet for 3D convolutions (specify ndim to convxlayer and ResBlock) (
Browse files Browse the repository at this point in the history
#263)

* Fix UNet for 3D convolutions (specify ndim to convxlayer and ResBlock)
  • Loading branch information
itan1 committed Oct 22, 2022
1 parent 767aa2b commit dee6399
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions FastVision/src/models/unet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ function UNetDynamic(backbone,
inputsize;
m_middle = UNetMiddleBlock,
skip_upscale = fdownscale,
kwargs...)
kwargs...)
outsz = Flux.outputsize(unet, inputsize)
return Chain(unet, final(outsz[end - 1], k_out))
return Chain(unet, final(outsz[end - 1], k_out, length(outsz) - 2))
end

function catchannels(x1, x2)
Expand All @@ -50,7 +50,7 @@ function catchannels(x1, x2)
end

function unetlayers(layers,
sz;
sz;
k_out = nothing,
skip_upscale = 0,
m_middle = _ -> (identity,))
Expand Down Expand Up @@ -81,7 +81,8 @@ function unetlayers(layers,
return UNetBlock(Chain(layer, childunet),
k_in, # Input channels to upsampling layer
k_mid,
k_out)
k_out,
length(outsz) - 2)
end
end

Expand All @@ -95,28 +96,28 @@ Given convolutional module `m` that halves the spatial dimensions
and outputs `k_in` filters, create a module that upsamples the
spatial dimensions and then aggregates features via a skip connection.
"""
function UNetBlock(m_child, k_in, k_mid, k_out = 2k_in)
function UNetBlock(m_child, k_in, k_mid, k_out = 2k_in, ndim = 2)
return Chain(upsample = SkipConnection(Chain(child = m_child, # Downsampling and processing
upsample = PixelShuffleICNR(k_mid, k_mid)),
upsample = PixelShuffleICNR(k_mid, k_mid, ndim)),
Parallel(catchannels, identity, BatchNorm(k_in))),
act = xs -> relu.(xs),
combine = UNetCombineLayer(k_in + k_mid, k_out))
combine = UNetCombineLayer(k_in + k_mid, k_out, ndim))
end

function PixelShuffleICNR(k_in, k_out; r = 2)
return Chain(convxlayer(k_in, k_out * (r^2), ks = 1), Flux.PixelShuffle(r))
function PixelShuffleICNR(k_in, k_out, ndim; r = 2)
return Chain(convxlayer(k_in, k_out * (r^ndim), ks = 1, ndim = ndim), Flux.PixelShuffle(r))
end

function UNetCombineLayer(k_in, k_out)
return Chain(convxlayer(k_in, k_out), convxlayer(k_out, k_out))
function UNetCombineLayer(k_in, k_out, ndim)
return Chain(convxlayer(k_in, k_out, ndim = ndim), convxlayer(k_out, k_out, ndim = ndim))
end

function UNetMiddleBlock(k)
return Chain(convxlayer(k, 2k), convxlayer(2k, k))
function UNetMiddleBlock(k, ndim)
return Chain(convxlayer(k, 2k, ndim = ndim), convxlayer(2k, k, ndim = ndim))
end

function UNetFinalBlock(k_in, k_out)
return Chain(ResBlock(1, k_in, k_in), convxlayer(k_in, k_out, ks = 1))
function UNetFinalBlock(k_in, k_out, ndim)
return Chain(ResBlock(1, k_in, k_in, ndim = ndim), convxlayer(k_in, k_out, ks = 1, ndim = ndim))
end

"""
Expand All @@ -139,4 +140,7 @@ end

model = UNetDynamic(Models.xresnet18(), (128, 128, 3, 1), 4, fdownscale = 1)
@test Flux.outputsize(model, (128, 128, 3, 1)) == (64, 64, 4, 1)

model = UNetDynamic(Models.xresnet18(ndim = 3), (128, 128, 128, 3, 1), 4)
@test Flux.outputsize(model, (128, 128, 128, 3, 1)) == (128, 128, 128, 4, 1)
end end

2 comments on commit dee6399

@lorenzoh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Pre-release version not allowed"

Please sign in to comment.