Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orca MXNetTrainer migration #2320

Merged
merged 10 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyzoo/dev/run-pytests-ray
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ export PYSPARK_DRIVER_PYTHON=python
echo "Running RayOnSpark tests"
python -m pytest -v ../test/zoo/ray/ \
--ignore=../test/zoo/ray/integration/ \
--ignore=../test/zoo/ray/mxnet/ \
--ignore=../test/zoo/orca/learn/mxnet/ \
--ignore=../test/zoo/ray/test_reinit_raycontext.py
exit_status_1=$?
if [ $exit_status_1 -ne 0 ];
then
exit $exit_status_1
fi

echo "Running MXNetTrainer tests"
python -m pytest -v ../test/zoo/ray/mxnet
echo "Running MXNet Estimator tests"
python -m pytest -v ../test/zoo/orca/learn/mxnet
exit_status_2=$?
if [ $exit_status_2 -ne 0 ];
then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from zoo.ray.mxnet import MXNetTrainer, create_trainer_config
from zoo.orca.learn.mxnet import Estimator, create_trainer_config

np.random.seed(1337) # for reproducibility

Expand Down Expand Up @@ -68,8 +68,7 @@ class TestMXNetGluon(TestCase):
def test_gluon(self):
config = create_trainer_config(batch_size=32, log_interval=2, optimizer="adam",
optimizer_params={'learning_rate': 0.02})
trainer = MXNetTrainer(config, get_data_iters, get_model, get_loss, get_metrics,
num_workers=2)
trainer = Estimator(config, get_data_iters, get_model, get_loss, get_metrics, num_workers=2)
trainer.train(nb_epoch=2)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest

import mxnet as mx
from zoo.ray.mxnet import MXNetTrainer, create_trainer_config
from zoo.orca.learn.mxnet import Estimator, create_trainer_config

np.random.seed(1337) # for reproducibility

Expand Down Expand Up @@ -56,7 +56,7 @@ def get_metrics(config):
class TestMXNetSymbol(TestCase):
def test_symbol(self):
config = create_trainer_config(batch_size=32, log_interval=2, seed=42)
trainer = MXNetTrainer(config, get_data_iters, get_model, metrics_creator=get_metrics)
trainer = Estimator(config, get_data_iters, get_model, metrics_creator=get_metrics)
trainer.train(nb_epoch=2)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# MXNet LeNet example

Here we demonstrate how to easily run synchronous distributed [MXNet](https://github.com/apache/incubator-mxnet) training using
MXNetTrainer implemented in Analytics Zoo on top of [RayOnSpark](https://analytics-zoo.github.io/master/#ProgrammingGuide/rayonspark/).
MXNet Estimator implemented in Analytics Zoo on top of [RayOnSpark](https://analytics-zoo.github.io/master/#ProgrammingGuide/rayonspark/).
We use the LeNet model to train on MNIST dataset for handwritten digit recognition.
See [here](https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html) for the original single-node version of this example provided by MXNet.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .mxnet_trainer import MXNetTrainer
from .utils import create_trainer_config
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from zoo import init_spark_on_local, init_spark_on_yarn
from zoo.ray import RayContext
from zoo.ray.mxnet import MXNetTrainer, create_trainer_config
from zoo.orca.learn.mxnet import Estimator, create_trainer_config


def get_data_iters(config, kv):
Expand Down Expand Up @@ -149,9 +149,9 @@ def get_metrics(config):
config = create_trainer_config(opt.batch_size, optimizer="sgd",
optimizer_params={'learning_rate': opt.learning_rate},
log_interval=opt.log_interval, seed=42)
trainer = MXNetTrainer(config, data_creator=get_data_iters, model_creator=get_model,
loss_creator=get_loss, metrics_creator=get_metrics,
num_workers=opt.num_workers, num_servers=opt.num_servers)
trainer = Estimator(config, data_creator=get_data_iters, model_creator=get_model,
loss_creator=get_loss, metrics_creator=get_metrics,
num_workers=opt.num_workers, num_servers=opt.num_servers)
trainer.train(nb_epoch=opt.epochs)
ray_ctx.stop()
sc.stop()
2 changes: 1 addition & 1 deletion pyzoo/zoo/examples/run-example-test-ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ time4=$((now-start))

echo "#5 Start mxnet lenet example"
start=$(date "+%s")
python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/ray/mxnet/lenet_mnist.py -e 1 -b 256
python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/orca/learn/mxnet/lenet_mnist.py -e 1 -b 256
now=$(date "+%s")
time5=$((now-start))

Expand Down
2 changes: 1 addition & 1 deletion pyzoo/zoo/examples/run-example-tests-pip-ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ time3=$?
execute_ray_test multiagent_two_trainers ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/ray/rllib/multiagent_two_trainers.py
time4=$?

execute_ray_test lenet_mnist ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/ray/mxnet/lenet_mnist.py -e 1 -b 256
execute_ray_test lenet_mnist ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/orca/learn/mxnet/lenet_mnist.py -e 1 -b 256
time5=$?

echo "#1 rl_pong time used:$time1 seconds"
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions pyzoo/zoo/orca/learn/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .estimator import Estimator
from .utils import create_trainer_config
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import subprocess
import ray.services
from dmlc_tracker.tracker import get_host_ip
from zoo.ray.mxnet.mxnet_runner import MXNetRunner
from zoo.ray.mxnet.utils import find_free_port
from zoo.orca.learn.mxnet.mxnet_runner import MXNetRunner
from zoo.orca.learn.mxnet.utils import find_free_port


class MXNetTrainer(object):
class Estimator(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we still use the train API, we can keep it as MXNetTrainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""
MXNetTrainer provides an automatic setup for synchronous distributed MXNet training.
MXNet Estimator provides an automatic setup for synchronous distributed MXNet training.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to support DataShards; see the PR link in https://github.com/analytics-zoo/orca/issues/3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


:param config: A dictionary for training configurations. Keys must include the following:
batch_size, optimizer, optimizer_params, log_interval.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import mxnet as mx
from mxnet import gluon
from zoo.ray.utils import to_list
from zoo.ray.mxnet.utils import find_free_port
from zoo.orca.learn.mxnet.utils import find_free_port


class MXNetRunner(object):
Expand Down
File renamed without changes.