Skip to content

Commit

Permalink
Remove redundant imputation flag columns [Resolves #544]
Browse files Browse the repository at this point in the history
The content of the imputation flag columns across all functions for a
given timespan will be the same. This commit removes the redundant
columns, and names the imputation flag column without any function name
(e.g. 'events_entity_id_1y_outcome_imp' instead of
'events_entity_id_1y_outcome_avg_imp')

- Change the Imputer class interface:
    - Add column_imputation_base to constructor
    - Change imputation_flag_sql to imputation_flag_select_and_alias so
    the caller can keep track of the aliases without doing SQL parsing
- Change the Aggregation/SpacetimeAggregation to:
    - Create reverse column name -> Aggregate lookup (with some
    refactoring so it can build this without duplicating a bunch fo
    existing logic)
    - When creating the imputation SQL, query the lookup to create the
    column_imputation_base
- Modify experiment algorithm doc to describe imputation flag behavior
  • Loading branch information
thcrock committed Apr 19, 2019
1 parent ec78a2a commit c9c5182
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 70 deletions.
3 changes: 2 additions & 1 deletion docs/sources/experiments/algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ in the aggregation, pre-imputation. Its output location is generally `{prefix}_a
A table that looks similar, but with imputed values is created. The cohort table from above is passed into collate as
the comprehensive set of entities and dates for which output should be generated, regardless if they exist in the
`from_obj`. Each feature column has an imputation rule, inherited from some level of the feature definition. The
imputation rules that are based on data (e.g. `mean`) use the rows from the `as_of_time` to produce the imputed value.
imputation rules that are based on data (e.g. `mean`) use the rows from the `as_of_time` to produce the imputed value.
In addition, each column that needs imputation has an imputation flag column created, which contains a boolean flagging which rows were imputed or not. Since the values of these columns are redundant for all aggregate functions that look at a given timespan's worth of data, only one imputation flag column per timespan is created.
Its output location is generally `{prefix}_aggregation_imputed`

### Recap
Expand Down
30 changes: 10 additions & 20 deletions src/tests/architect_tests/test_feature_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 0,
"aprefix_zip_code_all_cat_one__NULL_sum": 1,
"aprefix_entity_id_all_quantity_one_sum_imp": 1,
"aprefix_entity_id_all_quantity_one_count_imp": 1,
"aprefix_zip_code_all_quantity_one_sum_imp": 1,
"aprefix_zip_code_all_quantity_one_count_imp": 1,
"aprefix_entity_id_all_quantity_one_imp": 1,
"aprefix_zip_code_all_quantity_one_imp": 1,
},
{
"entity_id": 1,
Expand All @@ -128,10 +126,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 1,
"aprefix_zip_code_all_cat_one_bad_sum": 0,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
{
"entity_id": 3,
Expand All @@ -147,10 +143,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 1,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
{
"entity_id": 3,
Expand All @@ -166,10 +160,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 2,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
{
"entity_id": 4,
Expand All @@ -185,10 +177,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 2,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
]
}
Expand Down
6 changes: 3 additions & 3 deletions src/tests/collate_tests/test_imputation_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,15 @@ def test_imputation_output(feat_list, exp_imp_cols, feat_table):
and imp != "zero_noflag"
):
assert (
"prefix_entity_id_1y_%s_max_imp" % feat in df.columns.values
"prefix_entity_id_1y_%s_imp" % feat in df.columns.values
)
assert (
df["prefix_entity_id_1y_%s_max_imp" % feat].isnull().sum()
df["prefix_entity_id_1y_%s_imp" % feat].isnull().sum()
== 0
)
else:
# should not generate an imputed column when not needed
assert (
"prefix_entity_id_1y_%s_max_imp" % feat
"prefix_entity_id_1y_%s_imp" % feat
not in df.columns.values
)
9 changes: 6 additions & 3 deletions src/tests/collate_tests/test_imputations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
def test_impute_flag():
imp = BaseImputation(column="a", coltype="aggregate")
assert (
imp.imputed_flag_sql() == 'CASE WHEN "a" IS NULL THEN 1::SMALLINT ELSE 0::SMALLINT END AS "a_imp" '
imp.imputed_flag_select_and_alias() == (
'CASE WHEN "a" IS NULL THEN 1::SMALLINT ELSE 0::SMALLINT END',
'a_imp'
)
)


def test_impute_flag_categorical():
imp = BaseImputation(column="a", coltype="categorical")
assert imp.imputed_flag_sql() is None
assert imp.imputed_flag_select_and_alias() == (None, None)


def test_mean_imputation():
Expand Down Expand Up @@ -77,7 +80,7 @@ def test_impute_zero():
def test_impute_zero_noflag():
imp = ImputeZeroNoFlag(column="a", coltype="aggregate")
assert imp.to_sql() == 'COALESCE("a", 0::SMALLINT) AS "a" '
assert imp.imputed_flag_sql() is None
assert imp.imputed_flag_select_and_alias() == (None, None)
assert imp.noflag

imp = ImputeZeroNoFlag(column="a_myval_max", coltype="categorical")
Expand Down
21 changes: 9 additions & 12 deletions src/tests/collate_tests/test_spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,37 +152,34 @@ def test_basic_spacetime():
assert rows[6]["entity_id"] == 4
assert rows[6]["as_of_date"] == date(2015, 1, 1)
assert rows[6]["events_entity_id_1y_outcome::int_sum"] == 3
assert rows[6]["events_entity_id_1y_outcome::int_sum_imp"] == 1
assert rows[6]["events_entity_id_1y_outcome::int_imp"] == 1
assert (
round(float(rows[6]["events_entity_id_1y_outcome::int_avg"]), 4) == 0.1667
)
assert rows[6]["events_entity_id_1y_outcome::int_avg_imp"] == 1
assert rows[6]["events_entity_id_1y_outcome::int_imp"] == 1
assert rows[6]["events_entity_id_2y_outcome::int_sum"] == 3
assert rows[6]["events_entity_id_2y_outcome::int_sum_imp"] == 1
assert rows[6]["events_entity_id_2y_outcome::int_imp"] == 1
assert (
round(float(rows[6]["events_entity_id_2y_outcome::int_avg"]), 4) == 0.3333
)
assert rows[6]["events_entity_id_2y_outcome::int_avg_imp"] == 1
assert rows[6]["events_entity_id_2y_outcome::int_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_sum"] == 3
assert rows[6]["events_entity_id_all_outcome::int_sum_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_imp"] == 1
assert (
round(float(rows[6]["events_entity_id_all_outcome::int_avg"]), 4) == 0.3333
)
assert rows[6]["events_entity_id_all_outcome::int_avg_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_imp"] == 1
assert rows[7]["entity_id"] == 4
assert rows[7]["as_of_date"] == date(2016, 1, 1)
assert rows[7]["events_entity_id_1y_outcome::int_sum"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_sum_imp"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_imp"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_avg"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_avg_imp"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_sum"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_sum_imp"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_imp"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_avg"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_avg_imp"] == 0
assert rows[7]["events_entity_id_all_outcome::int_sum"] == 0
assert rows[7]["events_entity_id_all_outcome::int_sum_imp"] == 0
assert rows[7]["events_entity_id_all_outcome::int_imp"] == 0
assert rows[7]["events_entity_id_all_outcome::int_avg"] == 0
assert rows[7]["events_entity_id_all_outcome::int_avg_imp"] == 0
assert len(rows) == 8


Expand Down
54 changes: 49 additions & 5 deletions src/triage/component/collate/collate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
import logging
from numbers import Number
from itertools import product, chain
import sqlalchemy.sql.expression as ex
import re
from descriptors import cachedproperty

from .sql import make_sql_clause, to_sql_name, CreateTableAs, InsertFromSelect
from .imputations import (
Expand Down Expand Up @@ -478,16 +480,37 @@ def __init__(
self.suffix = suffix if suffix else "aggregation"
self.schema = schema

@cachedproperty
def colname_aggregate_lookup(self):
"""A reverse lookup from column name to the source collate.Aggregate
Will error if the Aggregation contains duplicate column names
"""
lookup = {}
for group, groupby in self.groups.items():
for agg in self.aggregates:
for col in agg.get_columns(prefix=self._col_prefix(group)):
if col.name in lookup:
raise ValueError("Duplicate feature column name found: ", col.name)
lookup[col.name] = agg
return lookup

def _col_prefix(self, group):
"""
Helper for creating a column prefix for the group
group: group clause, for naming columns
Returns: string for a common column prefix for columns in that group
"""
return "{prefix}_{group}_".format(prefix=self.prefix, group=group)

def _get_aggregates_sql(self, group):
"""
Helper for getting aggregates sql
Args:
group: group clause, for naming columns
Returns: collection of aggregate column SQL strings
"""
prefix = "{prefix}_{group}_".format(prefix=self.prefix, group=group)

return chain(*[a.get_columns(prefix=prefix) for a in self.aggregates])
return chain(*[a.get_columns(prefix=self._col_prefix(group)) for a in self.aggregates])

def get_selects(self):
"""
Expand Down Expand Up @@ -683,6 +706,7 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None):
# key columns and date column
query = ""

used_impflags = set()
# pre-sort and iterate through the combined set to ensure column order
for col in sorted(nonimpute_cols + impute_cols):
# just pass through columns that don't require imputation (no nulls found)
Expand All @@ -693,6 +717,22 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None):
# and a flag for whether the value was imputed
if col in impute_cols:

# we don't want to add redundant imputation flags. for a given source
# column and time interval, all of the functions will have identical
# sets of rows that needed imputation
# to reliably merge these, we lookup the original aggregate that produced
# the function, and see its available functions. we expect exactly one of
# these functions to end the column name and remove it if so
# this is passed to the imputer
if hasattr(self.colname_aggregate_lookup[col], 'functions'):
agg_functions = self.colname_aggregate_lookup[col].functions
used_function = next(funcname for funcname in agg_functions if col.endswith(funcname))
impflag_basecol = col.rstrip('_' + used_function)
else:
logging.warning("Imputation flag merging is not implemented for "
"AggregateExpression objects that don't define an aggregate "
"function (e.g. composites)")
impflag_basecol = col
impute_rule = imprules[col]

try:
Expand All @@ -703,13 +743,17 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None):
% (impute_rule.get("type", ""), col)
) from err

imputer = imputer(column=col, partitionby=partitionby, **impute_rule)
imputer = imputer(column=col, column_base_for_impflag=impflag_basecol, partitionby=partitionby, **impute_rule)

query += "\n,%s" % imputer.to_sql()
if not imputer.noflag:
# Add an imputation flag for non-categorical columns (this is handeled
# for categorical columns with a separate NULL category)
query += "\n,%s" % imputer.imputed_flag_sql()
# but only add it if another functionally equivalent impflag hasn't already been added
impflag_select, impflag_alias = imputer.imputed_flag_select_and_alias()
if impflag_alias not in used_impflags:
used_impflags.add(impflag_alias)
query += "\n,%s as \"%s\" " % (impflag_select, impflag_alias)

return query

Expand Down
Loading

0 comments on commit c9c5182

Please sign in to comment.