Skip to content

Commit

Permalink
Added capability to test intermediate files and started validating fu…
Browse files Browse the repository at this point in the history
…ll correctness.
  • Loading branch information
mmcdermott committed Aug 14, 2024
1 parent c7d56a9 commit 210a9ce
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 15 deletions.
268 changes: 262 additions & 6 deletions tests/test_multi_stage_preprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@
min_events_per_patient: 5
add_time_derived_measurements:
age:
DOB_code: "DOB"
DOB_code: "DOB" # This is the MEDS official code for BIRTH
age_code: "AGE"
age_unit: "years"
time_of_day:
time_of_day_code: "TIME_OF_DAY"
endpoints: [6, 12, 18, 24]
fit_outlier_detection:
aggregations:
- "values/n_occurrences"
Expand All @@ -67,7 +70,8 @@
"""

# After filtering out patients with fewer than 5 events:
POST_FILTER_YAML = parse_shards_yaml("""
WANT_POST_FILTER = parse_shards_yaml(
"""
"filter_patients/train/0": |-2
patient_id,time,code,numeric_value
239684,,EYE_COLOR//BROWN,
Expand Down Expand Up @@ -100,10 +104,13 @@
1195293,"06/20/2010, 20:41:33",HR,107.5
1195293,"06/20/2010, 20:41:33",TEMP,100.4
1195293,"06/20/2010, 20:50:04",DISCHARGE,
"filter_patients/train/1": |-2
patient_id,time,code,numeric_value
"filter_patients/tuning/0": |-2
patient_id,time,code,numeric_value
"filter_patients/held_out/0": |-2
patient_id,time,code,numeric_value
1500733,,EYE_COLOR//BROWN,
Expand All @@ -117,11 +124,254 @@
1500733,"06/03/2010, 16:20:49",HR,90.1
1500733,"06/03/2010, 16:20:49",TEMP,100.1
1500733,"06/03/2010, 16:44:26",DISCHARGE,
""")
"""
)

WANT_POST_TIME_DERIVED = parse_shards_yaml(
"""
"add_time_derived_measurements/train/0": |-2
patient_id,time,code,numeric_value
239684,,EYE_COLOR//BROWN,
239684,,HEIGHT,175.271115221764
239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",
239684,"12/28/1980, 00:00:00",DOB,
239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)",
239684,"05/11/2010, 17:41:51",AGE,29.36883360091833
239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,
239684,"05/11/2010, 17:41:51",HR,102.6
239684,"05/11/2010, 17:41:51",TEMP,96.0
239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)",
239684,"05/11/2010, 17:48:48",AGE,29.36884681513314
239684,"05/11/2010, 17:48:48",HR,105.1
239684,"05/11/2010, 17:48:48",TEMP,96.2
239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)",
239684,"05/11/2010, 18:25:35",AGE,29.36891675223647
239684,"05/11/2010, 18:25:35",HR,113.4
239684,"05/11/2010, 18:25:35",TEMP,95.8
239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)",
239684,"05/11/2010, 18:57:18",AGE,29.36897705595538
239684,"05/11/2010, 18:57:18",HR,112.6
239684,"05/11/2010, 18:57:18",TEMP,95.5
239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)",
239684,"05/11/2010, 19:27:19",AGE,29.369034127420306
239684,"05/11/2010, 19:27:19",DISCHARGE,
1195293,,EYE_COLOR//BLUE,
1195293,,HEIGHT,164.6868838269085
1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)",
1195293,"06/20/1978, 00:00:00",DOB,
1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265
1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,
1195293,"06/20/2010, 19:23:52",HR,109.0
1195293,"06/20/2010, 19:23:52",TEMP,100.0
1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172
1195293,"06/20/2010, 19:25:32",HR,114.1
1195293,"06/20/2010, 19:25:32",TEMP,100.0
1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522
1195293,"06/20/2010, 19:45:19",HR,119.8
1195293,"06/20/2010, 19:45:19",TEMP,99.9
1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945
1195293,"06/20/2010, 20:12:31",HR,112.5
1195293,"06/20/2010, 20:12:31",TEMP,99.8
1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335
1195293,"06/20/2010, 20:24:44",HR,107.7
1195293,"06/20/2010, 20:24:44",TEMP,100.0
1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765
1195293,"06/20/2010, 20:41:33",HR,107.5
1195293,"06/20/2010, 20:41:33",TEMP,100.4
1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544
1195293,"06/20/2010, 20:50:04",DISCHARGE,
"add_time_derived_measurements/train/1": |-2
patient_id,time,code,numeric_value
"add_time_derived_measurements/tuning/0": |-2
patient_id,time,code,numeric_value
"add_time_derived_measurements/held_out/0": |-2
patient_id,time,code,numeric_value
1500733,,EYE_COLOR//BROWN,
1500733,,HEIGHT,158.60131573580904
1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",
1500733,"07/20/1986, 00:00:00",DOB,
1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)",
1500733,"06/03/2010, 14:54:38",AGE,23.873531791091356
1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,
1500733,"06/03/2010, 14:54:38",HR,91.4
1500733,"06/03/2010, 14:54:38",TEMP,100.0
1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)",
1500733,"06/03/2010, 15:39:49",AGE,23.873617699332012
1500733,"06/03/2010, 15:39:49",HR,84.4
1500733,"06/03/2010, 15:39:49",TEMP,100.3
1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)",
1500733,"06/03/2010, 16:20:49",AGE,23.873695653692767
1500733,"06/03/2010, 16:20:49",HR,90.1
1500733,"06/03/2010, 16:20:49",TEMP,100.1
1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)",
1500733,"06/03/2010, 16:44:26",AGE,23.873740556672114
1500733,"06/03/2010, 16:44:26",DISCHARGE,
"""
)

FIT_OUTLIERS_NEW_METADATA = """
>>> import polars as pl
>>> VALS = pl.col("numeric_value").drop_nulls().drop_nans()
>>> post_outliers = (
... pl.concat(POST_TIME_DERIVED_YAML.values(), how='vertical')
... .group_by("code")
... .agg(
... VALS.len().alias("values/n_occurrences"),
... VALS.sum().alias("values/sum"),
... (VALS**2).sum().alias("values/sum_sqd")
... )
... .filter(pl.col("values/n_occurrences") > 0)
... )
>>> post_outliers
┌────────┬──────────────────────┬─────────────┬────────────────┐
│ code ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ u32 ┆ f32 ┆ f32 │
╞════════╪══════════════════════╪═════════════╪════════════════╡
│ HR ┆ 13 ┆ 1370.200073 ┆ 145770.0625 │
│ TEMP ┆ 13 ┆ 1284.000122 ┆ 126868.632812 │
│ AGE ┆ 16 ┆ 466.360046 ┆ 13761.804688 │
│ HEIGHT ┆ 3 ┆ 498.559326 ┆ 82996.109375 │
└────────┴──────────────────────┴─────────────┴────────────────┘
# This implies the following means and standard deviations
>>> mean_col = pl.col("values/sum") / pl.col("values/n_occurrences")
>>> stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5
>>> post_outliers.select("code", mean_col.alias("values/mean"), stddev_col.alias("values/std"))
shape: (4, 3)
┌────────┬─────────────┬────────────┐
│ code ┆ values/mean ┆ values/std │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞════════╪═════════════╪════════════╡
│ AGE ┆ 29.147503 ┆ 3.2459 │
│ HR ┆ 105.400006 ┆ 10.194143 │
│ TEMP ┆ 98.76924 ┆ 1.939794 │
│ HEIGHT ┆ 166.186442 ┆ 6.887399 │
└────────┴─────────────┴────────────┘
>>> post_outliers.select(
... "code",
... (mean_col + stddev_col).alias("values/inlier_upper_bound"),
... (mean_col - stddev_col).alias("values/inlier_lower_bound")
... )
shape: (4, 3)
┌────────┬───────────────────────────┬───────────────────────────┐
│ code ┆ values/inlier_upper_bound ┆ values/inlier_lower_bound │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞════════╪═══════════════════════════╪═══════════════════════════╡
│ AGE ┆ 32.393403 ┆ 25.901603 │
│ HR ┆ 115.594148 ┆ 95.205863 │
│ TEMP ┆ 100.709034 ┆ 96.829447 │
│ HEIGHT ┆ 173.073841 ┆ 159.299043 │
└────────┴───────────────────────────┴───────────────────────────┘
"""

WANT_POST_OCCLUDE_OUTLIERS = parse_shards_yaml(
"""
"occlude_outliers/train/0": |-2
patient_id,time,code,numeric_value,numeric_value/is_inlier
239684,,EYE_COLOR//BROWN,,
239684,,HEIGHT,,false
239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",,
239684,"12/28/1980, 00:00:00",DOB,,
239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)",,
239684,"05/11/2010, 17:41:51",AGE,29.36883360091833,true
239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,,
239684,"05/11/2010, 17:41:51",HR,102.6,true
239684,"05/11/2010, 17:41:51",TEMP,,false
239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)",,
239684,"05/11/2010, 17:48:48",AGE,29.36884681513314,true
239684,"05/11/2010, 17:48:48",HR,105.1,true
239684,"05/11/2010, 17:48:48",TEMP,,false
239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)",,
239684,"05/11/2010, 18:25:35",AGE,29.36891675223647,true
239684,"05/11/2010, 18:25:35",HR,113.4,true
239684,"05/11/2010, 18:25:35",TEMP,,false
239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)",,
239684,"05/11/2010, 18:57:18",AGE,29.36897705595538,true
239684,"05/11/2010, 18:57:18",HR,112.6,true
239684,"05/11/2010, 18:57:18",TEMP,,false
239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)",,
239684,"05/11/2010, 19:27:19",AGE,29.369034127420306,true
239684,"05/11/2010, 19:27:19",DISCHARGE,,
1195293,,EYE_COLOR//BLUE,,
1195293,,HEIGHT,164.6868838269085,true
1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)",,
1195293,"06/20/1978, 00:00:00",DOB,,
1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)",,
1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265,true
1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,,
1195293,"06/20/2010, 19:23:52",HR,109.0,true
1195293,"06/20/2010, 19:23:52",TEMP,100.0,true
1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)",,
1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172,true
1195293,"06/20/2010, 19:25:32",HR,114.1,true
1195293,"06/20/2010, 19:25:32",TEMP,100.0,true
1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)",,
1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522,true
1195293,"06/20/2010, 19:45:19",HR,,false
1195293,"06/20/2010, 19:45:19",TEMP,99.9,true
1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)",,
1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945,true
1195293,"06/20/2010, 20:12:31",HR,112.5,true
1195293,"06/20/2010, 20:12:31",TEMP,99.8,true
1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)",
1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335,true
1195293,"06/20/2010, 20:24:44",HR,107.7,true
1195293,"06/20/2010, 20:24:44",TEMP,100.0,true
1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)",,
1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765,true
1195293,"06/20/2010, 20:41:33",HR,107.5,true
1195293,"06/20/2010, 20:41:33",TEMP,100.4,true
1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)",,
1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544,true
1195293,"06/20/2010, 20:50:04",DISCHARGE,,
"occlude_outliers/train/1": |-2
patient_id,time,code,numeric_value
"occlude_outliers/tuning/0": |-2
patient_id,time,code,numeric_value
"occlude_outliers/held_out/0": |-2
patient_id,time,code,numeric_value,numeric_value/is_inlier
1500733,,EYE_COLOR//BROWN,,
1500733,,HEIGHT,,false
1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",,
1500733,"07/20/1986, 00:00:00",DOB,,
1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)",,
1500733,"06/03/2010, 14:54:38",AGE,,false
1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,,
1500733,"06/03/2010, 14:54:38",HR,,false
1500733,"06/03/2010, 14:54:38",TEMP,100.0,true
1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)",,
1500733,"06/03/2010, 15:39:49",AGE,,false
1500733,"06/03/2010, 15:39:49",HR,,false
1500733,"06/03/2010, 15:39:49",TEMP,100.3,true
1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)",,
1500733,"06/03/2010, 16:20:49",AGE,,false
1500733,"06/03/2010, 16:20:49",HR,,false
1500733,"06/03/2010, 16:20:49",TEMP,100.1,true
1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)",,
1500733,"06/03/2010, 16:44:26",AGE,,false
1500733,"06/03/2010, 16:44:26",DISCHARGE,,
"""
)


WANT_NRTs = {
"train/1.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
"tuning/0.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
"data/train/1.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
"data/tuning/0.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
}


Expand Down Expand Up @@ -150,6 +400,12 @@ def test_pipeline():
"tensorization",
],
stage_configs=STAGE_CONFIG_YAML,
want_data=WANT_NRTs,
want_data={
**WANT_POST_FILTER,
**WANT_POST_TIME_DERIVED,
**WANT_POST_OCCLUDE_OUTLIERS,
**WANT_NRTs,
},
outputs_from_cohort_dir=True,
input_code_metadata=MEDS_CODE_METADATA_FILE,
)
29 changes: 22 additions & 7 deletions tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader

import json
import os
import tempfile
Expand All @@ -23,7 +24,7 @@
import rootutils
from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict

from .utils import assert_df_equal, parse_meds_csvs, run_command, MEDS_PL_SCHEMA
from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

Expand Down Expand Up @@ -201,6 +202,7 @@
"code/vocab_index": pl.UInt8,
}


def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame:
schema = {**MEDS_PL_SCHEMA, **schema_updates}
return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema)
Expand Down Expand Up @@ -350,21 +352,28 @@ def input_MEDS_dataset(
def check_outputs(
cohort_dir: Path,
want_data: dict[str, pl.DataFrame] | None = None,
want_metadata: pl.DataFrame | None = None,
want_metadata: dict[str, pl.DataFrame] | pl.DataFrame | None = None,
assert_no_other_outputs: bool = True,
outputs_from_cohort_dir: bool = False,
):
if want_metadata is not None:
cohort_metadata_dir = cohort_dir / "metadata"
check_df_output(cohort_metadata_dir / "codes.parquet", want_metadata)
if isinstance(want_metadata, pl.DataFrame):
want_metadata = {"codes.parquet": want_metadata}
metadata_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "metadata"
for shard_name, want in want_metadata.items():
if Path(shard_name).suffix == "":
shard_name = f"{shard_name}.parquet"
check_df_output(metadata_root / shard_name, want)

if want_data:
data_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "data"
for shard_name, want in want_data.items():
if Path(shard_name).suffix == "":
shard_name = f"{shard_name}.parquet"

file_suffix = Path(shard_name).suffix

output_fp = cohort_dir / "data" / f"{shard_name}"
output_fp = data_root / f"{shard_name}"
if file_suffix == ".parquet":
check_df_output(output_fp, want)
elif file_suffix == ".nrt":
Expand All @@ -373,7 +382,7 @@ def check_outputs(
raise ValueError(f"Unknown file suffix: {file_suffix}")

if assert_no_other_outputs:
all_outputs = list((cohort_dir / "data").glob(f"**/*{file_suffix}"))
all_outputs = list((data_root).glob(f"**/*{file_suffix}"))
assert len(want_data) == len(all_outputs), (
f"Expected {len(want_data)} outputs, but found {len(all_outputs)}.\n"
f"Found outputs: {[fp.relative_to(cohort_dir/'data') for fp in all_outputs]}\n"
Expand Down Expand Up @@ -434,6 +443,7 @@ def multi_stage_transform_tester(
do_pass_stage_name: bool | dict[str, bool] = True,
want_data: dict[str, pl.DataFrame] | None = None,
want_metadata: pl.DataFrame | None = None,
outputs_from_cohort_dir: bool = True,
**input_data_kwargs,
):
with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir):
Expand Down Expand Up @@ -478,4 +488,9 @@ def multi_stage_transform_tester(
do_pass_stage_name=do_pass_stage_name[stage],
)

check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata)
check_outputs(
cohort_dir,
want_data=want_data,
want_metadata=want_metadata,
outputs_from_cohort_dir=outputs_from_cohort_dir,
)
Loading

0 comments on commit 210a9ce

Please sign in to comment.