Skip to content

Commit

Permalink
fixes an issue with wandb image logging, adds configuration for data …
Browse files Browse the repository at this point in the history
…loader workers
  • Loading branch information
mariusgiger committed Apr 19, 2022
1 parent 3b6540c commit 9a050c9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
7 changes: 5 additions & 2 deletions src/sdo/cmd/sood/ce_vae/cmd_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
@click.option(
"--dataset", type=click.Choice(["CuratedImageParameterDataset", "SDOMLDatasetV1"], case_sensitive=False), required=False, default="CuratedImageParameterDataset"
)
@click.option("--num-data-loader-workers", type=int, default=0)
@pass_environment
def train(ctx,
target_size,
Expand All @@ -40,7 +41,8 @@ def train(ctx,
load_path,
log_dir,
data_dir,
dataset):
dataset,
num_data_loader_workers):

main(run="train",
target_size=target_size,
Expand All @@ -56,4 +58,5 @@ def train(ctx,
load_path=load_path,
log_dir=log_dir,
data_dir=data_dir,
dataset=dataset)
dataset=dataset,
num_data_loader_workers=num_data_loader_workers)
30 changes: 15 additions & 15 deletions src/sood/algorithms/ce_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,15 @@ def train(self):
log_dict = {}

# VAE
image_table = wandb.Table()

image_data = [cnt, None, None, None, None]
if self.ce_factor < 1:
input_vae_path = save_image_grid(inpt, name="Input-VAE", save_dir=self.work_dir / Path("save/imgs"),
image_args={"normalize": False}, n_iter=cnt)
image_table.add_column(
"Input-VAE", wandb.Image(input_vae_path))
image_data[1] = wandb.Image(input_vae_path)
output_vae_path = save_image_grid(
x_rec_vae, name="Output-VAE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": True}, n_iter=cnt)
image_table.add_column(
"Output-VAE", wandb.Image(output_vae_path))
image_data[2] = wandb.Image(output_vae_path)

if self.beta > 0:
log_dict["Kl-loss"] = torch.mean(kl_loss).item()
Expand All @@ -197,16 +196,16 @@ def train(self):
if self.ce_factor > 0:
input_ce_path = save_image_grid(
inpt_noisy, name="Input-CE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": False})
image_table.add_column(
"Input-CE", wandb.Image(input_ce_path))
image_data[3] = wandb.Image(input_ce_path)
output_ce_path = save_image_grid(
x_rec_ce, name="Output-CE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": True})
image_table.add_column(
"Output-CE", wandb.Image(output_ce_path))
image_data[4] = wandb.Image(output_ce_path)

log_dict["CE-train-loss"] = loss_ce.item()

# TODO why normalize by the length of the input (batch length)?
image_table = wandb.Table(
columns=["Step", "Input-VAE", "Output-VAE", "Input-CE", "Output-CE"], data=[image_data])
# TODO why normalize by the length of the input (batch length)?
log_dict["CEVAE-train-loss"] = loss.item() / len(inpt)
log_dict["epoch"] = epoch
log_dict["counter"] = cnt
Expand Down Expand Up @@ -448,7 +447,8 @@ def main(
test_dir=None,
pred_dir=None,
data_dir=None,
dataset="CuratedImageParameterDataset"
dataset="CuratedImageParameterDataset",
num_data_loader_workers=0
):
input_shape = (batch_size, 1, target_size, target_size)

Expand All @@ -458,15 +458,15 @@ def main(
if dataset == "CuratedImageParameterDataset":
train_loader = get_dataset(
base_dir=data_dir,
num_processes=16,
num_processes=num_data_loader_workers,
pin_memory=False,
batch_size=batch_size,
mode="train",
target_size=input_shape[2],
)
val_loader = get_dataset(
base_dir=data_dir,
num_processes=8,
num_processes=num_data_loader_workers,
pin_memory=False,
batch_size=batch_size,
mode="val",
Expand All @@ -476,15 +476,15 @@ def main(
# due to a bug on Mac, num processes needs to be 0: https://github.com/pyg-team/pytorch_geometric/issues/366
train_loader = get_sdo_ml_v1_dataset(
base_dir=data_dir,
num_processes=0,
num_processes=num_data_loader_workers,
pin_memory=False,
batch_size=batch_size,
mode="train",
target_size=input_shape[2],
)
val_loader = get_sdo_ml_v1_dataset(
base_dir=data_dir,
num_processes=0,
num_processes=num_data_loader_workers,
pin_memory=False,
batch_size=batch_size,
mode="val",
Expand Down

0 comments on commit 9a050c9

Please sign in to comment.