-
Notifications
You must be signed in to change notification settings - Fork 729
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* migrate mxnet_trainer to estimator * fix * newline * indent * fix test path * fix * ignore mxnet estimator test in spark2.4- * estimator to trainer * style
- Loading branch information
Showing
4 changed files
with
242 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# MXNet LeNet example | ||
|
||
Here we demonstrate how to easily run synchronous distributed [MXNet](https://github.com/apache/incubator-mxnet) training using | ||
MXNetTrainer implemented in Analytics Zoo on top of [RayOnSpark](https://analytics-zoo.github.io/master/#ProgrammingGuide/rayonspark/). | ||
We use the LeNet model to train on MNIST dataset for handwritten digit recognition. | ||
See [here](https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html) for the original single-node version of this example provided by MXNet. | ||
|
||
In the distributed setting, the whole MNIST dataset is splitted into several parts and each MXNet worker takes a part for data parallel training. | ||
At the same time, MXNet servers are responsible for aggregating the parameters and send back to workers. | ||
|
||
## Prepare environments | ||
Follow steps 1 to 4 [here](https://analytics-zoo.github.io/master/#ProgrammingGuide/rayonspark/#steps-to-run-rayonspark) | ||
to prepare your python environment. | ||
|
||
You also need to install **MXNet** in your conda environment via pip. We have tested on MXNet 1.6.0. | ||
```bash | ||
pip install mxnet==1.6.0 | ||
``` | ||
If you are running on Intel Xeon scalable processors, you probably want to install the [MKLDNN](https://github.com/oneapi-src/oneDNN) version of MXNet for better performance: | ||
```bash | ||
pip install mxnet-mkl==1.6.0 | ||
``` | ||
|
||
See [here](https://analytics-zoo.github.io/master/#PythonUserGuide/run/#run-after-pip-install) | ||
for more running guidance after pip install. | ||
|
||
## Run on local after pip install | ||
``` | ||
python lenet_mnist.py -n 2 | ||
``` | ||
See [here](#Options) for more configurable options for this example. | ||
|
||
## Run on yarn cluster for yarn-client mode after pip install | ||
``` | ||
python lenet_mnist.py --hadoop_conf ...# path to your hadoop/yarn directory --conda_name ...# your conda name | ||
``` | ||
|
||
See [here](#Options) for more configurable options for this example. | ||
|
||
## Options | ||
- `-n` `--num_workers` The number of MXNet workers to be launched for distributed training. Default is 2. | ||
- `-s` `--num_servers` The number of MXNet servers to be launched for distributed training. If not specified, default to be equal to the number of workers. | ||
- `-b` `--batch_size` The number of samples per gradient update for each worker. Default is 100. | ||
- `-e` `--epochs` The number of epochs to train the model. Default is 10. | ||
- `-l` `--learning_rate` The learning rate for the TextClassifier model. Default is 0.01. | ||
- `--log_interval` The number of batches to wait before logging throughput and metrics information during the training process. | ||
|
||
**Options for yarn only** | ||
- `--hadoop_conf` This option is **required** when you want to run on yarn. The path to your configuration folder of hadoop. | ||
- `--conda_name` This option is **required** when you want to run on yarn. The name of your conda environment. | ||
- `--executor_cores` The number of executor cpu cores you want to use. Default is 4. | ||
|
||
## Results | ||
You can find the accuracy information from the log during the training process: | ||
``` | ||
(pid=34395) INFO:root:Epoch[6] Batch[99] Speed: 4513.416662 samples/sec loss=0.893286 accuracy=0.961562 | ||
(pid=34361) INFO:root:Epoch[6] Batch[99] Speed: 4465.811376 samples/sec loss=0.900434 accuracy=0.966484 | ||
(pid=34395) INFO:root:Epoch[6] Batch[199] Speed: 4877.452140 samples/sec loss=0.875496 accuracy=0.964102 | ||
(pid=34361) INFO:root:Epoch[6] Batch[199] Speed: 4322.318930 samples/sec loss=0.901634 accuracy=0.965000 | ||
(pid=34395) INFO:root:[Epoch 6] time cost: 7.500912 | ||
(pid=34395) INFO:root:[Epoch 6] training: accuracy=0.964410 | ||
(pid=34361) INFO:root:[Epoch 6] time cost: 7.522801 | ||
(pid=34361) INFO:root:[Epoch 6] training: accuracy=0.965178 | ||
(pid=34395) INFO:root:[Epoch 6] validation: accuracy=0.982171 | ||
(pid=34361) INFO:root:[Epoch 6] validation: accuracy=0.957131 | ||
``` | ||
Note that the training and validation accuracy of each worker may slightly differ as the accuracy of each worker is calculated | ||
based on its own portion of the whole dataset. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# Reference: https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/image/mnist.html | ||
|
||
import argparse | ||
|
||
from zoo import init_spark_on_local, init_spark_on_yarn | ||
from zoo.ray import RayContext | ||
from zoo.orca.learn.mxnet import MXNetTrainer, create_trainer_config | ||
|
||
|
||
def get_data_iters(config, kv): | ||
import os | ||
import zipfile | ||
import mxnet as mx | ||
from bigdl.dataset.base import maybe_download | ||
|
||
# In order to avoid conflict where multiple workers on the same node download and | ||
# zip data under the same location, here we let each worker have its own folder. | ||
|
||
# Not using mxnet.test_utils.get_mnist_iterator directly because data path is | ||
# hard-coded in this function. | ||
|
||
# In practice, data is supposed to be stored on a file system accessible to workers on | ||
# all nodes, for example, on HDFS or S3. | ||
maybe_download("mnist.zip", "worker" + str(kv.rank), | ||
"http://data.mxnet.io/mxnet/data/mnist.zip") | ||
if not os.path.isdir("worker" + str(kv.rank) + "/data"): | ||
with zipfile.ZipFile("worker" + str(kv.rank) + "/mnist.zip") as zf: | ||
zf.extractall("worker" + str(kv.rank) + "/data") | ||
|
||
train_iter = mx.io.MNISTIter( | ||
image="worker" + str(kv.rank) + "/data/train-images-idx3-ubyte", | ||
label="worker" + str(kv.rank) + "/data/train-labels-idx1-ubyte", | ||
input_shape=(1, 28, 28), | ||
batch_size=config["batch_size"], | ||
shuffle=True, | ||
flat=False, | ||
num_parts=kv.num_workers, | ||
part_index=kv.rank) | ||
val_iter = mx.io.MNISTIter( | ||
image="worker" + str(kv.rank) + "/data/t10k-images-idx3-ubyte", | ||
label="worker" + str(kv.rank) + "/data/t10k-labels-idx1-ubyte", | ||
input_shape=(1, 28, 28), | ||
batch_size=config["batch_size"], | ||
flat=False, | ||
num_parts=kv.num_workers, | ||
part_index=kv.rank) | ||
return train_iter, val_iter | ||
|
||
|
||
def get_model(config): | ||
import mxnet as mx | ||
from mxnet import gluon | ||
from mxnet.gluon import nn | ||
import mxnet.ndarray as F | ||
|
||
class LeNet(gluon.Block): | ||
def __init__(self, **kwargs): | ||
super(LeNet, self).__init__(**kwargs) | ||
with self.name_scope(): | ||
# layers created in name_scope will inherit name space | ||
# from parent layer. | ||
self.conv1 = nn.Conv2D(20, kernel_size=(5, 5)) | ||
self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) | ||
self.conv2 = nn.Conv2D(50, kernel_size=(5, 5)) | ||
self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)) | ||
self.fc1 = nn.Dense(500) | ||
self.fc2 = nn.Dense(10) | ||
|
||
def forward(self, x): | ||
x = self.pool1(F.tanh(self.conv1(x))) | ||
x = self.pool2(F.tanh(self.conv2(x))) | ||
# 0 means copy over size from corresponding dimension. | ||
# -1 means infer size from the rest of dimensions. | ||
x = x.reshape((0, -1)) | ||
x = F.tanh(self.fc1(x)) | ||
x = F.tanh(self.fc2(x)) | ||
return x | ||
|
||
net = LeNet() | ||
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=[mx.cpu()]) | ||
return net | ||
|
||
|
||
def get_loss(config): | ||
from mxnet import gluon | ||
return gluon.loss.SoftmaxCrossEntropyLoss() | ||
|
||
|
||
def get_metrics(config): | ||
import mxnet as mx | ||
return mx.metric.Accuracy() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Train a LeNet model for handwritten digit recognition.') | ||
parser.add_argument('--hadoop_conf', type=str, | ||
help='The path to the hadoop configuration folder. Required if you ' | ||
'wish to run on yarn clusters. Otherwise, run in local mode.') | ||
parser.add_argument('--conda_name', type=str, | ||
help='The name of conda environment. Required if you ' | ||
'wish to run on yarn clusters.') | ||
parser.add_argument('--executor_cores', type=int, default=4, | ||
help='The number of executor cores you want to use.') | ||
parser.add_argument('-n', '--num_workers', type=int, default=2, | ||
help='The number of MXNet workers to be launched.') | ||
parser.add_argument('-s', '--num_servers', type=int, | ||
help='The number of MXNet servers to be launched. If not specified, ' | ||
'default to be equal to the number of workers.') | ||
parser.add_argument('-b', '--batch_size', type=int, default=100, | ||
help='The number of samples per gradient update for each worker.') | ||
parser.add_argument('-e', '--epochs', type=int, default=10, | ||
help='The number of epochs to train the model.') | ||
parser.add_argument('-l', '--learning_rate', type=float, default=0.02, | ||
help='Learning rate for the LeNet model.') | ||
parser.add_argument('--log_interval', type=int, default=100, | ||
help='The number of batches to wait before logging throughput and ' | ||
'metrics information during the training process.') | ||
opt = parser.parse_args() | ||
|
||
if opt.hadoop_conf: | ||
assert opt.conda_name is not None, "conda_name must be specified for yarn mode" | ||
sc = init_spark_on_yarn( | ||
hadoop_conf=opt.hadoop_conf, | ||
conda_name=opt.conda_name, | ||
num_executor=opt.num_workers, | ||
executor_cores=opt.executor_cores) | ||
else: | ||
sc = init_spark_on_local(cores="*") | ||
ray_ctx = RayContext(sc=sc) | ||
ray_ctx.init() | ||
|
||
config = create_trainer_config(opt.batch_size, optimizer="sgd", | ||
optimizer_params={'learning_rate': opt.learning_rate}, | ||
log_interval=opt.log_interval, seed=42) | ||
trainer = MXNetTrainer(config, data_creator=get_data_iters, model_creator=get_model, | ||
loss_creator=get_loss, metrics_creator=get_metrics, | ||
num_workers=opt.num_workers, num_servers=opt.num_servers) | ||
trainer.train(nb_epoch=opt.epochs) | ||
ray_ctx.stop() | ||
sc.stop() |