Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release Candidate 0.0.6 #175

Merged
merged 67 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
d6061cf
Added an 'extract_values' transform.
mmcdermott Aug 6, 2024
5efbbcc
Updated mandatory types to be separate from mandatory MEDS columns
mmcdermott Aug 6, 2024
6d6b318
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 7, 2024
10fedf0
Changed to 'categorical_value' and added better doctests
mmcdermott Aug 7, 2024
c463548
Added an integration test for extract_values -- it needs to wait for …
mmcdermott Aug 7, 2024
f8b260d
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 8, 2024
6ca0054
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 8, 2024
26db680
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 9, 2024
f3da768
Merged dev
mmcdermott Aug 10, 2024
350a82e
Merged to dev
mmcdermott Aug 10, 2024
f7d7e27
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 12, 2024
2d59ff2
Exit if there is no _metadata in the event configs
prenc Aug 13, 2024
d217885
Use a more precise message
prenc Aug 13, 2024
1969ea1
Removed aggregate from default ETL and added dedicated test.
mmcdermott Aug 14, 2024
fc17317
Fixed a parent_codes schema error in vocabulary index
mmcdermott Aug 14, 2024
a751f72
Merge pull request #157 from mmcdermott/110_remove_agg_from_default_e…
mmcdermott Aug 14, 2024
d830d49
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 14, 2024
df3517d
Corrected a typo in a doctest.
mmcdermott Aug 14, 2024
87dde78
Everything but normalization succeeding -- I need to re-calculate exp…
mmcdermott Aug 14, 2024
089ae24
Updated aggregations stage config syntax in mapper_fntr test.
mmcdermott Aug 14, 2024
1fc2e4d
Added failing test to capture mapper_fntr bug with quantils when summ…
mmcdermott Aug 14, 2024
6b7a6b8
Added tests and all are passing.
mmcdermott Aug 14, 2024
9948bc9
Fixed normalization test.
mmcdermott Aug 14, 2024
31e3652
Fixed tokenization test
mmcdermott Aug 14, 2024
3cab512
Merge pull request #159 from mmcdermott/158_fix_typing_issue
mmcdermott Aug 14, 2024
017d126
Merge branch 'dev' into 163_aggregation_test_and_mapper_bug
mmcdermott Aug 14, 2024
497d411
Update type
mmcdermott Aug 14, 2024
49fb1ce
Merge pull request #166 from mmcdermott/163_aggregation_test_and_mapp…
mmcdermott Aug 14, 2024
9bb19bf
Separated input data preparation and renamed it to make tests cleaner…
mmcdermott Aug 14, 2024
9cf0bc4
Making stage-name always be passed via the command line if it is used.
mmcdermott Aug 14, 2024
d5b0782
Updated test output syntax to support data and metadata checking and …
mmcdermott Aug 14, 2024
138eb1e
Added a multi-stage test which currently, appropriately, fails due to…
mmcdermott Aug 14, 2024
74d3ce1
Corrected #161. This makes the normalization error disappear, but mor…
mmcdermott Aug 14, 2024
5911c90
Added a graceful exit mode for cases where shards are empty. Should s…
mmcdermott Aug 14, 2024
ff46645
Using more default parameters.
mmcdermott Aug 14, 2024
9304bf7
Tests re-failing as I haven't added all the NRT tests, just the fully…
mmcdermott Aug 14, 2024
c7d56a9
Starting to track some of the intermediate outputs to identify the ta…
mmcdermott Aug 14, 2024
210a9ce
Added capability to test intermediate files and started validating fu…
mmcdermott Aug 14, 2024
747ce11
Tested outputs up to occlude outliers.
mmcdermott Aug 14, 2024
8736279
Removed broken assert_no_other_outcomes
mmcdermott Aug 14, 2024
6114c42
Validated fit normalization stage as well.
mmcdermott Aug 15, 2024
3c5c7d9
Added test for final metadata stage
mmcdermott Aug 15, 2024
4a466c0
Added normalization shard test.
mmcdermott Aug 15, 2024
2a58ec1
Added tokenization schemas test.
mmcdermott Aug 15, 2024
aeedb9e
Added tokenized output validations.
mmcdermott Aug 15, 2024
80b31fd
Tested tensorized outputs.
mmcdermott Aug 15, 2024
a27d361
Merge pull request #167 from mmcdermott/160-multi-stage-integration-t…
mmcdermott Aug 15, 2024
7c55b34
Merge branch 'dev' into 125_metadata_extraction_crashes
mmcdermott Aug 15, 2024
766f685
Added a test for the no metadata case.
mmcdermott Aug 15, 2024
7f11b9e
Merge pull request #154 from mmcdermott/125_metadata_extraction_crashes
mmcdermott Aug 15, 2024
0c4a942
Incorporating Ethan's update from #108 and fixing other issues with t…
mmcdermott Aug 19, 2024
3fab90a
Added badges to the README.
mmcdermott Aug 19, 2024
6046856
Fixed a typo in a badge link
mmcdermott Aug 19, 2024
727f4bd
Fixed another typo in a badge link
mmcdermott Aug 19, 2024
775ba5e
Merge pull request #171 from mmcdermott/170_badges
mmcdermott Aug 19, 2024
4b2d10c
Set-up python 3.11 support.
mmcdermott Aug 26, 2024
2f7382a
Made publishing workflow only use 3.12
mmcdermott Aug 26, 2024
5a54e9a
Made code quality just use python 3.12
mmcdermott Aug 26, 2024
1f80b15
Merge pull request #178 from mmcdermott/176_py311_support
mmcdermott Aug 26, 2024
4ded5b0
Merge branch 'dev' into 57_extract_numeric_values
mmcdermott Aug 26, 2024
c4375c6
Fixed test (still errors, but now is due to proper issue.
mmcdermott Aug 26, 2024
95e66fc
Added regex capabilities to parser.
mmcdermott Aug 26, 2024
074237e
Added regex capabilities to matchers.
mmcdermott Aug 26, 2024
64807c5
Added a possibly working additional match revise stage type and updat…
mmcdermott Aug 26, 2024
b15a8a1
Fixed things up to test config for extracting temp.
mmcdermott Aug 26, 2024
cd03fa9
Updated test outputs to be in the proper order and use the right rege…
mmcdermott Aug 26, 2024
3f73a35
Merge pull request #121 from mmcdermott/57_extract_numeric_values
mmcdermott Aug 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/code-quality-main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ jobs:
code-quality:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.12"]

steps:
- name: Checkout
uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python 3.12
uses: actions/setup-python@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}

