diff --git a/python/orca/example/learn/horovod/pytorch_estimator.py b/python/orca/example/learn/horovod/pytorch_estimator.py index 1024b87f114..17494b0c9e4 100644 --- a/python/orca/example/learn/horovod/pytorch_estimator.py +++ b/python/orca/example/learn/horovod/pytorch_estimator.py @@ -60,24 +60,26 @@ def scheduler_creator(optimizer, config): return torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9) -def data_creator(config): - """Returns training dataloader, validation dataloader.""" +def train_data_creator(config): train_dataset = LinearDataset(2, 5, size=config.get("data_size", 1000)) - val_dataset = LinearDataset(2, 5, size=config.get("val_size", 400)) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.get("batch_size", 32), ) + return train_loader + + +def validation_data_creator(config): + val_dataset = LinearDataset(2, 5, size=config.get("val_size", 400)) validation_loader = torch.utils.data.DataLoader( val_dataset, batch_size=config.get("batch_size", 32)) - return train_loader, validation_loader + return validation_loader def train_example(): trainer1 = PyTorchHorovodEstimator( model_creator=model_creator, - data_creator=data_creator, optimizer_creator=optimizer_creator, loss_creator=nn.MSELoss, scheduler_creator=scheduler_creator, @@ -89,8 +91,10 @@ def train_example(): # train 5 epochs for i in range(5): - stats = trainer1.train() - print(stats) + stats = trainer1.train(train_data_creator) + print("train stats: {}".format(stats)) + val_stats = trainer1.validate(validation_data_creator) + print("validation stats: {}".format(val_stats)) parser = argparse.ArgumentParser()