Skip to content

Commit

Permalink
Add Horovod tests (intel-analytics#2761)
Browse files Browse the repository at this point in the history
* add pytorch horovod tests

* add horovod tf tests

* fix

* fix style

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests
  • Loading branch information
yangw1234 committed Sep 23, 2021
1 parent 6469e4f commit 78df3bc
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 14 deletions.
5 changes: 2 additions & 3 deletions python/orca/src/bigdl/orca/learn/tf2/tf_ray_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,15 @@ def fit(self, data_creator, epochs=1, verbose=1,
return stats

def evaluate(self, data_creator, verbose=1, sample_weight=None,
steps=None, callbacks=None, return_dict=False):
steps=None, callbacks=None):
"""Evaluates the model on the validation data set."""
logger.info("Starting validation step.")
params = dict(
data_creator=data_creator,
verbose=verbose,
sample_weight=sample_weight,
steps=steps,
callbacks=callbacks,
return_dict=return_dict
callbacks=callbacks
)
# see ./tf_runner.py:setup_distributed
# for an explanation of only taking the first worker's data
Expand Down
3 changes: 1 addition & 2 deletions python/orca/src/bigdl/orca/learn/tf2/tf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def step(self, data_creator, epochs=1, verbose=1,
return stats

def validate(self, data_creator, verbose=1, sample_weight=None,
steps=None, callbacks=None, return_dict=False):
steps=None, callbacks=None):
"""Evaluates the model on the validation data set."""

dataset = data_creator(self.config)
Expand All @@ -208,7 +208,6 @@ def validate(self, data_creator, verbose=1, sample_weight=None,
sample_weight=sample_weight,
steps=steps,
callbacks=callbacks,
return_dict=return_dict
)
results = self.model.evaluate(dataset, **params)
if results is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
class TestPyTorchTrainer(TestCase):
def test_train(self):
estimator = Estimator.from_torch(
model_creator=model_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
scheduler_creator=scheduler_creator,
config={
"lr": 1e-2, # used in optimizer_creator
Expand All @@ -54,9 +54,9 @@ def test_train(self):

def test_save_and_restore(self):
estimator1 = Estimator.from_torch(
model_creator=model_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
scheduler_creator=scheduler_creator,
config={
"lr": 1e-2, # used in optimizer_creator
Expand All @@ -73,9 +73,9 @@ def test_save_and_restore(self):
estimator1.shutdown()

estimator2 = Estimator.from_torch(
model_creator=model_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
model=model_creator,
optimizer=optimizer_creator,
loss=nn.MSELoss,
scheduler_creator=scheduler_creator,
config={
"lr": 1e-2, # used in optimizer_creator
Expand Down
15 changes: 15 additions & 0 deletions python/orca/test/bigdl/orca/learn/ray/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
31 changes: 31 additions & 0 deletions python/orca/test/bigdl/orca/learn/ray/tf/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest

sc = None
ray_ctx = None


@pytest.fixture(autouse=True, scope='package')
def rayonspark_fixture():
from zoo import init_spark_on_local
from zoo.ray import RayContext
sc = init_spark_on_local(cores=8, spark_log_level="INFO")
ray_ctx = RayContext(sc=sc, object_store_memory="1g")
ray_ctx.init()
yield
ray_ctx.stop()
sc.stop()
118 changes: 118 additions & 0 deletions python/orca/test/bigdl/orca/learn/ray/tf/test_tf_ray_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest import TestCase

import numpy as np

from zoo.orca.learn.tf2 import Estimator
from zoo.ray import RayContext

NUM_TRAIN_SAMPLES = 1000
NUM_TEST_SAMPLES = 400


def linear_dataset(a=2, size=1000):
x = np.random.rand(size)
y = x / 2

x = x.reshape((-1, 1))
y = y.reshape((-1, 1))

return x, y


def create_train_datasets(config):
import tensorflow as tf
batch_size = config["batch_size"]
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).batch(
batch_size)

return train_dataset


def create_test_dataset(config):
import tensorflow as tf
batch_size = config["batch_size"]
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)

return test_dataset


def simple_model(config):
import tensorflow as tf
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10, input_shape=(1,)), tf.keras.layers.Dense(1)])
return model


def compile_args(config):
import tensorflow as tf
args = {
"optimizer": tf.keras.optimizers.Adam(),
"loss": "mean_squared_error",
"metrics": ["mean_squared_error"]
}
return args


class TestTFRayEstimator(TestCase):
def test_fit_and_evaluate(self):
import tensorflow as tf
ray_ctx = RayContext.get()
batch_size = 32
global_batch_size = batch_size * ray_ctx.num_ray_nodes
config = {
"batch_size": batch_size
}
trainer = Estimator(
model_creator=simple_model,
compile_args_creator=compile_args,
verbose=True,
config=config)

# model baseline performance
start_stats = trainer.evaluate(create_test_dataset,
steps=NUM_TEST_SAMPLES // global_batch_size)
print(start_stats)

def scheduler(epoch):
if epoch < 2:
return 0.001
else:
return 0.001 * tf.math.exp(0.1 * (2 - epoch))

scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1)
# train for 2 epochs
trainer.fit(create_train_datasets, epochs=2, callbacks=[scheduler])
trainer.fit(create_train_datasets, epochs=2, callbacks=[scheduler])

# model performance after training (should improve)
end_stats = trainer.evaluate(create_test_dataset,
steps=NUM_TEST_SAMPLES // global_batch_size)
print(end_stats)

# sanity check that training worked
dloss = end_stats["validation_loss"] - start_stats["validation_loss"]
dmse = (end_stats["validation_mean_squared_error"] -
start_stats["validation_mean_squared_error"])
print(f"dLoss: {dloss}, dMSE: {dmse}")

assert dloss < 0 and dmse < 0, "training sanity check failed. loss increased!"

0 comments on commit 78df3bc

Please sign in to comment.