-
-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tabular model #124
Add tabular model #124
Changes from 12 commits
2c85ed4
97546c7
c1bd73b
2551fbb
ef11450
c0b2922
c2c95d5
a081616
f565675
04d27d4
e1c2263
b4d7149
bc250a1
506f889
725c6dd
59eb66a
d4fded0
96564f0
ddb4d62
825d146
979c9ba
8f8c65a
f2042a4
928509b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
function emb_sz_rule(n_cat) | ||
min(600, round(1.6 * n_cat^0.56)) | ||
end | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function _one_emb_sz(catdict, catcol::Symbol, sz_dict=nothing) | ||
sz_dict = isnothing(sz_dict) ? Dict() : sz_dict | ||
n_cat = length(catdict[catcol]) | ||
sz = catcol in keys(sz_dict) ? sz_dict[catcol] : emb_sz_rule(n_cat) | ||
Int64(n_cat)+1, Int64(sz) | ||
end | ||
|
||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function get_emb_sz(catdict, cols; sz_dict=nothing) | ||
[_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols] | ||
end | ||
|
||
function sigmoidrange(x, low, high) | ||
@. Flux.sigmoid(x) * (high - low) + low | ||
end | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function embeddingbackbone(embedding_sizes, dropoutprob=0.) | ||
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] | ||
emb_drop = dropoutprob==0. ? identity : Dropout(dropoutprob) | ||
Chain( | ||
x -> tuple(eachrow(x)...), | ||
Parallel(vcat, embedslist), | ||
emb_drop | ||
) | ||
end | ||
|
||
function continuousbackbone(n_cont) | ||
n_cont > 0 ? BatchNorm(n_cont) : identity | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know how useful it is to have this function. |
||
|
||
function classifierbackbone( | ||
layers; | ||
ps=0, | ||
use_bn=true, | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bn_final=false, | ||
act_cls=Flux.relu, | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lin_first=true) | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ps = Iterators.cycle(ps) | ||
classifiers = [] | ||
|
||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:end], ps) | ||
layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first) | ||
push!(classifiers, layer) | ||
end | ||
Chain(classifiers...) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't thinking that the linear chain should be customizable. I think this loop should get pushed into |
||
|
||
function TabularModel( | ||
catbackbone, | ||
contbackbone, | ||
classifierbackbone; | ||
final_activation=identity) | ||
tabularbackbone = Parallel(vcat, catbackbone, contbackbone) | ||
Chain( | ||
tabularbackbone, | ||
classifierbackbone, | ||
final_activation | ||
) | ||
end | ||
|
||
function TabularModel( | ||
catcols, | ||
n_cont::Number, | ||
out_sz::Number, | ||
layers=[200, 100]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a tuple |
||
catdict, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have some questions about understanding |
||
sz_dict=nothing, | ||
ps=0.) | ||
embedszs = get_emb_sz(catdict, catcols, sz_dict=sz_dict) | ||
catback = embeddingbackbone(embedszs) | ||
contback = continuousbackbone(n_cont) | ||
|
||
classifierin = mapreduce(layer -> size(layer.weight)[1], +, catback[2].layers, init = n_cont) | ||
layers = append!([classifierin], layers, [out_sz]) | ||
classback = classifierbackbone(layers, ps=ps) | ||
|
||
TabularModel(catback, contback, classback) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
include("../imports.jl") | ||
|
||
@testset ExtendedTestSet "TabularModel Components" begin | ||
@testset ExtendedTestSet "embeddingbackbone" begin | ||
embed_szs = [(5, 10), (100, 30), (2, 30)] | ||
embeds = embeddingbackbone(embed_szs, 0.) | ||
x = [rand(1:n) for (n, _) in embed_szs] | ||
|
||
@test size(embeds(x)) == (70, 1) | ||
end | ||
|
||
@testset ExtendedTestSet "continuousbackbone" begin | ||
n = 5 | ||
contback = continuousbackbone(n) | ||
x = rand(5, 1) | ||
@test size(contback(x)) == (5, 1) | ||
end | ||
|
||
@testset ExtendedTestSet "classifierbackbone" begin | ||
classback = classifierbackbone([10, 200, 100, 2]) | ||
x = rand(10, 2) | ||
@test size(classback(x)) == (2, 2) | ||
end | ||
|
||
@testset ExtendedTestSet "TabularModel" begin | ||
n = 5 | ||
embed_szs = [(5, 10), (100, 30), (2, 30)] | ||
|
||
embeds = embeddingbackbone(embed_szs, 0.) | ||
contback = continuousbackbone(n) | ||
classback = classifierbackbone([75, 200, 100, 4]) | ||
|
||
tm = TabularModel(embeds, contback, classback, final_activation = x->FastAI.sigmoidrange(x, 2, 5)) | ||
|
||
x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1)) | ||
y1 = tm(x) | ||
@test size(y1) == (4, 1) | ||
@test all(y1.> 2) && all(y1.<5) | ||
|
||
catcols = [:a, :b, :c] | ||
catdict = Dict(:a => rand(4), :b => rand(99), :c => rand(1)) | ||
tm2 = TabularModel(catcols, n, 4, [200, 100], catdict = catdict, sz_dict = Dict(:a=>10, :b=>30, :c=>30)) | ||
@test size(tm2(x)) == (4, 1) | ||
end | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment with a link to where this is taken from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I can add this link. I believe they got this formula experimentally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add it?