- name: Run pre-commits
uses: pre-commit/action@v3.0.1
13 changes: 8 additions & 5 deletions .github/workflows/code-quality-pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ jobs:
code-quality:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.12"]

steps:
- name: Checkout
uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Set up Python 3.12
uses: actions/setup-python@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}

- name: Find modified files
id: file_changes
Expand Down
10 changes: 7 additions & 3 deletions .github/workflows/python-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ jobs:
name: Build distribution 📦
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}
- name: Install pypa/build
run: >-
python3 -m
Expand Down
10 changes: 6 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@ jobs:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.11", "3.12"]
fail-fast: false

timeout-minutes: 30

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python 3.12
uses: actions/setup-python@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: ${{ matrix.python-version }}

- name: Install packages
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py310-plus]
args: [--py311-plus]

# python docstring formatting
- repo: https://github.com/myint/docformatter
Expand Down
4 changes: 2 additions & 2 deletions MIMIC-IV_Example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pypi installation, which is covered below, so make sure you are in a suitable di
```bash
conda create -n MEDS python=3.12
conda activate MEDS
pip install MEDS_transforms[examples,local_parallelism]
pip install "MEDS_transforms[local_parallelism]"
mkdir MIMIC-IV_Example
cd MIMIC-IV_Example
wget https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/main/MIMIC-IV_Example/joint_script.sh
Expand All @@ -32,7 +32,7 @@ git clone git@github.com:mmcdermott/MEDS_transforms.git
cd MEDS_transforms
conda create -n MEDS python=3.12
conda activate MEDS
pip install .[examples,local_parallelism]
pip install .[local_parallelism]
```

## Step 1: Download MIMIC-IV
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# MEDS Transforms

