-
-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added tabular model * fixed batchnorm in tabular model * added function for calculating embedding dimensions * updated tabular model * Apply suggestions from code review Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * simplified tabular model * updated tabular model * refactored TabularModel * updated tabular model, and added tests * added classifierbackbone * update tablemodel tests * export classifierbackbone * refactored TabularModel methods * updated tabular model tests * add TabularModel docstring * renamed args Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * updated docstrings and embed dims calculation, made args usage consistent * docstring fixes * made methods concise Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * updated docstrings and get_emb_sz * updated model test * undo unintentional comments * Docstring updates Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com> * minor docstring fix Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
- Loading branch information
1 parent
9fb14af
commit b306865
Showing
5 changed files
with
208 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
""" | ||
emb_sz_rule(n_cat) | ||
Compute an embedding size corresponding to the number of classes for a | ||
categorical variable using the rule of thumb present in python fastai. | ||
(see https://github.com/fastai/fastai/blob/2742fe844573d06e700f869839fb9ec5f3a9bca9/fastai/tabular/model.py#L12) | ||
""" | ||
emb_sz_rule(n_cat) = min(600, round(Int, 1.6 * n_cat^0.56)) | ||
|
||
""" | ||
get_emb_sz(cardinalities::AbstractVector, [size_overrides::AbstractVector]) | ||
Given a vector of `cardinalities` of each categorical column | ||
(i.e. each element of `cardinalities` is the number of classes in that categorical column), | ||
compute the output embedding size according to [`emb_sz_rule`](#). | ||
Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer. | ||
## Keyword arguments | ||
- `size_overrides`: A collection of integers (or `nothing` to skip override) where the value present at any index | ||
will be used to as the output embedding size for that column. | ||
""" | ||
get_emb_sz(cardinalities::AbstractVector{<:Integer}, size_overrides=fill(nothing, length(cardinalities))) = | ||
map(zip(cardinalities, size_overrides)) do (cardinality, override) | ||
emb_dim = isnothing(override) ? emb_sz_rule(cardinality + 1) : Int64(override) | ||
return (cardinality + 1, emb_dim) | ||
end | ||
|
||
""" | ||
get_emb_sz(cardinalities::Dict, [size_overrides::Dict]) | ||
Given a map from columns to `cardinalities`, compute the output embedding size according to [`emb_sz_rule`](#). | ||
Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer. | ||
## Keyword arguments | ||
- `size_overrides`: A map of output embedding size overrides | ||
(i.e. `size_overrides[col]` is the output embedding size for `col`). | ||
""" | ||
function get_emb_sz(cardinalities::Dict{<:Any, <:Integer}, size_overrides=Dict()) | ||
values_and_overrides = map(pairs(cardinalities)) do (col, cardinality) | ||
cardinality, get(size_overrides, col, nothing) | ||
end | ||
get_emb_sz(first.(values_and_overrides), last.(values_and_overrides)) | ||
end | ||
|
||
sigmoidrange(x, low, high) = @. Flux.sigmoid(x) * (high - low) + low | ||
|
||
function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.) | ||
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] | ||
emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate) | ||
Chain( | ||
x -> tuple(eachrow(x)...), | ||
Parallel(vcat, embedslist), | ||
emb_drop | ||
) | ||
end | ||
|
||
tabular_continuous_backbone(n_cont) = BatchNorm(n_cont) | ||
|
||
""" | ||
TabularModel(catbackbone, contbackbone, [finalclassifier]; kwargs...) | ||
Create a tabular model which operates on a tuple of categorical values | ||
(label or one-hot encoded) and continuous values. | ||
The categorical backbones (`catbackbone`) and continuous backbone (`contbackbone`) operate on each element of the input tuple. | ||
The output from these backbones is then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block. | ||
## Keyword arguments | ||
- `outsize`: The output size of the final classifier block. For single classification tasks, | ||
this would be the number of classes, and for regression tasks, this would be the | ||
number of target continuous variables. | ||
- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers. | ||
- `dropout_rates`: Dropout probabilities for the linear-batch norm-dropout layers. | ||
This could either be a single number which would be used for for all the layers, | ||
or a collection of numbers which are cycled through for each layer. | ||
- `batchnorm`: Set to `false` to skip each batch norm in the linear-batch norm-dropout sequence. | ||
- `activation`: The activation function to use in the classifier layers. | ||
- `linear_first`: Controls if the linear layer comes before or after batch norm and dropout. | ||
""" | ||
function TabularModel( | ||
catbackbone, | ||
contbackbone; | ||
outsize, | ||
layersizes=(200, 100), | ||
kwargs...) | ||
TabularModel(catbackbone, contbackbone, Dense(layersizes[end], outsize); layersizes=layersizes, kwargs...) | ||
end | ||
|
||
function TabularModel( | ||
catbackbone, | ||
contbackbone, | ||
finalclassifier; | ||
layersizes=(200, 100), | ||
dropout_rates=0., | ||
batchnorm=true, | ||
activation=Flux.relu, | ||
linear_first=true) | ||
|
||
tabularbackbone = Parallel(vcat, catbackbone, contbackbone) | ||
|
||
classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers; | ||
init = contbackbone.chs) | ||
dropout_rates = Iterators.cycle(dropout_rates) | ||
classifiers = [] | ||
|
||
first_ps, dropout_rates = Iterators.peel(dropout_rates) | ||
push!(classifiers, linbndrop(classifierin, first(layersizes); | ||
use_bn=batchnorm, p=first_ps, lin_first=linear_first, act=activation)) | ||
|
||
for (isize, osize, p) in zip(layersizes[1:(end-1)], layersizes[2:end], dropout_rates) | ||
layer = linbndrop(isize, osize; use_bn=batchnorm, p=p, act=activation, lin_first=linear_first) | ||
push!(classifiers, layer) | ||
end | ||
|
||
Chain( | ||
tabularbackbone, | ||
classifiers..., | ||
finalclassifier | ||
) | ||
end | ||
|
||
""" | ||
TabularModel(n_cont, outsize, [layersizes; kwargs...]) | ||
Create a tabular model which operates on a tuple of categorical values | ||
(label or one-hot encoded) and continuous values. The default categorical backbone (`catbackbone`) is | ||
a [`Flux.Parallel`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Parallel) set of `Flux.Embedding` layers corresponding to each categorical variable. | ||
The default continuous backbone (`contbackbone`) is a single [`Flux.BatchNorm`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.BatchNorm). | ||
The output from these backbones is concatenated then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block. | ||
## Arguments | ||
- `n_cont`: The number of continuous columns. | ||
- `outsize`: The output size of the model. | ||
- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers. | ||
## Keyword arguments | ||
- `cardinalities`: A collection of sizes (number of classes) for each categorical column. | ||
- `size_overrides`: An optional argument which corresponds to a collection containing | ||
embedding sizes to override the value returned by the "rule of thumb" for a particular index | ||
corresponding to `cardinalities`, or `nothing`. | ||
""" | ||
function TabularModel( | ||
n_cont::Number, | ||
outsize::Number, | ||
layersizes=(200, 100); | ||
cardinalities, | ||
size_overrides=fill(nothing, length(cardinalities))) | ||
embedszs = get_emb_sz(cardinalities, size_overrides) | ||
catback = tabular_embedding_backbone(embedszs) | ||
contback = tabular_continuous_backbone(n_cont) | ||
|
||
TabularModel(catback, contback; layersizes=layersizes, outsize=outsize) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
include("../imports.jl") | ||
|
||
@testset ExtendedTestSet "TabularModel Components" begin | ||
@testset ExtendedTestSet "embeddingbackbone" begin | ||
embed_szs = [(5, 10), (100, 30), (2, 30)] | ||
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.) | ||
x = [rand(1:n) for (n, _) in embed_szs] | ||
|
||
@test size(embeds(x)) == (70, 1) | ||
end | ||
|
||
@testset ExtendedTestSet "continuousbackbone" begin | ||
n = 5 | ||
contback = FastAI.Models.tabular_continuous_backbone(n) | ||
x = rand(5, 1) | ||
@test size(contback(x)) == (5, 1) | ||
end | ||
|
||
@testset ExtendedTestSet "TabularModel" begin | ||
n = 5 | ||
embed_szs = [(5, 10), (100, 30), (2, 30)] | ||
|
||
embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.) | ||
contback = FastAI.Models.tabular_continuous_backbone(n) | ||
|
||
x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1)) | ||
|
||
tm = TabularModel(embeds, contback; outsize=4) | ||
@test size(tm(x)) == (4, 1) | ||
|
||
tm2 = TabularModel(embeds, contback, Chain(Dense(100, 4), x->FastAI.Models.sigmoidrange(x, 2, 5))) | ||
y2 = tm2(x) | ||
@test all(y2.> 2) && all(y2.<5) | ||
|
||
cardinalities = [4, 99, 1] | ||
tm3 = TabularModel(n, 4, [200, 100], cardinalities = cardinalities, size_overrides = (10, 30, 30)) | ||
@test size(tm3(x)) == (4, 1) | ||
end | ||
end | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters