Skip to content

Commit

Permalink
Error on cohort or label duplicates (#889)
Browse files Browse the repository at this point in the history
* check for duplicates when generating cohort and labels

* add unit tests

* debug

* multiple database_reflection files...

* add import

* consolidate database_reflection
  • Loading branch information
shaycrk authored Apr 18, 2022
1 parent be04c9f commit 335f264
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 13 deletions.
28 changes: 28 additions & 0 deletions src/tests/architect_tests/test_label_generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import date, timedelta

import testing.postgresql
import pytest
from sqlalchemy import create_engine

from triage.component.architect.label_generators import LabelGenerator
Expand Down Expand Up @@ -169,3 +170,30 @@ def test_generate_all_labels_noreplace():
(4, date(2014, 9, 30), timedelta(90), "outcome", "binary", False),
]
assert records == expected


def test_generate_all_labels_errors_on_duplicates():

# label query that will yield duplicates (one row for each event in the timespan)
BAD_LABEL_GENERATE_QUERY = """
select
events.entity_id,
1 as outcome
from events
where
'{as_of_date}' <= outcome_date
and outcome_date < '{as_of_date}'::timestamp + interval '{label_timespan}'
"""

with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
create_binary_outcome_events(engine, "events", events_data)

label_generator = LabelGenerator(db_engine=engine, query=BAD_LABEL_GENERATE_QUERY, replace=True)
with pytest.raises(ValueError):
label_generator.generate_all_labels(
labels_table=LABELS_TABLE_NAME,
as_of_dates=["2014-09-30", "2015-03-30"],
label_timespans=["6month", "3month"],
)

2 changes: 1 addition & 1 deletion src/tests/postmodeling_tests/test_add_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from triage.component.postmodeling.utils.add_predictions import add_predictions
from triage.component.architect.database_reflection import table_has_data
from triage.database_reflection import table_has_data


MODEL_IDS_QUERY = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from testing.postgresql import Postgresql
from unittest import TestCase

from triage.component.architect import database_reflection as dbreflect
import triage.database_reflection as dbreflect


class TestDatabaseReflection(TestCase):
Expand Down Expand Up @@ -40,6 +40,16 @@ def test_table_has_data(self):
assert dbreflect.table_has_data("compliments", self.engine)
assert not dbreflect.table_has_data("incidents", self.engine)

def test_table_has_duplicates(self):
self.engine.execute("create table events (col1 int, col2 int)")
assert not dbreflect.table_has_duplicates("events", ['col1', 'col2'], self.engine)
self.engine.execute("insert into events values (1,2)")
self.engine.execute("insert into events values (1,3)")
assert dbreflect.table_has_duplicates("events", ['col1'], self.engine)
assert not dbreflect.table_has_duplicates("events", ['col1', 'col2'], self.engine)
self.engine.execute("insert into events values (1,2)")
assert dbreflect.table_has_duplicates("events", ['col1', 'col2'], self.engine)

def test_table_has_column(self):
self.engine.execute("create table incidents (col1 varchar)")
assert dbreflect.table_has_column("incidents", "col1", self.engine)
Expand Down
10 changes: 8 additions & 2 deletions src/triage/component/architect/entity_date_table_generators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import verboselogs

from triage.component.architect.database_reflection import table_has_data
from triage.database_reflection import table_row_count, table_exists
from triage.database_reflection import table_has_data, table_row_count, table_exists, table_has_duplicates


logger = verboselogs.VerboseLogger(__name__)
Expand Down Expand Up @@ -55,6 +54,13 @@ def generate_entity_date_table(self, as_of_dates):
if not table_has_data(self.entity_date_table_name, self.db_engine):
raise ValueError(self._empty_table_message(as_of_dates))

if table_has_duplicates(
self.entity_date_table_name,
['entity_id', 'as_of_date'],
self.db_engine
):
raise ValueError(f"Duplicates found in {self.entity_date_table_name}!")

logger.debug(f"Entity-date table generated at {self.entity_date_table_name}")
logger.spam(f"Generating stats on {self.entity_date_table_name}")
logger.spam(f"Row count of {self.entity_date_table_name}: {table_row_count(self.entity_date_table_name, self.db_engine)}")
Expand Down
17 changes: 12 additions & 5 deletions src/triage/component/architect/label_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
logger = verboselogs.VerboseLogger(__name__)

import textwrap
from triage.database_reflection import table_row_count, table_exists
from triage.database_reflection import table_row_count, table_exists, table_has_duplicates

DEFAULT_LABEL_NAME = "outcome"

Expand Down Expand Up @@ -86,10 +86,17 @@ def generate_all_labels(self, labels_table, as_of_dates, label_timespans):

if nrows == 0:
logger.warning(f"Done creating labels, but no rows in {labels_table} table!")
raise ValueError(f"{label_table} is empty!")
else:
logger.debug(f"Labels table generated at {labels_table}")
logger.spam(f"Row count of {labels_table}: {nrows}")
raise ValueError(f"{labels_table} is empty!")

if table_has_duplicates(
labels_table,
['entity_id', 'as_of_date', 'label_timespan', 'label_name', 'label_type'],
self.db_engine
):
raise ValueError(f"Duplicates found in {labels_table}!")

logger.debug(f"Labels table generated at {labels_table}")
logger.spam(f"Row count of {labels_table}: {nrows}")

def generate(self, start_date, label_timespan, labels_table):
"""Generate labels table using a query
Expand Down
2 changes: 1 addition & 1 deletion src/triage/component/architect/validations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Functions for validating input, mostly around database schema and state"""
from triage.component.architect.database_reflection import (
from triage.database_reflection import (
table_exists,
table_has_column,
column_type,
Expand Down
35 changes: 32 additions & 3 deletions src/triage/database_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def table_has_data(table_name, db_engine):
"""
if not table_exists(table_name, db_engine):
return False
result = [
row for row in db_engine.execute("select 1 from {} limit 1".format(table_name))
results = [
row for row in db_engine.execute("select * from {} limit 1".format(table_name))
]

return any(result)
return len(results) > 0


def table_row_count(table_name, db_engine):
Expand All @@ -99,6 +99,35 @@ def table_row_count(table_name, db_engine):
)


def table_has_duplicates(table_name, column_list, db_engine):
"""Check whether the table has duplicate rows on the set of columns.
The table is expected to exist and contain the columns in column_list.
Args:
table_name (string) A table name (with schema)
column_list (list) A list of column names
db_engine (sqlalchemy.engine)
Returns: (boolean) Whether or not duplicates are found
"""
if not table_has_data(table_name, db_engine):
return False

cols = ','.join(['"%s"' % c for c in column_list])
sql = f"""
WITH counts AS (
SELECT {cols}
, COUNT(*) AS num_records
FROM {table_name}
GROUP BY {cols}
)
SELECT MAX(num_records) FROM counts
"""
result = next(db_engine.execute(sql))[0]
return result > 1


def table_has_column(table_name, column, db_engine):
"""Check whether the table contains a column of the given name
Expand Down

0 comments on commit 335f264

Please sign in to comment.