Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot run EfficientNetv2 on gpu #261

Closed
IanButterworth opened this issue Dec 6, 2023 · 7 comments
Closed

Cannot run EfficientNetv2 on gpu #261

IanButterworth opened this issue Dec 6, 2023 · 7 comments

Comments

@IanButterworth
Copy link
Contributor

julia> using Metalhead, CUDA, Flux

julia> model = EfficientNetv2(:small; pretrain=false, inchannels=1, nclasses=2);

julia> batch = rand(Float32, 224, 224, 1, 1);

julia> model(batch)
2×1 Matrix{Float32}:
 -6.368171f-11
  1.7750237f-10

julia> model(gpu(batch))
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f5312b82720.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:106
ERROR: TaskFailedException

    nested task error: TaskFailedException
    
        nested task error: MethodError: no method matching gemm!(::Val{false}, ::Val{false}, ::Int64, ::Int64, ::Int64, ::Float32, ::CuPtr{Float32}, ::Ptr{Float32}, ::Float32, ::CuPtr{Float32})
        
        Closest candidates are:
          gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32})
           @ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
          gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64})
           @ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
          gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64})
           @ NNlib ~/.julia/packages/NNlib/sXmAj/src/gemm.jl:29
          ...
        
        Stacktrace:
         [1] macro expansion
           @ ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:59 [inlined]
         [2] (::NNlib.var"#647#648"{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float32, Float32, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})()
           @ NNlib ./threadingconstructs.jl:416
    Stacktrace:
     [1] sync_end(c::Channel{Any})
       @ Base ./task.jl:445
     [2] macro expansion
       @ ./task.jl:477 [inlined]
     [3] conv_im2col!(y::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, alpha::Float32, beta::Float32, ntasks::Int64)
       @ NNlib ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:50
     [4] conv_im2col!(y::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, x::SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3})
       @ NNlib ~/.julia/packages/NNlib/sXmAj/src/impl/conv_im2col.jl:23
     [5] (::NNlib.var"#305#309"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}, SubArray{Float32, 5, CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ./threadingconstructs.jl:416
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:445
  [2] macro expansion
    @ ./task.jl:477 [inlined]
  [3] conv!(out::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, in1::CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, in2::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:205
  [4] conv!
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:185 [inlined]
  [5] conv!(y::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:145
  [6] conv!
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:140 [inlined]
  [7] conv(x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/sXmAj/src/conv.jl:88
  [8] conv
    @ ~/.julia/packages/NNlib/sXmAj/src/conv.jl:83 [inlined]
  [9] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool})(x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/conv.jl:202
 [10] macro expansion
    @ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53 [inlined]
 [11] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Chain{Tuple{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 5}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 8}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 14}}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}, x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53
 [12] (::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Chain{Tuple{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 5}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 8}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 14}}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}})(x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:51
 [13] macro expansion
    @ ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53 [inlined]
 [14] _applychain(layers::Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Chain{Tuple{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 5}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 8}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 14}}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:53
 [15] (::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Chain{Tuple{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 5}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 8}}}, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Vararg{Parallel{typeof(+), Tuple{Dropout{Float64, Int64, TaskLocalRNG}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}, SkipConnection{Chain{Tuple{AdaptiveMeanPool{4, 2}, Conv{2, 4, typeof(swish), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(hardσ), Array{Float32, 4}, Vector{Float32}}}}, Base.Broadcast.BroadcastFunction{typeof(*)}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, 14}}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(swish), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})(x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Flux ~/.julia/packages/Flux/jgpVj/src/layers/basic.jl:51
 [16] (::EfficientNetv2)(x::CuArray{Float32, 4, CUDA.Mem.DeviceBuffer})
    @ Metalhead ~/.julia/packages/Metalhead/NDEgB/src/convnets/efficientnets/efficientnetv2.jl:99
  [052768ef] CUDA v5.1.1
  [587475ba] Flux v0.14.7
  [dbeba491] Metalhead v0.9.0
  [02a925ec] cuDNN v1.2.1
@IanButterworth
Copy link
Contributor Author

@CarloLucibello could it be that the extensions aren't being loaded properly, even though gpu is putting things on the gpu

@ToucheSir
Copy link
Member

Make sure you've moved model to gpu as well ;)

@IanButterworth
Copy link
Contributor Author

Oh ok. Is there a way to catch that and warn the user?
Also the docs to Metalhead don't appear to mention gpu (even thought there are 3 results, but gpu isn't found on any of the pages)

@ToucheSir
Copy link
Member

ToucheSir commented Dec 6, 2023

Warning the user is probably best handled at the NNlib level, that's what FluxML/NNlib.jl#523 is tracking. We assume most users will have read through a page such as https://fluxml.ai/Flux.jl/stable/gpu/#Basic-GPU-Usage before using Metalhead (or any other Flux-based) models with GPUs. Those pages introduce the gpu function and show how to apply it to both inputs and models.

@IanButterworth
Copy link
Contributor Author

Fair enough. I came straight to metalhead though and the flux models I've used before have internally managed gpu state, so it wasn't clear to me.

I'll try to come up with a docs PR

@ToucheSir
Copy link
Member

ToucheSir commented Dec 8, 2023

I'm not aware of any models which manage GPU state internally for parameters but not also for inputs. Let alone ones which then point people to functions like gpu. Out of curiosity, do you have some examples? If this is happening in active downstream libraries, we could look into submitting some docs PRs.

@IanButterworth
Copy link
Contributor Author

ObjectDetector.jl is the one I had in mind. Note I'm a maintainer.. so I could be the issue!

The challenge with that is part of the post processing inside the model has been difficult to make runnable on the gpu.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants