From b306865d06fae860be00ae9f623293c558893484 Mon Sep 17 00:00:00 2001 From: Manikya Bardhan Date: Sun, 22 Aug 2021 22:40:45 +0530 Subject: [PATCH] Add tabular model (#124) * added tabular model * fixed batchnorm in tabular model * added function for calculating embedding dimensions * updated tabular model * Apply suggestions from code review Co-authored-by: Kyle Daruwalla * simplified tabular model * updated tabular model * refactored TabularModel * updated tabular model, and added tests * added classifierbackbone * update tablemodel tests * export classifierbackbone * refactored TabularModel methods * updated tabular model tests * add TabularModel docstring * renamed args Co-authored-by: Kyle Daruwalla * updated docstrings and embed dims calculation, made args usage consistent * docstring fixes * made methods concise Co-authored-by: Kyle Daruwalla * updated docstrings and get_emb_sz * updated model test * undo unintentional comments * Docstring updates Co-authored-by: Kyle Daruwalla * minor docstring fix Co-authored-by: Kyle Daruwalla --- src/models/Models.jl | 5 +- src/models/tabularmodel.jl | 157 ++++++++++++++++++++++++++++++++++++ test/imports.jl | 1 + test/models/tabularmodel.jl | 41 ++++++++++ test/runtests.jl | 6 ++ 5 files changed, 208 insertions(+), 2 deletions(-) create mode 100644 src/models/tabularmodel.jl create mode 100644 test/models/tabularmodel.jl diff --git a/src/models/Models.jl b/src/models/Models.jl index 65c70ae4d5..9401596839 100644 --- a/src/models/Models.jl +++ b/src/models/Models.jl @@ -1,5 +1,6 @@ module Models +using Base: Bool, Symbol using ..FastAI using BSON @@ -13,9 +14,9 @@ include("blocks.jl") include("xresnet.jl") include("unet.jl") +include("tabularmodel.jl") -export xresnet18, xresnet50, UNetDynamic - +export xresnet18, xresnet50, UNetDynamic, TabularModel end diff --git a/src/models/tabularmodel.jl b/src/models/tabularmodel.jl new file mode 100644 index 0000000000..626293e6f5 --- /dev/null +++ b/src/models/tabularmodel.jl @@ -0,0 +1,157 @@ +""" + emb_sz_rule(n_cat) + +Compute an embedding size corresponding to the number of classes for a +categorical variable using the rule of thumb present in python fastai. +(see https://github.com/fastai/fastai/blob/2742fe844573d06e700f869839fb9ec5f3a9bca9/fastai/tabular/model.py#L12) +""" +emb_sz_rule(n_cat) = min(600, round(Int, 1.6 * n_cat^0.56)) + +""" + get_emb_sz(cardinalities::AbstractVector, [size_overrides::AbstractVector]) + +Given a vector of `cardinalities` of each categorical column +(i.e. each element of `cardinalities` is the number of classes in that categorical column), +compute the output embedding size according to [`emb_sz_rule`](#). +Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer. + +## Keyword arguments + +- `size_overrides`: A collection of integers (or `nothing` to skip override) where the value present at any index + will be used to as the output embedding size for that column. +""" +get_emb_sz(cardinalities::AbstractVector{<:Integer}, size_overrides=fill(nothing, length(cardinalities))) = + map(zip(cardinalities, size_overrides)) do (cardinality, override) + emb_dim = isnothing(override) ? emb_sz_rule(cardinality + 1) : Int64(override) + return (cardinality + 1, emb_dim) + end + +""" + get_emb_sz(cardinalities::Dict, [size_overrides::Dict]) + +Given a map from columns to `cardinalities`, compute the output embedding size according to [`emb_sz_rule`](#). +Return a vector of tuples where each element is `(in_size, out_size)` for an embedding layer. + +## Keyword arguments + +- `size_overrides`: A map of output embedding size overrides + (i.e. `size_overrides[col]` is the output embedding size for `col`). +""" +function get_emb_sz(cardinalities::Dict{<:Any, <:Integer}, size_overrides=Dict()) + values_and_overrides = map(pairs(cardinalities)) do (col, cardinality) + cardinality, get(size_overrides, col, nothing) + end + get_emb_sz(first.(values_and_overrides), last.(values_and_overrides)) +end + +sigmoidrange(x, low, high) = @. Flux.sigmoid(x) * (high - low) + low + +function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.) + embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] + emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate) + Chain( + x -> tuple(eachrow(x)...), + Parallel(vcat, embedslist), + emb_drop + ) +end + +tabular_continuous_backbone(n_cont) = BatchNorm(n_cont) + +""" + TabularModel(catbackbone, contbackbone, [finalclassifier]; kwargs...) + +Create a tabular model which operates on a tuple of categorical values +(label or one-hot encoded) and continuous values. +The categorical backbones (`catbackbone`) and continuous backbone (`contbackbone`) operate on each element of the input tuple. +The output from these backbones is then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block. + +## Keyword arguments + +- `outsize`: The output size of the final classifier block. For single classification tasks, + this would be the number of classes, and for regression tasks, this would be the + number of target continuous variables. +- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers. +- `dropout_rates`: Dropout probabilities for the linear-batch norm-dropout layers. + This could either be a single number which would be used for for all the layers, + or a collection of numbers which are cycled through for each layer. +- `batchnorm`: Set to `false` to skip each batch norm in the linear-batch norm-dropout sequence. +- `activation`: The activation function to use in the classifier layers. +- `linear_first`: Controls if the linear layer comes before or after batch norm and dropout. +""" +function TabularModel( + catbackbone, + contbackbone; + outsize, + layersizes=(200, 100), + kwargs...) + TabularModel(catbackbone, contbackbone, Dense(layersizes[end], outsize); layersizes=layersizes, kwargs...) +end + +function TabularModel( + catbackbone, + contbackbone, + finalclassifier; + layersizes=(200, 100), + dropout_rates=0., + batchnorm=true, + activation=Flux.relu, + linear_first=true) + + tabularbackbone = Parallel(vcat, catbackbone, contbackbone) + + classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers; + init = contbackbone.chs) + dropout_rates = Iterators.cycle(dropout_rates) + classifiers = [] + + first_ps, dropout_rates = Iterators.peel(dropout_rates) + push!(classifiers, linbndrop(classifierin, first(layersizes); + use_bn=batchnorm, p=first_ps, lin_first=linear_first, act=activation)) + + for (isize, osize, p) in zip(layersizes[1:(end-1)], layersizes[2:end], dropout_rates) + layer = linbndrop(isize, osize; use_bn=batchnorm, p=p, act=activation, lin_first=linear_first) + push!(classifiers, layer) + end + + Chain( + tabularbackbone, + classifiers..., + finalclassifier + ) +end + +""" + TabularModel(n_cont, outsize, [layersizes; kwargs...]) + +Create a tabular model which operates on a tuple of categorical values +(label or one-hot encoded) and continuous values. The default categorical backbone (`catbackbone`) is +a [`Flux.Parallel`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.Parallel) set of `Flux.Embedding` layers corresponding to each categorical variable. +The default continuous backbone (`contbackbone`) is a single [`Flux.BatchNorm`](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.BatchNorm). +The output from these backbones is concatenated then passed through a series of linear-batch norm-dropout layers before a `finalclassifier` block. + +## Arguments + +- `n_cont`: The number of continuous columns. +- `outsize`: The output size of the model. +- `layersizes`: A vector of sizes for each hidden layer in the sequence of linear layers. + +## Keyword arguments + +- `cardinalities`: A collection of sizes (number of classes) for each categorical column. +- `size_overrides`: An optional argument which corresponds to a collection containing + embedding sizes to override the value returned by the "rule of thumb" for a particular index + corresponding to `cardinalities`, or `nothing`. +""" +function TabularModel( + n_cont::Number, + outsize::Number, + layersizes=(200, 100); + cardinalities, + size_overrides=fill(nothing, length(cardinalities))) + embedszs = get_emb_sz(cardinalities, size_overrides) + catback = tabular_embedding_backbone(embedszs) + contback = tabular_continuous_backbone(n_cont) + + TabularModel(catback, contback; layersizes=layersizes, outsize=outsize) +end diff --git a/test/imports.jl b/test/imports.jl index 9923365b70..f5f5721602 100644 --- a/test/imports.jl +++ b/test/imports.jl @@ -6,6 +6,7 @@ import FastAI: Image, Keypoints, Mask, testencoding, Label, OneHot, ProjectiveTr encodedblock, decodedblock, encode, decode, mockblock, checkblock, Block, Encoding using FilePathsBase using FastAI.Datasets +using FastAI.Models using DLPipelines import DataAugmentation import DataAugmentation: getbounds diff --git a/test/models/tabularmodel.jl b/test/models/tabularmodel.jl new file mode 100644 index 0000000000..b2aad087ee --- /dev/null +++ b/test/models/tabularmodel.jl @@ -0,0 +1,41 @@ +include("../imports.jl") + +@testset ExtendedTestSet "TabularModel Components" begin + @testset ExtendedTestSet "embeddingbackbone" begin + embed_szs = [(5, 10), (100, 30), (2, 30)] + embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.) + x = [rand(1:n) for (n, _) in embed_szs] + + @test size(embeds(x)) == (70, 1) + end + + @testset ExtendedTestSet "continuousbackbone" begin + n = 5 + contback = FastAI.Models.tabular_continuous_backbone(n) + x = rand(5, 1) + @test size(contback(x)) == (5, 1) + end + + @testset ExtendedTestSet "TabularModel" begin + n = 5 + embed_szs = [(5, 10), (100, 30), (2, 30)] + + embeds = FastAI.Models.tabular_embedding_backbone(embed_szs, 0.) + contback = FastAI.Models.tabular_continuous_backbone(n) + + x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1)) + + tm = TabularModel(embeds, contback; outsize=4) + @test size(tm(x)) == (4, 1) + + tm2 = TabularModel(embeds, contback, Chain(Dense(100, 4), x->FastAI.Models.sigmoidrange(x, 2, 5))) + y2 = tm2(x) + @test all(y2.> 2) && all(y2.<5) + + cardinalities = [4, 99, 1] + tm3 = TabularModel(n, 4, [200, 100], cardinalities = cardinalities, size_overrides = (10, 30, 30)) + @test size(tm3(x)) == (4, 1) + end +end + + diff --git a/test/runtests.jl b/test/runtests.jl index 1333fc1b92..6bbfc87f95 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -72,4 +72,10 @@ include("imports.jl") end # TODO: test learning rate finder end + + @testset ExtendedTestSet "models/" begin + @testset ExtendedTestSet "tabularmodel.jl" begin + include("models/tabularmodel.jl") + end + end end