Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tabular model #124

Merged
merged 24 commits into from
Aug 22, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ Flux = "0.12"
FluxTraining = "0.2"
Glob = "1"
IndirectArrays = "0.5"
LearnBase = "0.3, 0.4"
JLD2 = "0.4"
LearnBase = "0.3, 0.4"
MLDataPattern = "0.5"
Makie = "0.15"
MosaicViews = "0.2, 0.3"
Expand Down
5 changes: 3 additions & 2 deletions src/models/Models.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Models

using Base: Bool, Symbol
using ..FastAI

using BSON
Expand All @@ -13,9 +14,9 @@ include("blocks.jl")

include("xresnet.jl")
include("unet.jl")
include("tabularmodel.jl")


export xresnet18, xresnet50, UNetDynamic

export xresnet18, xresnet50, UNetDynamic, TabularModel

end
141 changes: 141 additions & 0 deletions src/models/tabularmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
emb_sz_rule(n_cat)

Returns an embedding size corresponding to the number of classes for a
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
categorical variable using the rule of thumb present in python fastai.
(see https://github.com/fastai/fastai/blob/2742fe844573d06e700f869839fb9ec5f3a9bca9/fastai/tabular/model.py#L12)
"""
function emb_sz_rule(n_cat)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment with a link to where this is taken from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can add this link. I believe they got this formula experimentally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add it?

min(600, round(1.6 * n_cat^0.56))
end
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

"""
get_emb_sz(cardinalities, [size_overrides])
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
get_emb_sz(cardinalities; catcols, [size_overrides])

Returns a collection of tuples containing embedding dimensions corresponding to
number of classes in categorical columns present in `cardinalities` and adjusting for nans.
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

## Keyword arguments

- `size_overrides`: Depending on the method used, this could either be a collection of
Integers and `nothing` or an indexable collection with column name as key and size
to override it with as the value. In the first case, the integer present at any index
will be used to override the rule of thumb for getting embedding sizes.
- `categorical_cols`: A collection of categorical column names.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring is better off split into two separate docstrings for each method. Both will show up automatically in the docs, but you will be able to tailor the explanation of size_overrides to each method (instead of explaining both in one paragraph).


manikyabard marked this conversation as resolved.
Show resolved Hide resolved
function get_emb_sz(cardinalities, size_overrides=fill(nothing, length(cardinalities)))
map(Iterators.enumerate(cardinalities)) do (i, cardinality)
emb_dim = isnothing(size_overrides[i]) ? emb_sz_rule(cardinality+1) : size_overrides[i]
(Int64(cardinality)+1, Int64(emb_dim))
end
end
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

function get_emb_sz(cardinalities; catcols, size_overrides=Dict())
keylist = keys(size_overrides)
overrides = map(catcols) do col
col in keylist ? size_overrides[col] : nothing
end
get_emb_sz(cardinalities, overrides)
end

function sigmoidrange(x, low, high)
@. Flux.sigmoid(x) * (high - low) + low
end
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

function tabular_embedding_backbone(embedding_sizes, dropoutprob=0.)
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes]
emb_drop = dropoutprob==0. ? identity : Dropout(dropoutprob)
Chain(
x -> tuple(eachrow(x)...),
Parallel(vcat, embedslist),
emb_drop
)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function tabular_embedding_backbone(embedding_sizes, dropoutprob=0.)
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes]
emb_drop = dropoutprob==0. ? identity : Dropout(dropoutprob)
Chain(
x -> tuple(eachrow(x)...),
Parallel(vcat, embedslist),
emb_drop
)
end
function tabular_embedding_backbone(embedding_sizes, dropout_rates=0.)
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes]
emb_drop = iszero(dropout_rates) ? identity : Dropout(dropout_rates)
Chain(
x -> tuple(eachrow(x)...),
Parallel(vcat, embedslist),
emb_drop
)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emb_drop could even be further inlined if desired. If preserving model structure is more important than a bit of lost performance, then passing active=iszero(dropout_rates) instead of using a ternary is also a valid option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is active an argument to something? How would that work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

active is the third arg to BatchNorm, yes. I wouldn't worry about that now though, it's a minor tweak we can always revisit later.


function tabular_continuous_backbone(n_cont)
BatchNorm(n_cont)
end
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

"""
TabularModel(catbackbone, contbackbone, [finalclassifier]; kwargs...)
TabularModel(`n_cont::Number, outsize::Number[; kwargs...])
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

Create a tabular model which takes in a tuple of categorical values
(label or one-hot encoded) and continuous values. The default categorical backbone is
a Parallel of Embedding layers corresponding to each categorical variable, and continuous
variables are just BatchNormed. The output from these backbones is then passed through
a final classifier block.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another case where two separate docstrings would help. catbackbone, etc. don't have default values except implicitly in second method.


## Keyword arguments

- `outsize`: The output size of the final classifier block. For single classification tasks,
this would just be the number of classes and for regression tasks, this could be the
number of target continuous variables.
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
- `layersizes`: The sizes of the hidden layers in the classifier block.
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
- `dropout_rates`: Dropout probability. This could either be a single number which would be
used for for all the classifier layers, or a collection of numbers which are cycled through
for each layer.
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
- `batchnorm`: Boolean variable which controls whether to use batch normalization in the classifier.
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
- `activation`: The activation function to use in the classifier layers.
- `linear_first`: Controls if the linear layer comes before or after BatchNorm and Dropout.
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
- `cardinalities`: A collection of sizes (number of classes) for each categorical column.
- `size_overrides`: An optional argument which corresponds to a collection containing
embedding sizes to override the value returned by the "rule of thumb" for a particular index
corresponding to `cardinalities`, or `nothing`.
"""

manikyabard marked this conversation as resolved.
Show resolved Hide resolved
function TabularModel(
catbackbone,
contbackbone;
outsize,
layersizes=[200, 100],
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
kwargs...)
TabularModel(catbackbone, contbackbone, Dense(layersizes[end], outsize); layersizes=layersizes, kwargs...)
end

function TabularModel(
catbackbone,
contbackbone,
finalclassifier;
layersizes=[200, 100],
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
dropout_rates=0.,
batchnorm=true,
activation=Flux.relu,
linear_first=true)

tabularbackbone = Parallel(vcat, catbackbone, contbackbone)

classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers;
init = contbackbone.chs)
dropout_rates = Iterators.cycle(dropout_rates)
classifiers = []

first_ps, dropout_rates = Iterators.peel(dropout_rates)
push!(classifiers, linbndrop(classifierin, first(layersizes); use_bn=batchnorm, p=first_ps, lin_first=linear_first, act=activation))
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

for (isize, osize, p) in zip(layersizes[1:(end-1)], layersizes[2:end], dropout_rates)
layer = linbndrop(isize, osize; use_bn=batchnorm, p=p, act=activation, lin_first=linear_first)
push!(classifiers, layer)
end

Chain(
tabularbackbone,
classifiers...,
finalclassifier
)
end

function TabularModel(
n_cont::Number,
outsize::Number,
layersizes=[200, 100];
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
cardinalities,
size_overrides=fill(nothing, length(cardinalities)))
embedszs = get_emb_sz(cardinalities, size_overrides)
catback = tabular_embedding_backbone(embedszs)
contback = tabular_continuous_backbone(n_cont)

TabularModel(catback, contback; layersizes=layersizes, outsize=outsize)
end
1 change: 1 addition & 0 deletions test/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using FastAI: Image, Keypoints, Mask, testencoding, Label, OneHot, ProjectiveTra
encodedblock, decodedblock, encode, decode, mockblock
using FilePathsBase
using FastAI.Datasets
using FastAI.Models
using DLPipelines
import DataAugmentation
import DataAugmentation: getbounds
Expand Down
41 changes: 41 additions & 0 deletions test/models/tabularmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
include("../imports.jl")

@testset ExtendedTestSet "TabularModel Components" begin
@testset ExtendedTestSet "embeddingbackbone" begin
embed_szs = [(5, 10), (100, 30), (2, 30)]
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.)
x = [rand(1:n) for (n, _) in embed_szs]

@test size(embeds(x)) == (70, 1)
end

@testset ExtendedTestSet "continuousbackbone" begin
n = 5
contback = FastAI.Models.tabular_continuous_backbone(n)
x = rand(5, 1)
@test size(contback(x)) == (5, 1)
end

@testset ExtendedTestSet "TabularModel" begin
n = 5
embed_szs = [(5, 10), (100, 30), (2, 30)]

embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.)
contback = FastAI.Models.tabular_continuous_backbone(n)

x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1))

tm = TabularModel(embeds, contback; outsize=4)
@test size(tm(x)) == (4, 1)

tm2 = TabularModel(embeds, contback, Chain(Dense(100, 4), x->FastAI.Models.sigmoidrange(x, 2, 5)))
y2 = tm2(x)
@test all(y2.> 2) && all(y2.<5)

cardinalities = (4, 99, 1)
tm3 = TabularModel(n, 4, [200, 100], cardinalities = cardinalities, size_overrides = (10, 30, 30))
@test size(tm3(x)) == (4, 1)
end
end


6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,10 @@ include("imports.jl")
end
# TODO: test learning rate finder
end

@testset ExtendedTestSet "models/" begin
@testset ExtendedTestSet "tabularmodel.jl" begin
include("models/tabularmodel.jl")
end
end
end