Skip to content

Commit

Permalink
Validated fit normalization stage as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 15, 2024
1 parent 8736279 commit 6114c42
Showing 1 changed file with 128 additions and 0 deletions.
128 changes: 128 additions & 0 deletions tests/test_multi_stage_preprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,133 @@
"""
)

FIT_NORMALIZATION_CODE = """
```python
>>> from tests.test_multi_stage_preprocess_pipeline import WANT_OCCLUDE_OUTLIERS as dfs
>>> import polars as pl
>>> VALS = pl.col("numeric_value").drop_nulls().drop_nans()
>>> post_transform = (
... dfs[next(k for k in dfs.keys() if k.endswith("/train/0"))]
... .group_by("code")
... .agg(
... pl.len().alias("code/n_occurrences"),
... pl.col("patient_id").n_unique().alias("code/n_patients"),
... VALS.len().alias("values/n_occurrences"),
... VALS.sum().alias("values/sum"),
... (VALS**2).sum().alias("values/sum_sqd")
... )
... )
>>> post_transform.filter(pl.col("values/n_occurrences") > 0)
shape: (3, 6)
┌──────┬────────────────────┬─────────────────┬──────────────────────┬────────────┬────────────────┐
│ code ┆ code/n_occurrences ┆ code/n_patients ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ u32 ┆ u32 ┆ u32 ┆ f32 ┆ f32 │
╞══════╪════════════════════╪═════════════════╪══════════════════════╪════════════╪════════════════╡
│ HR ┆ 10 ┆ 2 ┆ 7 ┆ 776.799988 ┆ 86249.921875 │
│ TEMP ┆ 10 ┆ 2 ┆ 6 ┆ 600.100037 ┆ 60020.214844 │
│ AGE ┆ 12 ┆ 2 ┆ 7 ┆ 224.020844 ┆ 7169.333496 │
└──────┴────────────────────┴─────────────────┴──────────────────────┴────────────┴────────────────┘
>>> print(post_transform.filter(pl.col("values/n_occurrences") > 0).to_dict(as_series=False))
{'code': ['HR', 'TEMP', 'AGE'],
'code/n_occurrences': [10, 10, 12],
'code/n_patients': [2, 2, 2],
'values/n_occurrences': [7, 6, 7],
'values/sum': [776.7999877929688, 600.1000366210938, 224.02084350585938],
'values/sum_sqd': [86249.921875, 60020.21484375, 7169.33349609375]}
"""

WANT_FIT_NORMALIZATION = {
"fit_normalization/codes.parquet": pl.DataFrame(
{
"code": [
"EYE_COLOR//BLUE",
"EYE_COLOR//BROWN",
"HR",
"TEMP",
"AGE",
"HEIGHT",
"TIME_OF_DAY//[18,24)",
"TIME_OF_DAY//[12,18)",
"TIME_OF_DAY//[00,06)",
"ADMISSION//CARDIAC",
"DISCHARGE",
"DOB",
],
"code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2],
"code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
"values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0],
"values/sum": [
0.0,
0.0,
776.7999877929688,
600.1000366210938,
224.0208376784967,
0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
"values/sum_sqd": [
0.0,
0.0,
86249.921875,
60020.21484375,
7169.33349609375,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
],
"description": [
"Blue Eyes. Less common than brown.",
"Brown Eyes. The most common eye color.",
"Heart Rate",
"Body Temperature",
None,
None,
None,
None,
None,
None,
None,
None,
],
"parent_codes": [
None,
None,
["LOINC/8867-4"],
["LOINC/8310-5"],
None,
None,
None,
None,
None,
None,
None,
None,
],
},
schema={
"code": pl.String,
"description": pl.String,
"parent_codes": pl.List(pl.String),
"code/n_occurrences": pl.UInt8,
"code/n_patients": pl.UInt8,
"values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex.
"values/sum": pl.Float32,
"values/sum_sqd": pl.Float32,
},
).sort(by="code")
}


WANT_NRTs = {
"data/train/1.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out.
Expand Down Expand Up @@ -509,6 +636,7 @@ def test_pipeline():
stage_configs=STAGE_CONFIG_YAML,
want_metadata={
**WANT_FIT_OUTLIERS,
**WANT_FIT_NORMALIZATION,
},
want_data={
**WANT_FILTER,
Expand Down

0 comments on commit 6114c42

Please sign in to comment.