Skip to content

Commit

Permalink
added integration tests for tokenization alignment fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Oufattole committed Sep 1, 2024
1 parent 82589d8 commit 8724aa0
Showing 1 changed file with 83 additions and 1 deletion.
84 changes: 83 additions & 1 deletion tests/MEDS_Transforms/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

from tests.MEDS_Transforms import TOKENIZATION_SCRIPT

from .test_normalization import NORMALIZED_MEDS_SCHEMA
from ..utils import parse_meds_csvs
from .test_normalization import NORMALIZED_MEDS_SCHEMA, WANT_HELD_OUT_0
from .test_normalization import WANT_SHARDS as NORMALIZED_SHARDS
from .test_normalization import WANT_TRAIN_1, WANT_TUNING_0
from .transform_tester_base import single_stage_transform_tester

SECONDS_PER_DAY = 60 * 60 * 24
Expand Down Expand Up @@ -77,6 +79,17 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]:
schema=SCHEMAS_SCHEMA,
)

WANT_SCHEMAS_TRAIN_0_MISSING_STATIC = pl.DataFrame(
{
"subject_id": [239684, 1195293],
"code": [None, [6, 9]],
"numeric_value": [None, [None, 0.06802856922149658]],
"start_time": [ts[0] for ts in TRAIN_0_TIMES],
"time": TRAIN_0_TIMES,
},
schema=SCHEMAS_SCHEMA,
)

WANT_EVENT_SEQ_TRAIN_0 = pl.DataFrame(
{
"subject_id": [239684, 1195293],
Expand Down Expand Up @@ -211,13 +224,62 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]:
"schemas/held_out/0": WANT_SCHEMAS_HELD_OUT_0,
}

WANT_SCHEMAS_MISSING_STATIC = {
"schemas/train/0": WANT_SCHEMAS_TRAIN_0_MISSING_STATIC,
"schemas/train/1": WANT_SCHEMAS_TRAIN_1,
"schemas/tuning/0": WANT_SCHEMAS_TUNING_0,
"schemas/held_out/0": WANT_SCHEMAS_HELD_OUT_0,
}

WANT_EVENT_SEQS = {
"event_seqs/train/0": WANT_EVENT_SEQ_TRAIN_0,
"event_seqs/train/1": WANT_EVENT_SEQ_TRAIN_1,
"event_seqs/tuning/0": WANT_EVENT_SEQ_TUNING_0,
"event_seqs/held_out/0": WANT_EVENT_SEQ_HELD_OUT_0,
}

WANT_TRAIN_0 = """
subject_id,time,code,numeric_value
239684,"12/28/1980, 00:00:00",5,
239684,"05/11/2010, 17:41:51",1,
239684,"05/11/2010, 17:41:51",10,-0.569736897945404
239684,"05/11/2010, 17:41:51",11,-1.2714673280715942
239684,"05/11/2010, 17:48:48",10,-0.43754738569259644
239684,"05/11/2010, 17:48:48",11,-1.168027639389038
239684,"05/11/2010, 18:25:35",10,0.001321975840255618
239684,"05/11/2010, 18:25:35",11,-1.37490713596344
239684,"05/11/2010, 18:57:18",10,-0.04097883030772209
239684,"05/11/2010, 18:57:18",11,-1.5300706624984741
239684,"05/11/2010, 19:27:19",4,
1195293,,6,
1195293,,9,0.06802856922149658
1195293,"06/20/1978, 00:00:00",5,
1195293,"06/20/2010, 19:23:52",1,
1195293,"06/20/2010, 19:23:52",10,-0.23133166134357452
1195293,"06/20/2010, 19:23:52",11,0.7973587512969971
1195293,"06/20/2010, 19:25:32",10,0.03833488002419472
1195293,"06/20/2010, 19:25:32",11,0.7973587512969971
1195293,"06/20/2010, 19:45:19",10,0.3397272229194641
1195293,"06/20/2010, 19:45:19",11,0.745638906955719
1195293,"06/20/2010, 20:12:31",10,-0.046266332268714905
1195293,"06/20/2010, 20:12:31",11,0.6939190626144409
1195293,"06/20/2010, 20:24:44",10,-0.3000703752040863
1195293,"06/20/2010, 20:24:44",11,0.7973587512969971
1195293,"06/20/2010, 20:41:33",10,-0.31064537167549133
1195293,"06/20/2010, 20:41:33",11,1.004242181777954
1195293,"06/20/2010, 20:50:04",4,
"""

NORMALIZED_SHARDS_MISSING_STATIC = parse_meds_csvs(
{
"train/0": WANT_TRAIN_0,
"train/1": WANT_TRAIN_1,
"tuning/0": WANT_TUNING_0,
"held_out/0": WANT_HELD_OUT_0,
},
schema=NORMALIZED_MEDS_SCHEMA,
)


def test_tokenization():
single_stage_transform_tester(
Expand All @@ -237,3 +299,23 @@ def test_tokenization():
want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS},
should_error=True,
)


def test_tokenization_missing_static():
single_stage_transform_tester(
transform_script=TOKENIZATION_SCRIPT,
stage_name="tokenization",
transform_stage_kwargs=None,
input_shards=NORMALIZED_SHARDS_MISSING_STATIC,
want_data={**WANT_SCHEMAS_MISSING_STATIC, **WANT_EVENT_SEQS},
df_check_kwargs={"check_column_order": False},
)

single_stage_transform_tester(
transform_script=TOKENIZATION_SCRIPT,
stage_name="tokenization",
transform_stage_kwargs={"train_only": True},
input_shards=NORMALIZED_SHARDS_MISSING_STATIC,
want_data={**WANT_SCHEMAS_MISSING_STATIC, **WANT_EVENT_SEQS},
should_error=True,
)

0 comments on commit 8724aa0

Please sign in to comment.