Skip to content

Commit

Permalink
Orca change default backend of PyTorch estimator to bigdl (intel-anal…
Browse files Browse the repository at this point in the history
…ytics#2966)

* change type

* Change default backend of PyTorch estimator to bigdl

* Update PythonInferenceModel.scala

* fix

* update docs
  • Loading branch information
cyita authored and yangw1234 committed Sep 26, 2021
1 parent 8fa79a0 commit f3fc75f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions python/orca/src/bigdl/orca/learn/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def from_torch(*,
use_tqdm=False,
workers_per_node=1,
model_dir=None,
backend="horovod"):
backend="bigdl"):
if backend == "horovod":
return PyTorchHorovodEstimatorWrapper(model_creator=model,
optimizer_creator=optimizer,
Expand Down Expand Up @@ -253,7 +253,7 @@ def get_model(self):
def save(self, checkpoint):
pass

def load(self, checkpoint, loss=None, model_dir=None):
def load(self, checkpoint, loss=None):
from zoo.orca.learn.utils import find_latest_checkpoint
from bigdl.nn.layer import Model
from bigdl.optim.optimizer import OptimMethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,11 @@ def transform(df):

with tempfile.TemporaryDirectory() as temp_dir_name:
estimator = Estimator.from_torch(model=model, loss=loss_func,
optimizer=SGD(), model_dir=temp_dir_name,
backend="bigdl")
optimizer=SGD(), model_dir=temp_dir_name)
estimator.fit(data=data_shard, epochs=4, batch_size=2, validation_data=data_shard,
validation_methods=[Accuracy()], checkpoint_trigger=EveryEpoch())
estimator.evaluate(data_shard, validation_methods=[Accuracy()], batch_size=2)
est2 = Estimator.from_torch(model=model, loss=loss_func, optimizer=None,
backend="bigdl")
est2 = Estimator.from_torch(model=model, loss=loss_func, optimizer=None)
est2.load(temp_dir_name, loss=loss_func)
est2.fit(data=data_shard, epochs=8, batch_size=2, validation_data=data_shard,
validation_methods=[Accuracy()], checkpoint_trigger=EveryEpoch())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward(self, x):
model = SimpleModel()

estimator = Estimator.from_torch(model=model, loss=nn.BCELoss(),
optimizer=Adam(), backend="bigdl")
optimizer=Adam())

def get_dataloader():
inputs = torch.Tensor([[1, 2], [1, 3], [3, 2], [5, 6], [8, 9], [1, 9]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_train(self):
"lr": 1e-2, # used in optimizer_creator
"hidden_size": 1, # used in model_creator
"batch_size": 4, # used in data_creator
})
}, backend="horovod")
stats1 = estimator.fit(train_data_creator, epochs=5)
train_loss1 = stats1[-1]["train_loss"]
validation_loss1 = estimator.evaluate(validation_data_creator)["val_loss"]
Expand All @@ -62,7 +62,7 @@ def test_save_and_restore(self):
"lr": 1e-2, # used in optimizer_creator
"hidden_size": 1, # used in model_creator
"batch_size": 4, # used in data_creator
})
}, backend="horovod")
with TemporaryDirectory() as tmp_path:
estimator1.fit(train_data_creator, epochs=1)
checkpoint_path = os.path.join(tmp_path, "checkpoint")
Expand All @@ -81,7 +81,7 @@ def test_save_and_restore(self):
"lr": 1e-2, # used in optimizer_creator
"hidden_size": 1, # used in model_creator
"batch_size": 4, # used in data_creator
})
}, backend="horovod")
estimator2.load(checkpoint_path)

model2 = estimator2.get_model()
Expand Down

0 comments on commit f3fc75f

Please sign in to comment.