Skip to content

Commit

Permalink
updated tabular model
Browse files Browse the repository at this point in the history
  • Loading branch information
manikyabard committed Jun 17, 2021
1 parent 4852dcf commit fa5c563
Showing 1 changed file with 17 additions and 43 deletions.
60 changes: 17 additions & 43 deletions src/models/tabularmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ function _one_emb_sz(catdict, catcol::Symbol, sz_dict=nothing)
sz_dict = isnothing(sz_dict) ? Dict() : sz_dict
n_cat = length(catdict[catcol])
sz = catcol in keys(sz_dict) ? sz_dict[catcol] : emb_sz_rule(n_cat)
n_cat, sz
Int64(n_cat), Int64(sz)
end

function get_emb_sz(catdict, cols, sz_dict=nothing)
function get_emb_sz(catdict, cols; sz_dict=nothing)
[_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
end

Expand All @@ -18,21 +18,12 @@ end
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
# end

struct TabularModel
embeds
emb_drop
bn_cont
n_emb
n_cont
layers
end

function TabularModel(
layers;
emb_szs,
n_cont,
n_cont::Int64,
out_sz,
ps::Union{Tuple, Vector, Number, Nothing}=nothing,
ps::Union{Tuple, Vector, Number}=0,
embed_p::Float64=0.,
y_range=nothing,
use_bn::Bool=true,
Expand All @@ -41,39 +32,22 @@ function TabularModel(
act_cls=Flux.relu,
lin_first::Bool=true)

n_cont = Int64(n_cont)
if isnothing(ps)
ps = zeros(length(layers))
end
if ps isa Number
ps = fill(ps, length(layers))
end
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
emb_drop = Dropout(embed_p)
bn_cont = bn_cont ? BatchNorm(n_cont) : false
embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop)

bn_cont = bn_cont ? BatchNorm(n_cont) : identity

n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
sizes = append!(zeros(0), [n_emb+n_cont], layers, [out_sz])
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing])
_layers = [linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=(use_bn && ((i!=(length(actns)-1)) || bn_final)), p=p, act=a, lin_first=lin_first) for (i, (p, a)) in enumerate(zip(push!(ps, 0.), actns))]
if !isnothing(y_range)
push!(_layers, Chain(@. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1]))
end
layers = Chain(_layers...)
TabularModel(embedslist, emb_drop, bn_cont, n_emb, n_cont, layers)
end
sizes = append!(zeros(0), [n_emb+n_cont], layers)
actns = append!([], [act_cls for i in 1:(length(sizes)-1)])

function (tm::TabularModel)(x)
x_cat, x_cont = x
if tm.n_emb != 0
x = [e(x_cat[i, :]) for (i, e) in enumerate(tm.embeds)]
x = vcat(x...)
x = tm.emb_drop(x)
_layers = []
for (i, (p, a)) in enumerate(zip(Iterators.cycle(ps), actns))
layer = linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=use_bn, p=p, act=a, lin_first=lin_first)
push!(_layers, layer)
end
if tm.n_cont != 0
if (tm.bn_cont != false)
x_cont = tm.bn_cont(x_cont)
end
x = tm.n_emb!=0 ? vcat(x, x_cont) : x_cont
end
tm.layers(x)
push!(_layers, linbndrop(Int64(last(sizes)), Int64(out_sz), use_bn=bn_final, lin_first=lin_first))
layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), _layers...) : Chain(Parallel(vcat, embeds, bn_cont), _layers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])
layers
end

0 comments on commit fa5c563

Please sign in to comment.