[![codecov](https://codecov.io/gh/mmcdermott/MEDS_transforms/graph/badge.svg?token=5RORKQOZF9)](https://codecov.io/gh/mmcdermott/MEDS_transforms)
[![PyPI - Version](https://img.shields.io/pypi/v/MEDS-transforms)](https://pypi.org/project/MEDS-transforms/)
![python](https://img.shields.io/badge/-Python_3.12-blue?logo=python&logoColor=white)
[![Documentation Status](https://readthedocs.org/projects/meds-transforms/badge/?version=latest)](https://meds-transforms.readthedocs.io/en/latest/?badge=latest)
[![codecov](https://codecov.io/gh/mmcdermott/MEDS_transforms/graph/badge.svg?token=5RORKQOZF9)](https://codecov.io/gh/mmcdermott/MEDS_transforms)
[![tests](https://github.com/mmcdermott/MEDS_transforms/actions/workflows/tests.yaml/badge.svg)](https://github.com/mmcdermott/MEDS_transforms/actions/workflows/tests.yml)
[![code-quality](https://github.com/mmcdermott/MEDS_transforms/actions/workflows/code-quality-main.yaml/badge.svg)](https://github.com/mmcdermott/MEDS_transforms/actions/workflows/code-quality-main.yaml)
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/mmcdermott/MEDS_transforms#license)
[![PRs](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](https://github.com/mmcdermott/MEDS_transforms/pulls)
[![contributors](https://img.shields.io/github/contributors/mmcdermott/MEDS_transforms.svg)](https://github.com/mmcdermott/MEDS_transforms/graphs/contributors)

This repository contains a set of functions and scripts for extraction to and transformation/pre-processing of
MEDS-formatted data.
Expand Down
14 changes: 10 additions & 4 deletions eICU_Example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,20 @@ up from this one).

## Step 0: Installation

Download this repository and install the requirements:
Install the requirements and source the requisite scripts

```bash
git clone git@github.com:mmcdermott/MEDS_transforms.git
cd MEDS_transforms
conda create -n MEDS python=3.12
conda activate MEDS
pip install .[examples]
pip install "MEDS_transforms[local_parallelism]"
mkdir eICU_Example
cd eICU_Example
wget https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/main/eICU_Example/joint_script.sh
wget https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/main/eICU_Example/pre_MEDS.py
chmod +x joint_script.sh
chmod +x joint_script_slurm.sh
chmod +x pre_MEDS.py
cd ..
```

## Step 1: Download eICU
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "MEDS ETL and transformation functions leveraging a sharding-based parallelism model & polars."
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down Expand Up @@ -54,6 +54,7 @@ MEDS_transform-filter_patients = "MEDS_transforms.filters.filter_patients:main"
## Transforms
MEDS_transform-reorder_measurements = "MEDS_transforms.transforms.reorder_measurements:main"
MEDS_transform-add_time_derived_measurements = "MEDS_transforms.transforms.add_time_derived_measurements:main"
MEDS_transform-extract_values = "MEDS_transforms.transforms.extract_values:main"
MEDS_transform-normalization = "MEDS_transforms.transforms.normalization:main"
MEDS_transform-occlude_outliers = "MEDS_transforms.transforms.occlude_outliers:main"
MEDS_transform-tensorization = "MEDS_transforms.transforms.tensorization:main"
Expand Down
8 changes: 8 additions & 0 deletions src/MEDS_transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@
"timestamp": "time",
"subject_id": "patient_id",
}

INFERRED_STAGE_KEYS = {
"is_metadata",
"data_input_dir",
"metadata_input_dir",
"output_dir",
"reducer_output_dir",
}
43 changes: 39 additions & 4 deletions src/MEDS_transforms/aggregate_code_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def validate_args_and_get_code_cols(stage_cfg: DictConfig, code_modifiers: list[
for agg in aggregations:
if isinstance(agg, (dict, DictConfig)):
agg = agg.get("name", None)
if agg not in METADATA_FN:
if agg not in {fn.value for fn in METADATA_FN}:
raise ValueError(
f"Metadata aggregation function {agg} not found in METADATA_FN enumeration. Values are: "
f"{', '.join([fn.value for fn in METADATA_FN])}"
Expand Down Expand Up @@ -406,7 +406,9 @@ def mapper_fntr(
│ C ┆ 1 ┆ 81.25 ┆ 5.0 ┆ 7.5 │
│ D ┆ null ┆ 0.0 ┆ null ┆ null │
└──────┴───────────┴────────────────┴────────────┴────────────┘
>>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]})
>>> stage_cfg = DictConfig({
... "aggregations": [{"name": "values/quantiles", "quantiles": [0.25, 0.5, 0.75]}]
... })
>>> mapper = mapper_fntr(stage_cfg, code_modifiers)
>>> mapper(df.lazy()).collect().select("code", "modifier1", pl.col("values/quantiles"))
shape: (5, 3)
Expand All @@ -421,6 +423,25 @@ def mapper_fntr(
│ C ┆ 1 ┆ [5.0, 7.5] │
│ D ┆ null ┆ [] │
└──────┴───────────┴──────────────────┘
>>> stage_cfg = DictConfig({
... "aggregations": [{"name": "values/quantiles", "quantiles": [0.25, 0.5, 0.75]}],
... "do_summarize_over_all_codes": True,
... })
>>> mapper = mapper_fntr(stage_cfg, code_modifiers)
>>> mapper(df.lazy()).collect().select("code", "modifier1", pl.col("values/quantiles"))
shape: (6, 3)
┌──────┬───────────┬───────────────────┐
│ code ┆ modifier1 ┆ values/quantiles │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ list[f64] │
╞══════╪═══════════╪═══════════════════╡
│ null ┆ null ┆ [1.1, 2.0, … 7.5] │
│ A ┆ 1 ┆ [1.1, 1.1] │
│ A ┆ 2 ┆ [6.0] │
│ B ┆ 2 ┆ [2.0, 4.0] │
│ C ┆ 1 ┆ [5.0, 7.5] │
│ D ┆ null ┆ [] │
└──────┴───────────┴───────────────────┘
"""

code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifiers)
Expand All @@ -435,15 +456,20 @@ def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame:
return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns)

def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame:
return df.select(**agg_operations)
local_agg_operations = agg_operations.copy()
if METADATA_FN.VALUES_QUANTILES in agg_operations:
local_agg_operations[METADATA_FN.VALUES_QUANTILES] = agg_operations[
METADATA_FN.VALUES_QUANTILES
].implode()
return df.select(**local_agg_operations)

if stage_cfg.get("do_summarize_over_all_codes", False):

def mapper(df: pl.LazyFrame) -> pl.LazyFrame:
by_code = by_code_mapper(df)
all_patients = all_patients_mapper(df)
return pl.concat([all_patients, by_code], how="diagonal_relaxed").select(
*code_key_columns, *aggregations
*code_key_columns, *agg_operations.keys()
)

else:
Expand Down Expand Up @@ -682,6 +708,15 @@ def run_map_reduce(cfg: DictConfig):
cs.numeric().shrink_dtype().name.keep()
)

old_metadata_fp = Path(cfg.stage_cfg.metadata_input_dir) / "codes.parquet"
join_cols = ["code", *cfg.get("code_modifier_cols", [])]

if old_metadata_fp.exists():
logger.info(f"Joining to existing code metadata at {str(old_metadata_fp.resolve())}")
existing = pl.scan_parquet(old_metadata_fp)
existing = existing.drop(*[c for c in existing.columns if c in set(reduced.columns) - set(join_cols)])
reduced = reduced.join(existing, on=join_cols, how="left", coalesce=True)

write_lazyframe(reduced, reducer_fp)
logger.info(f"Finished reduction in {datetime.now() - start}")

Expand Down
19 changes: 0 additions & 19 deletions src/MEDS_transforms/configs/extract.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ description: |-

# The event conversion configuration file is used throughout the pipeline to define the events to extract.
event_conversion_config_fp: ???
# The code modifier columns are in this pipeline only used in the aggregate_code_metadata stage.
code_modifiers: null
# The shards mapping is stored in the root of the final output directory.
shards_map_fp: "${cohort_dir}/metadata/.shards.json"

Expand All @@ -37,27 +35,10 @@ stages:
- split_and_shard_patients
- convert_to_sharded_events
- merge_to_MEDS_cohort
- aggregate_code_metadata
- extract_code_metadata
- finalize_MEDS_metadata
- finalize_MEDS_data

stage_configs:
shard_events:
data_input_dir: "${input_dir}"
aggregate_code_metadata:
description: |-
This stage collects some descriptive metadata about the codes in the cohort.

Args:
stage_cfg.aggregations: The aggregations to compute over the codes.
Defaults to counts of code occurrences, counts of patients with the code, and counts of value
occurrences per code, as well as the sum and sum of squares of values (for use in computing means
and variances).
aggregations:
- "code/n_occurrences"
- "code/n_patients"
- "values/n_occurrences"
- "values/sum"
- "values/sum_sqd"
do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
fit_vocabulary_indices:
is_metadata: true
ordering_method: "lexicographic"
output_dir: "${cohort_dir}"
4 changes: 4 additions & 0 deletions src/MEDS_transforms/extract/extract_code_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def main(cfg: DictConfig):
OmegaConf.save(event_conversion_cfg, partial_metadata_dir / "event_conversion_config.yaml")

events_and_metadata_by_metadata_fp = get_events_and_metadata_by_metadata_fp(event_conversion_cfg)
if not events_and_metadata_by_metadata_fp:
logger.info("No _metadata blocks in the event_conversion_config.yaml found. Exiting...")
return

event_metadata_configs = list(events_and_metadata_by_metadata_fp.items())
random.shuffle(event_metadata_configs)

Expand Down
20 changes: 12 additions & 8 deletions src/MEDS_transforms/extract/split_and_shard_patients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
import json
import math
from collections.abc import Sequence
from pathlib import Path

Expand All @@ -13,15 +14,13 @@
from MEDS_transforms.utils import stage_init


def shard_patients[
SUBJ_ID_T
](
def shard_patients(
patients: np.ndarray,
n_patients_per_shard: int = 50000,
external_splits: dict[str, Sequence[SUBJ_ID_T]] | None = None,
external_splits: dict[str, Sequence[int]] | None = None,
split_fracs_dict: dict[str, float] | None = {"train": 0.8, "tuning": 0.1, "held_out": 0.1},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using mutable data structures for argument defaults.

Using mutable default arguments can lead to unexpected behavior. Set the default value to None and initialize within the function.

Apply this diff to fix the issue:

-    split_fracs_dict: dict[str, float] | None = {"train": 0.8, "tuning": 0.1, "held_out": 0.1},
+    split_fracs_dict: dict[str, float] | None = None,

Then, initialize the default value within the function:

    if external_splits is None:
        external_splits = {}
+    if split_fracs_dict is None:
+        split_fracs_dict = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}
Tools
Ruff

21-21: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

seed: int = 1,
) -> dict[str, list[SUBJ_ID_T]]:
) -> dict[str, list[int]]:
"""Shard a list of patients, nested within train/tuning/held-out splits.

This function takes a list of patients and shards them into train/tuning/held-out splits, with the shards
Expand Down Expand Up @@ -72,7 +71,7 @@ def shard_patients[
>>> shard_patients(patients, n_patients_per_shard=3, split_fracs_dict={'train': 0.5})
Traceback (most recent call last):
...
ValueError: The sum of the split fractions must be equal to 1.
ValueError: The sum of the split fractions must be equal to 1. Got 0.5 through {'train': 0.5}.
>>> shard_patients([1, 2], n_patients_per_shard=3)
Traceback (most recent call last):
...
Expand Down Expand Up @@ -107,10 +106,15 @@ def shard_patients[

splits = external_splits

splits_cover = sum(split_fracs_dict.values()) if split_fracs_dict else 0

rng = np.random.default_rng(seed)
if n_patients := len(patient_ids_to_split):
if sum(split_fracs_dict.values()) != 1:
raise ValueError("The sum of the split fractions must be equal to 1.")
if not math.isclose(splits_cover, 1):
raise ValueError(
f"The sum of the split fractions must be equal to 1. Got {splits_cover} "
f"through {split_fracs_dict}."
)
split_names_idx = rng.permutation(len(split_fracs_dict))
split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx]
split_fracs = np.array([split_fracs_dict[k] for k in split_names])
Expand Down
3 changes: 1 addition & 2 deletions src/MEDS_transforms/fit_vocabulary_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def lexicographic_indices(code_metadata: pl.DataFrame, code_modifiers: list[str]
... "modifier2": [None, None, None, None, 2, 1],
... })
>>> code_modifiers = ["modifier1", "modifier2"]
>>> expr = lexicographic_indices(code_metadata, code_modifiers)
>>> code_metadata.with_columns(expr)
>>> lexicographic_indices(code_metadata, code_modifiers)
shape: (6, 4)
┌──────┬───────────┬───────────┬──────────────────┐
│ code ┆ modifier1 ┆ modifier2 ┆ code/vocab_index │
Expand Down
Loading
Loading