Skip to content

Commit

Permalink
Starting to track some of the intermediate outputs to identify the ta…
Browse files Browse the repository at this point in the history
…rget outputs
  • Loading branch information
mmcdermott committed Aug 14, 2024
1 parent 9304bf7 commit c7d56a9
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
64 changes: 64 additions & 0 deletions tests/test_multi_stage_preprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@
TENSORIZATION_SCRIPT,
TOKENIZATION_SCRIPT,
multi_stage_transform_tester,
parse_shards_yaml,
)

MEDS_CODE_METADATA_FILE = """
code,description,parent_codes
EYE_COLOR//BLUE,"Blue Eyes. Less common than brown.",
EYE_COLOR//BROWN,"Brown Eyes. The most common eye color.",
EYE_COLOR//HAZEL,"Hazel eyes. These are uncommon",
HR,"Heart Rate",LOINC/8867-4
TEMP,"Body Temperature",LOINC/8310-5
"""

STAGE_CONFIG_YAML = """
filter_patients:
min_events_per_patient: 5
Expand All @@ -56,6 +66,59 @@
- "values/sum_sqd"
"""

# After filtering out patients with fewer than 5 events:
POST_FILTER_YAML = parse_shards_yaml("""
"filter_patients/train/0": |-2
patient_id,time,code,numeric_value
239684,,EYE_COLOR//BROWN,
239684,,HEIGHT,175.271115221764
239684,"12/28/1980, 00:00:00",DOB,
239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,
239684,"05/11/2010, 17:41:51",HR,102.6
239684,"05/11/2010, 17:41:51",TEMP,96.0
239684,"05/11/2010, 17:48:48",HR,105.1
239684,"05/11/2010, 17:48:48",TEMP,96.2
239684,"05/11/2010, 18:25:35",HR,113.4
239684,"05/11/2010, 18:25:35",TEMP,95.8
239684,"05/11/2010, 18:57:18",HR,112.6
239684,"05/11/2010, 18:57:18",TEMP,95.5
239684,"05/11/2010, 19:27:19",DISCHARGE,
1195293,,EYE_COLOR//BLUE,
1195293,,HEIGHT,164.6868838269085
1195293,"06/20/1978, 00:00:00",DOB,
1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,
1195293,"06/20/2010, 19:23:52",HR,109.0
1195293,"06/20/2010, 19:23:52",TEMP,100.0
1195293,"06/20/2010, 19:25:32",HR,114.1
1195293,"06/20/2010, 19:25:32",TEMP,100.0
1195293,"06/20/2010, 19:45:19",HR,119.8
1195293,"06/20/2010, 19:45:19",TEMP,99.9
1195293,"06/20/2010, 20:12:31",HR,112.5
1195293,"06/20/2010, 20:12:31",TEMP,99.8
1195293,"06/20/2010, 20:24:44",HR,107.7
1195293,"06/20/2010, 20:24:44",TEMP,100.0
1195293,"06/20/2010, 20:41:33",HR,107.5
1195293,"06/20/2010, 20:41:33",TEMP,100.4
1195293,"06/20/2010, 20:50:04",DISCHARGE,
"filter_patients/train/1": |-2
patient_id,time,code,numeric_value
"filter_patients/tuning/0": |-2
patient_id,time,code,numeric_value
"filter_patients/held_out/0": |-2
patient_id,time,code,numeric_value
1500733,,EYE_COLOR//BROWN,
1500733,,HEIGHT,158.60131573580904
1500733,"07/20/1986, 00:00:00",DOB,
1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,
1500733,"06/03/2010, 14:54:38",HR,91.4
1500733,"06/03/2010, 14:54:38",TEMP,100.0
1500733,"06/03/2010, 15:39:49",HR,84.4
1500733,"06/03/2010, 15:39:49",TEMP,100.3
1500733,"06/03/2010, 16:20:49",HR,90.1
1500733,"06/03/2010, 16:20:49",TEMP,100.1
1500733,"06/03/2010, 16:44:26",DISCHARGE,
""")

WANT_NRTs = {
"train/1.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
"tuning/0.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
Expand Down Expand Up @@ -88,4 +151,5 @@ def test_pipeline():
],
stage_configs=STAGE_CONFIG_YAML,
want_data=WANT_NRTs,
input_code_metadata=MEDS_CODE_METADATA_FILE,
)
6 changes: 5 additions & 1 deletion tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import rootutils
from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict

from .utils import assert_df_equal, parse_meds_csvs, run_command
from .utils import assert_df_equal, parse_meds_csvs, run_command, MEDS_PL_SCHEMA

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

Expand Down Expand Up @@ -201,6 +201,10 @@
"code/vocab_index": pl.UInt8,
}

def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame:
schema = {**MEDS_PL_SCHEMA, **schema_updates}
return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema)


def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame:
cols = csv_str.strip().split("\n")[0].split(",")
Expand Down

0 comments on commit c7d56a9

Please sign in to comment.