Skip to content

Commit

Permalink
Changed to 'categorical_value' and added better doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 7, 2024
1 parent 6d6b318 commit 10fedf0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/MEDS_transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
"time": pl.Datetime("us"),
"code": pl.String,
"numeric_value": pl.Float32,
"categoric_value": pl.String,
"categorical_value": pl.String,
"text_value": pl.String,
}

DEPRECATED_NAMES = {
"numerical_value": "numeric_value",
"categorical_value": "categoric_value",
"category_value": "categoric_value",
"category_value": "categorical_value",
"categoric_value": "categorical_value",
"textual_value": "text_value",
"timestamp": "time",
"subject_id": "patient_id",
Expand Down
57 changes: 45 additions & 12 deletions src/MEDS_transforms/transforms/extract_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,56 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La
values.
Examples:
>>> stage_cfg = {"numeric_value": "foo", "categoric_value": "bar"}
>>> stage_cfg = {"numeric_value": "foo", "categorical_value": "bar"}
>>> fn = extract_values_fntr(stage_cfg)
>>> df = pl.DataFrame({
... "patient_id": [1, 1, 1], "time": [1, 2, 3],
... "foo": ["1", "2", "3"], "bar": [1.0, 2.0, 4.0],
... })
>>> fn(df)
shape: (3, 6)
┌────────────┬──────┬─────┬─────┬───────────────┬─────────────────┐
│ patient_id ┆ time ┆ foo ┆ bar ┆ numeric_value ┆ categoric_value │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str ┆ f64 ┆ f32 ┆ str │
╞════════════╪══════╪═════╪═════╪═══════════════╪═════════════════╡
│ 1 ┆ 1 ┆ 1 ┆ 1.0 ┆ 1.0 ┆ 1.0 │
│ 1 ┆ 2 ┆ 2 ┆ 2.0 ┆ 2.0 ┆ 2.0 │
│ 1 ┆ 3 ┆ 3 ┆ 4.0 ┆ 3.0 ┆ 4.0 │
└────────────┴──────┴─────┴─────┴───────────────┴─────────────────┘
┌────────────┬──────┬─────┬─────┬───────────────┬───────────────────┐
│ patient_id ┆ time ┆ foo ┆ bar ┆ numeric_value ┆ categorical_value │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str ┆ f64 ┆ f32 ┆ str │
╞════════════╪══════╪═════╪═════╪═══════════════╪═══════════════════╡
│ 1 ┆ 1 ┆ 1 ┆ 1.0 ┆ 1.0 ┆ 1.0 │
│ 1 ┆ 2 ┆ 2 ┆ 2.0 ┆ 2.0 ┆ 2.0 │
│ 1 ┆ 3 ┆ 3 ┆ 4.0 ┆ 3.0 ┆ 4.0 │
└────────────┴──────┴─────┴─────┴───────────────┴───────────────────┘
>>> stage_cfg = {32: "foo"}
>>> fn = extract_values_fntr(stage_cfg)
Traceback (most recent call last):
...
ValueError: Invalid column name: 32
>>> stage_cfg = {"numeric_value": {"lit": 1}}
>>> fn = extract_values_fntr(stage_cfg)
Traceback (most recent call last):
...
ValueError: Error building expression for numeric_value...
>>> stage_cfg = {"numeric_value": "foo", "categorical_value": "bar"}
>>> fn = extract_values_fntr(stage_cfg)
>>> df = pl.DataFrame({"patient_id": [1, 1, 1], "time": [1, 2, 3]})
>>> fn(df)
Traceback (most recent call last):
...
ValueError: Missing columns: ['bar', 'foo']
Note that deprecated column names like "numerical_value" or "timestamp" won't be re-typed.
>>> stage_cfg = {"numerical_value": "foo"}
>>> fn = extract_values_fntr(stage_cfg)
>>> df = pl.DataFrame({"patient_id": [1, 1, 1], "time": [1, 2, 3], "foo": ["1", "2", "3"]})
>>> fn(df)
shape: (3, 4)
┌────────────┬──────┬─────┬─────────────────┐
│ patient_id ┆ time ┆ foo ┆ numerical_value │
│ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str ┆ str │
╞════════════╪══════╪═════╪═════════════════╡
│ 1 ┆ 1 ┆ 1 ┆ 1 │
│ 1 ┆ 2 ┆ 2 ┆ 2 │
│ 1 ┆ 3 ┆ 3 ┆ 3 │
└────────────┴──────┴─────┴─────────────────┘
"""

new_cols = []
Expand All @@ -53,7 +86,7 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La
try:
expr, cols = cfg_to_expr(value_cfg)
except ValueError as e:
raise ValueError(f"Error in {out_col_n}") from e
raise ValueError(f"Error building expression for {out_col_n}") from e

match out_col_n:
case str() if out_col_n in MANDATORY_TYPES:
Expand All @@ -78,7 +111,7 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La
def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame:
in_cols = set(df.collect_schema().names())
if not need_cols.issubset(in_cols):
raise ValueError(f"Missing columns: {need_cols - in_cols}")
raise ValueError(f"Missing columns: {sorted(list(need_cols - in_cols))}")

return df.with_columns(new_cols).sort("patient_id", "time", maintain_order=True)

Expand Down

0 comments on commit 10fedf0

Please sign in to comment.