Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant imputation flag columns [Resolves #544] #676

Merged
merged 3 commits into from
Apr 25, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/sources/experiments/algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ in the aggregation, pre-imputation. Its output location is generally `{prefix}_a
A table that looks similar, but with imputed values is created. The cohort table from above is passed into collate as
the comprehensive set of entities and dates for which output should be generated, regardless if they exist in the
`from_obj`. Each feature column has an imputation rule, inherited from some level of the feature definition. The
imputation rules that are based on data (e.g. `mean`) use the rows from the `as_of_time` to produce the imputed value.
imputation rules that are based on data (e.g. `mean`) use the rows from the `as_of_time` to produce the imputed value.
In addition, each column that needs imputation has an imputation flag column created, which contains a boolean flagging which rows were imputed or not. Since the values of these columns are redundant for most aggregate functions that look at a given timespan's worth of data (they will be imputed only if zero events in their timespan are seen), only one imputation flag column per timespan is created. An exception to this are some statistical functions that require not one, but two values, like standard deviation and variance. These boolean imputation flags are *not* merged in with the others.
Its output location is generally `{prefix}_aggregation_imputed`

### Recap
Expand Down
30 changes: 10 additions & 20 deletions src/tests/architect_tests/test_feature_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 0,
"aprefix_zip_code_all_cat_one__NULL_sum": 1,
"aprefix_entity_id_all_quantity_one_sum_imp": 1,
"aprefix_entity_id_all_quantity_one_count_imp": 1,
"aprefix_zip_code_all_quantity_one_sum_imp": 1,
"aprefix_zip_code_all_quantity_one_count_imp": 1,
"aprefix_entity_id_all_quantity_one_imp": 1,
"aprefix_zip_code_all_quantity_one_imp": 1,
},
{
"entity_id": 1,
Expand All @@ -128,10 +126,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 1,
"aprefix_zip_code_all_cat_one_bad_sum": 0,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
{
"entity_id": 3,
Expand All @@ -147,10 +143,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 1,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
{
"entity_id": 3,
Expand All @@ -166,10 +160,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 2,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
{
"entity_id": 4,
Expand All @@ -185,10 +177,8 @@ def test_feature_generation(test_engine):
"aprefix_zip_code_all_cat_one_good_sum": 0,
"aprefix_zip_code_all_cat_one_bad_sum": 2,
"aprefix_zip_code_all_cat_one__NULL_sum": 0,
"aprefix_entity_id_all_quantity_one_sum_imp": 0,
"aprefix_entity_id_all_quantity_one_count_imp": 0,
"aprefix_zip_code_all_quantity_one_sum_imp": 0,
"aprefix_zip_code_all_quantity_one_count_imp": 0,
"aprefix_entity_id_all_quantity_one_imp": 0,
"aprefix_zip_code_all_quantity_one_imp": 0,
},
]
}
Expand Down
6 changes: 3 additions & 3 deletions src/tests/collate_tests/test_imputation_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,15 @@ def test_imputation_output(feat_list, exp_imp_cols, feat_table):
and imp != "zero_noflag"
):
assert (
"prefix_entity_id_1y_%s_max_imp" % feat in df.columns.values
"prefix_entity_id_1y_%s_imp" % feat in df.columns.values
)
assert (
df["prefix_entity_id_1y_%s_max_imp" % feat].isnull().sum()
df["prefix_entity_id_1y_%s_imp" % feat].isnull().sum()
== 0
)
else:
# should not generate an imputed column when not needed
assert (
"prefix_entity_id_1y_%s_max_imp" % feat
"prefix_entity_id_1y_%s_imp" % feat
not in df.columns.values
)
9 changes: 6 additions & 3 deletions src/tests/collate_tests/test_imputations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
def test_impute_flag():
imp = BaseImputation(column="a", coltype="aggregate")
assert (
imp.imputed_flag_sql() == 'CASE WHEN "a" IS NULL THEN 1::SMALLINT ELSE 0::SMALLINT END AS "a_imp" '
imp.imputed_flag_select_and_alias() == (
'CASE WHEN "a" IS NULL THEN 1::SMALLINT ELSE 0::SMALLINT END',
'a_imp'
)
)


def test_impute_flag_categorical():
imp = BaseImputation(column="a", coltype="categorical")
assert imp.imputed_flag_sql() is None
assert imp.imputed_flag_select_and_alias() == (None, None)


def test_mean_imputation():
Expand Down Expand Up @@ -77,7 +80,7 @@ def test_impute_zero():
def test_impute_zero_noflag():
imp = ImputeZeroNoFlag(column="a", coltype="aggregate")
assert imp.to_sql() == 'COALESCE("a", 0::SMALLINT) AS "a" '
assert imp.imputed_flag_sql() is None
assert imp.imputed_flag_select_and_alias() == (None, None)
assert imp.noflag

imp = ImputeZeroNoFlag(column="a_myval_max", coltype="categorical")
Expand Down
35 changes: 21 additions & 14 deletions src/tests/collate_tests/test_spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def test_basic_spacetime():

agg = Aggregate(
"outcome::int",
["sum", "avg"],
["sum", "avg", "stddev"],
{
"coltype": "aggregate",
"avg": {"type": "mean"},
"sum": {"type": "constant", "value": 3},
"max": {"type": "zero"},
"stddev": {"type": "constant", "value": 2},
},
)
st = SpacetimeAggregation(
Expand Down Expand Up @@ -152,37 +152,44 @@ def test_basic_spacetime():
assert rows[6]["entity_id"] == 4
assert rows[6]["as_of_date"] == date(2015, 1, 1)
assert rows[6]["events_entity_id_1y_outcome::int_sum"] == 3
assert rows[6]["events_entity_id_1y_outcome::int_sum_imp"] == 1
assert rows[6]["events_entity_id_1y_outcome::int_imp"] == 1
assert rows[6]["events_entity_id_1y_outcome::int_stddev"] == 2
assert rows[6]["events_entity_id_1y_outcome::int_stddev_imp"] == 1
assert (
round(float(rows[6]["events_entity_id_1y_outcome::int_avg"]), 4) == 0.1667
)
assert rows[6]["events_entity_id_1y_outcome::int_avg_imp"] == 1
assert rows[6]["events_entity_id_2y_outcome::int_sum"] == 3
assert rows[6]["events_entity_id_2y_outcome::int_sum_imp"] == 1
assert rows[6]["events_entity_id_2y_outcome::int_imp"] == 1
assert rows[6]["events_entity_id_2y_outcome::int_stddev"] == 2
assert rows[6]["events_entity_id_2y_outcome::int_stddev_imp"] == 1
assert (
round(float(rows[6]["events_entity_id_2y_outcome::int_avg"]), 4) == 0.3333
)
assert rows[6]["events_entity_id_2y_outcome::int_avg_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_sum"] == 3
assert rows[6]["events_entity_id_all_outcome::int_sum_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_stddev"] == 2
assert rows[6]["events_entity_id_all_outcome::int_stddev_imp"] == 1
assert (
round(float(rows[6]["events_entity_id_all_outcome::int_avg"]), 4) == 0.3333
)
assert rows[6]["events_entity_id_all_outcome::int_avg_imp"] == 1
assert rows[6]["events_entity_id_all_outcome::int_imp"] == 1
assert rows[7]["entity_id"] == 4
assert rows[7]["as_of_date"] == date(2016, 1, 1)
assert rows[7]["events_entity_id_1y_outcome::int_sum"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_sum_imp"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_imp"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_avg"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_avg_imp"] == 0
assert rows[7]["events_entity_id_1y_outcome::int_stddev"] == 2
assert rows[7]["events_entity_id_1y_outcome::int_stddev_imp"] == 1
assert rows[7]["events_entity_id_2y_outcome::int_sum"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_sum_imp"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_imp"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_avg"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_avg_imp"] == 0
assert rows[7]["events_entity_id_2y_outcome::int_stddev"] == 2
assert rows[7]["events_entity_id_2y_outcome::int_stddev_imp"] == 1
assert rows[7]["events_entity_id_all_outcome::int_sum"] == 0
assert rows[7]["events_entity_id_all_outcome::int_sum_imp"] == 0
assert rows[7]["events_entity_id_all_outcome::int_imp"] == 0
assert rows[7]["events_entity_id_all_outcome::int_avg"] == 0
assert rows[7]["events_entity_id_all_outcome::int_avg_imp"] == 0
assert rows[7]["events_entity_id_all_outcome::int_stddev"] == 2
assert rows[7]["events_entity_id_all_outcome::int_stddev_imp"] == 1
assert len(rows) == 8


Expand Down
58 changes: 53 additions & 5 deletions src/triage/component/collate/collate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
import logging
from numbers import Number
from itertools import product, chain
import sqlalchemy.sql.expression as ex
import re
from descriptors import cachedproperty

from .sql import make_sql_clause, to_sql_name, CreateTableAs, InsertFromSelect
from .imputations import (
Expand Down Expand Up @@ -35,6 +37,7 @@ def make_tuple(a):


DISTINCT_REGEX = re.compile(r"distinct[ (]")
AGGFUNCS_NEED_MULTIPLE_VALUES = set(['stddev', 'stddev_samp', 'variance', 'var_samp'])


def split_distinct(quantity):
Expand Down Expand Up @@ -478,16 +481,37 @@ def __init__(
self.suffix = suffix if suffix else "aggregation"
self.schema = schema

@cachedproperty
def colname_aggregate_lookup(self):
"""A reverse lookup from column name to the source collate.Aggregate

Will error if the Aggregation contains duplicate column names
"""
lookup = {}
for group, groupby in self.groups.items():
for agg in self.aggregates:
for col in agg.get_columns(prefix=self._col_prefix(group)):
if col.name in lookup:
raise ValueError("Duplicate feature column name found: ", col.name)
lookup[col.name] = agg
return lookup

def _col_prefix(self, group):
"""
Helper for creating a column prefix for the group
group: group clause, for naming columns
Returns: string for a common column prefix for columns in that group
"""
return "{prefix}_{group}_".format(prefix=self.prefix, group=group)

def _get_aggregates_sql(self, group):
"""
Helper for getting aggregates sql
Args:
group: group clause, for naming columns
Returns: collection of aggregate column SQL strings
"""
prefix = "{prefix}_{group}_".format(prefix=self.prefix, group=group)

return chain(*[a.get_columns(prefix=prefix) for a in self.aggregates])
return chain(*[a.get_columns(prefix=self._col_prefix(group)) for a in self.aggregates])

def get_selects(self):
"""
Expand Down Expand Up @@ -683,6 +707,7 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None):
# key columns and date column
query = ""

used_impflags = set()
# pre-sort and iterate through the combined set to ensure column order
for col in sorted(nonimpute_cols + impute_cols):
# just pass through columns that don't require imputation (no nulls found)
Expand All @@ -693,6 +718,25 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None):
# and a flag for whether the value was imputed
if col in impute_cols:

# we don't want to add redundant imputation flags. for a given source
# column and time interval, all of the functions will have identical
# sets of rows that needed imputation
# to reliably merge these, we lookup the original aggregate that produced
# the function, and see its available functions. we expect exactly one of
# these functions to end the column name and remove it if so
# this is passed to the imputer
if hasattr(self.colname_aggregate_lookup[col], 'functions'):
agg_functions = self.colname_aggregate_lookup[col].functions
used_function = next(funcname for funcname in agg_functions if col.endswith(funcname))
if used_function in AGGFUNCS_NEED_MULTIPLE_VALUES:
impflag_basecol = col
else:
impflag_basecol = col.rstrip('_' + used_function)
else:
logging.warning("Imputation flag merging is not implemented for "
"AggregateExpression objects that don't define an aggregate "
"function (e.g. composites)")
impflag_basecol = col
impute_rule = imprules[col]

try:
Expand All @@ -703,13 +747,17 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None):
% (impute_rule.get("type", ""), col)
) from err

imputer = imputer(column=col, partitionby=partitionby, **impute_rule)
imputer = imputer(column=col, column_base_for_impflag=impflag_basecol, partitionby=partitionby, **impute_rule)

query += "\n,%s" % imputer.to_sql()
if not imputer.noflag:
# Add an imputation flag for non-categorical columns (this is handeled
# for categorical columns with a separate NULL category)
query += "\n,%s" % imputer.imputed_flag_sql()
# but only add it if another functionally equivalent impflag hasn't already been added
impflag_select, impflag_alias = imputer.imputed_flag_select_and_alias()
if impflag_alias not in used_impflags:
used_impflags.add(impflag_alias)
query += "\n,%s as \"%s\" " % (impflag_select, impflag_alias)

return query

Expand Down
Loading