diff --git a/EventStream/data/pytorch_dataset.py b/EventStream/data/pytorch_dataset.py index 68fad67e..72ff2234 100644 --- a/EventStream/data/pytorch_dataset.py +++ b/EventStream/data/pytorch_dataset.py @@ -141,9 +141,6 @@ def __init__(self, config: PytorchDatasetConfig, split: str, just_cache: bool = logger.info("Reading splits & patient shards") self.read_shards() - logger.info("Setting measurement configs") - self.set_measurement_configs() - logger.info("Reading patient descriptors") self.read_patient_descriptors() @@ -180,9 +177,10 @@ def read_shards(self): self.shards = {sp: subjs for sp, subjs in all_shards.items() if sp.startswith(f"{self.split}/")} self.subj_map = {subj: sp for sp, subjs in self.shards.items() for subj in subjs} - def set_measurement_configs(self): - """Sets the measurement configs from the source ESGPT or MEDS dataset.""" - self.measurement_configs = self.config.measurement_configs + @property + def measurement_configs(self): + """Grabs the measurement configs from the config.""" + return self.config.measurement_configs def read_patient_descriptors(self): """Reads the patient descriptors from the ESGPT or MEDS dataset."""