Skip to content

Commit

Permalink
Update pytorch dataset download logic (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Aug 29, 2023
1 parent 71696ae commit ae64958
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion AB_environments/AB_baseline.conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies:
- toolz ==0.12.0
- zict ==3.0.0
- xgboost ==1.7.4
- optuna ==3.2.0
- optuna ==3.3.0
- scipy ==1.10.1
- snowflake-connector-python ==3.0.4
- snowflake-sqlalchemy ==1.4.7
Expand Down
2 changes: 1 addition & 1 deletion AB_environments/AB_sample.conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies:
- toolz ==0.12.0
- zict ==3.0.0
- xgboost ==1.7.4
- optuna ==3.2.0
- optuna ==3.3.0
- scipy ==1.10.1
- snowflake-connector-python ==3.0.4
- snowflake-sqlalchemy ==1.4.7
Expand Down
2 changes: 1 addition & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- toolz ==0.12.0
- zict ==3.0.0
- xgboost ==1.7.4
- optuna ==3.2.0
- optuna ==3.3.0
- scipy ==1.10.1
- snowflake-connector-python ==3.0.4
- snowflake-sqlalchemy ==1.4.7
Expand Down
36 changes: 20 additions & 16 deletions tests/workflows/test_pytorch_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import zipfile

import pytest
from dask.distributed import PipInstall
from dask.distributed import Lock, PipInstall, get_worker

from ..utils_test import wait

Expand Down Expand Up @@ -47,23 +47,27 @@ def weights_init(m):
def download_data():
import s3fs # FIXME: see above install w/ urllib3 - import here after reinstall

tmpdir = tempfile.gettempdir()
zip = pathlib.Path(tmpdir).joinpath("img_align_celeba.zip")
dataset_dir = zip.parent.joinpath("img_align_celeba")

if zip.exists():
print("Dataset already downloaded, returning dataset dir")
worker = get_worker()
with Lock(worker.address):
tmpdir = tempfile.gettempdir()
zip = pathlib.Path(tmpdir).joinpath("img_align_celeba.zip")
dataset_dir = zip.parent.joinpath("img_align_celeba")

if zip.exists():
print("Dataset already downloaded, returning dataset dir")
return dataset_dir

print("Downloading dataset...")
fs = s3fs.S3FileSystem(anon=True)
fs.download(
"s3://coiled-datasets/CelebA-Faces/img_align_celeba.zip", str(zip)
)

print(f"Unzipping into {dataset_dir}")
with zipfile.ZipFile(str(zip), "r") as zipped:
zipped.extractall(dataset_dir)
return dataset_dir

print("Downloading dataset...")
fs = s3fs.S3FileSystem(anon=True)
fs.download("s3://coiled-datasets/CelebA-Faces/img_align_celeba.zip", str(zip))

print(f"Unzipping into {dataset_dir}")
with zipfile.ZipFile(str(zip), "r") as zipped:
zipped.extractall(dataset_dir)
return dataset_dir

def get_generator(trial):
import torch.nn as nn

Expand Down

0 comments on commit ae64958

Please sign in to comment.