diff --git a/python/orca/example/learn/horovod/pytorch_estimator.py b/python/orca/example/learn/horovod/pytorch_estimator.py index 17494b0c9e4..83f4f609bcd 100644 --- a/python/orca/example/learn/horovod/pytorch_estimator.py +++ b/python/orca/example/learn/horovod/pytorch_estimator.py @@ -22,7 +22,7 @@ from zoo import init_spark_on_yarn, init_spark_on_local from zoo.ray import RayContext -from zoo.orca.learn.pytorch.pytorch_horovod_estimator import PyTorchHorovodEstimator +from zoo.orca.learn.pytorch.estimator import Estimator class LinearDataset(torch.utils.data.Dataset): @@ -78,7 +78,7 @@ def validation_data_creator(config): def train_example(): - trainer1 = PyTorchHorovodEstimator( + estimator = Estimator.from_model_creator( model_creator=model_creator, optimizer_creator=optimizer_creator, loss_creator=nn.MSELoss, @@ -90,11 +90,10 @@ def train_example(): }) # train 5 epochs - for i in range(5): - stats = trainer1.train(train_data_creator) - print("train stats: {}".format(stats)) - val_stats = trainer1.validate(validation_data_creator) - print("validation stats: {}".format(val_stats)) + stats = estimator.fit(train_data_creator, epochs=5) + print("train stats: {}".format(stats)) + val_stats = estimator.evaluate(validation_data_creator) + print("validation stats: {}".format(val_stats)) parser = argparse.ArgumentParser()