Skip to content

Commit

Permalink
Added test for final metadata stage
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 15, 2024
1 parent 6114c42 commit 3c5c7d9
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions tests/test_multi_stage_preprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,99 @@
).sort(by="code")
}

# As the last metadata stage, this gets a special directory.
WANT_FIT_VOCABULARY_INDICES = {
"metadata/codes.parquet": pl.DataFrame(
{
"code": [
"EYE_COLOR//BLUE",
"EYE_COLOR//BROWN",
"HR",
"TEMP",
"AGE",
"HEIGHT",
"TIME_OF_DAY//[18,24)",
"TIME_OF_DAY//[12,18)",
"TIME_OF_DAY//[00,06)",
"ADMISSION//CARDIAC",
"DISCHARGE",
"DOB",
],
"code/vocab_index": [5, 6, 8, 9, 2, 7, 12, 11, 10, 1, 3, 4],
"code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2],
"code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
"values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0],
"values/sum": [
0.0,
0.0,
776.7999877929688,
600.1000366210938,
224.0208376784967,
0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
"values/sum_sqd": [
0.0,
0.0,
86249.921875,
60020.21484375,
7169.33349609375,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
"description": [
"Blue Eyes. Less common than brown.",
"Brown Eyes. The most common eye color.",
"Heart Rate",
"Body Temperature",
None,
None,
None,
None,
None,
None,
None,
None,
],
"parent_codes": [
None,
None,
["LOINC/8867-4"],
["LOINC/8310-5"],
None,
None,
None,
None,
None,
None,
None,
None,
],
},
schema={
"code": pl.String,
"description": pl.String,
"parent_codes": pl.List(pl.String),
"code/n_occurrences": pl.UInt8,
"code/n_patients": pl.UInt8,
"code/vocab_index": pl.UInt8,
"values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex.
"values/sum": pl.Float32,
"values/sum_sqd": pl.Float32,
},
).sort(by="code")
}


WANT_NRTs = {
"data/train/1.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
Expand Down Expand Up @@ -637,6 +730,7 @@ def test_pipeline():
want_metadata={
**WANT_FIT_OUTLIERS,
**WANT_FIT_NORMALIZATION,
**WANT_FIT_VOCABULARY_INDICES,
},
want_data={
**WANT_FILTER,
Expand Down

0 comments on commit 3c5c7d9

Please sign in to comment.