Skip to content

Commit

Permalink
Bugfix for obtaining cohort from labels table (#911)
Browse files Browse the repository at this point in the history
* remove default config filling

* add active status when using label table for cohort

* when creating cohort from labels don't insert for existing as_of_dates

* debug log string

* debug

* debug more

* fix error inserting from labels table
  • Loading branch information
shaycrk authored Oct 25, 2022
1 parent 971a872 commit 05d1dc1
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 52 deletions.
16 changes: 0 additions & 16 deletions src/tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,12 @@
from tests.utils import sample_config, populate_source_data
from triage.experiments.defaults import (
fill_timechop_config_missing,
fill_cohort_config_missing,
fill_feature_group_definition,
fill_model_grid_presets,
model_grid_preset,
)


def test_fill_cohort_config_missing():
config = sample_config()
config.pop('cohort_config')
cohort_config = fill_cohort_config_missing(config)
assert cohort_config == {
'query': "select distinct entity_id from "
"((select entity_id, as_of_date as knowledge_date from "
"(select * from cat_complaints) as t)\n union \n(select entity_id, "
"as_of_date as knowledge_date from (select * from entity_zip_codes "
"join zip_code_events using (zip_code)) as t)) as e "
"where knowledge_date < '{as_of_date}'",
'name': 'all_entities'
}


def test_fill_feature_group_definition():
config = sample_config()
fg_definition = fill_feature_group_definition(config)
Expand Down
43 changes: 37 additions & 6 deletions src/triage/component/architect/entity_date_table_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _create_and_populate_entity_date_table_from_query(self, as_of_dates):
"""
))
if len(any_existing) == 1:
logger.spam(f"Since >0 entity_date rows found for date {as_of_date}, skipping")
logger.notice(f"Since >0 entity_date rows found for date {as_of_date}, skipping")
continue
dated_query = self.query.format(as_of_date=formatted_date)
full_query = f"""insert into {self.entity_date_table_name}
Expand All @@ -129,16 +129,47 @@ def _create_and_populate_entity_date_table_from_labels(self):
logger.warning("Labels table does not exist, cannot populate entity-dates")
return

self.db_engine.execute(f"""insert into {self.entity_date_table_name}
select distinct entity_id, as_of_date
from {self.labels_table_name}
""")
# If any rows exist in the entity_date table, don't insert any for dates
# already in the table. This replicates the logic used above by
# _create_and_populate_entity_date_table_from_query
logger.spam(f"Looking for existing entity_date rows for label as of dates")
existing_dates = list(self.db_engine.execute(
f"""
with label_dates as (
select distinct as_of_date::DATE AS as_of_date FROM {self.labels_table_name}
)
, cohort_dates as (
select distinct as_of_date::DATE AS as_of_date FROM {self.entity_date_table_name}
)
select distinct l.as_of_date
from label_dates l
join cohort_dates c using(as_of_date)
"""
))
if len(existing_dates) > 0:
existing_dates = ', '.join(existing_dates)
logger.notice(f'Existing entity_dates records found for the following dates, '
f'so new records will not be inserted for these dates {existing_dates}')

insert_query = f"""
insert into {self.entity_date_table_name}
select distinct entity_id, as_of_date, true
from (
select distinct l.entity_id, l.as_of_date
from {self.labels_table_name} as l
left join (select distinct as_of_date from {self.entity_date_table_name}) as c
on l.as_of_date::DATE = c.as_of_date::DATE
where c.as_of_date IS NULL
) as sub
"""
logger.spam(f"Running entity_date query from labels table: {insert_query}")
self.db_engine.execute(insert_query)

def _empty_table_message(self, as_of_dates):
return """Query does not return any rows for the given as_of_dates:
{as_of_dates}
'{query}'""".format(
query=self.query,
query=self.query or "labels table",
as_of_dates=", ".join(
str(as_of_date)
for as_of_date in (
Expand Down
3 changes: 0 additions & 3 deletions src/triage/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@

from triage.experiments.defaults import (
fill_timechop_config_missing,
fill_cohort_config_missing,
fill_feature_group_definition,
fill_model_grid_presets,
)
Expand Down Expand Up @@ -225,8 +224,6 @@ def __init__(
self.config["temporal_config"] = fill_timechop_config_missing(
self.config, self.db_engine
)
## Defaults to all the entities found in the features_aggregation's from_obj
self.config["cohort_config"] = fill_cohort_config_missing(self.config)
## Defaults to all the feature_aggregation's prefixes
self.config["feature_group_definition"] = fill_feature_group_definition(
self.config
Expand Down
27 changes: 0 additions & 27 deletions src/triage/experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,6 @@ def fill_timechop_config_missing(config, db_engine):
return default_config


def fill_cohort_config_missing(config):
"""
If none cohort_config section is provided, include all the entities by default
Args:
config (dict) a triage experiment configuration
Returns: (dict) a triage cohort config
"""
from_query = "(select entity_id, {knowledge_date} as knowledge_date from (select * from {from_obj}) as t)"

feature_aggregations = config['feature_aggregations']

from_queries = [from_query.format(knowledge_date = agg['knowledge_date_column'], from_obj=agg['from_obj']) for agg in feature_aggregations]

unions = "\n union \n".join(from_queries)

query = f"select distinct entity_id from ({unions}) as e" +" where knowledge_date < '{as_of_date}'"

cohort_config = config.get('cohort_config', {})
default_config = {'query': query, 'name': 'all_entities'}

default_config.update(cohort_config)

return default_config


def fill_feature_group_definition(config):
"""
If feature_group_definition is not presents, this function sets it to all
Expand Down
3 changes: 3 additions & 0 deletions src/triage/experiments/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ def _run(self, label_config):
class CohortConfigValidator(Validator):
def _run(self, cohort_config):
logger.spam("Validating of cohort configuration")
if not cohort_config:
logger.debug("No cohort config specified, label config will be used instead")
return
if len(set(cohort_config.keys()).intersection({"query", "filepath"})) != 1:
raise ValueError(
dedent(
Expand Down

0 comments on commit 05d1dc1

Please sign in to comment.