Skip to content

Commit

Permalink
MXNet Estimator support multiple inputs/outputs for Gluon (intel-anal…
Browse files Browse the repository at this point in the history
…ytics#2765)

* support multiple input and output

* bug fix

* update

* update context

* add ut
  • Loading branch information
hkvision authored and yangw1234 committed Sep 27, 2021
1 parent b991951 commit d6454d5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 38 deletions.
87 changes: 50 additions & 37 deletions python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import subprocess
import ray.services
import mxnet as mx
import numpy as np
from mxnet import gluon
from zoo.ray.utils import to_list

Expand All @@ -39,7 +38,7 @@ def setup_distributed(self, env, config, model_creator, loss_creator=None,
self.model_creator = model_creator
self.loss_creator = loss_creator
self.validation_metrics_creator = validation_metrics_creator
self.eval_metircs_creator = eval_metrics_creator
self.eval_metrics_creator = eval_metrics_creator
self.is_worker = False
env["DMLC_NODE_HOST"] = self.get_node_ip()
if env["DMLC_ROLE"] == "worker":
Expand All @@ -48,16 +47,21 @@ def setup_distributed(self, env, config, model_creator, loss_creator=None,
if self.is_worker:
os.environ.update(env)
self.kv = mx.kv.create("dist_sync")
# Set seed so that the model on each worker is initialized with the same weights
# Set seed so that the model on each worker is initialized with the same weights.
if "seed" in self.config:
mx.random.seed(self.config["seed"])

self.model = self.model_creator(self.config)
self.loss = self.loss_creator(self.config) if self.loss_creator else None
self.eval_metrics = self.eval_metircs_creator(self.config) \
if self.eval_metircs_creator else None
self.eval_metrics = self.eval_metrics_creator(self.config) \
if self.eval_metrics_creator else None
from mxnet.metric import CompositeEvalMetric
if isinstance(self.eval_metrics, list):
self.eval_metrics = CompositeEvalMetric(self.eval_metrics)
self.val_metrics = self.validation_metrics_creator(self.config) \
if self.validation_metrics_creator else None
if isinstance(self.val_metrics, list):
self.val_metrics = CompositeEvalMetric(self.val_metrics)
# For BaseModule, use symbolic API. Otherwise, use imperative API.
# TODO: change Gluon Trainer to Estimator API?
if not isinstance(self.model, mx.module.BaseModule):
Expand All @@ -72,7 +76,7 @@ def setup_distributed(self, env, config, model_creator, loss_creator=None,
# TODO: Need to kill this process manually?
modified_env = os.environ.copy()
modified_env.update(env)
# For servers, just import mxnet and no need to do anything else
# For servers, just import mxnet and no need to do anything else.
subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env)

def train(self, train_data, epochs=1, batch_size=32,
Expand Down Expand Up @@ -106,69 +110,78 @@ def train(self, train_data, epochs=1, batch_size=32,
val_data_iter = validation_data(config, self.kv) if validation_data else None
start_time = time.time()
if self.trainer: # Imperative API
def cpu_context(target_data):
if isinstance(target_data, list):
return [cpu_context(d) for d in target_data]
else:
return target_data.as_in_context(mx.cpu())

for epoch in range(epochs):
train_data_iter.reset()
# DataLoader doesn't need to be reset.
if isinstance(train_data_iter, mx.io.DataIter):
train_data_iter.reset()
if self.eval_metrics:
self.eval_metrics.reset() # metrics will accumulate for one batch
self.eval_metrics.reset() # metrics will accumulate for one batch.
batch_start_time = time.time()
epoch_start_time = time.time()
for i, batch in enumerate(train_data_iter):
data = gluon.utils.split_and_load(
batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
label = gluon.utils.split_and_load(
batch.label[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
outputs = []
Ls = []
data = cpu_context(batch.data)
label = cpu_context(batch.label)
if not isinstance(data, list):
data = [data]
if not isinstance(label, list):
label = [label]
from mxnet import autograd as ag
with ag.record():
for x, y in zip(data, label):
z = self.model(x) # forward
L = self.loss(z, y)
# store the loss and do backward on a batch for better speed
Ls.append(L)
outputs.append(z)
output = self.model(*data) # forward
if not isinstance(output, list):
output = [output]
Ls = self.loss(*output, *label)
ag.backward(Ls)
self.trainer.step(batch.data[0].shape[0])
self.trainer.step(batch_size)
if self.eval_metrics:
self.eval_metrics.update(label, outputs)
self.eval_metrics.update(label, output)
if not (i + 1) % self.config["log_interval"]:
# This would be logged on driver for each worker process.
iteration_log = \
"Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \
% (epoch, i,
batch_size / (time.time() - batch_start_time),
"loss", Ls[0].asnumpy().mean())
"loss", Ls.asnumpy().mean())
if self.eval_metrics:
names, accs = self.eval_metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
iteration_log += " %s=%f" % (name, acc)
self.logger.info(iteration_log)
batch_start_time = time.time()
# Epoch time log
# Epoch time log.
self.logger.info("[Epoch %d] time cost: %f" %
(epoch, time.time() - epoch_start_time))
# Epoch metrics log on train data
# Epoch metrics log on train data.
if self.eval_metrics:
epoch_train_log = "[Epoch %d] training: " % epoch
names, accs = self.eval_metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_train_log += "%s=%f " % (name, acc)
self.logger.info(epoch_train_log)
# Epoch metrics log on validation data if any:
# Epoch metrics log on validation data if any.
if val_data_iter:
if isinstance(val_data_iter, mx.io.DataIter):
val_data_iter.reset()
self.val_metrics.reset()
val_data_iter.reset()
for batch in val_data_iter:
data = gluon.utils.split_and_load(
batch.data[0].astype("float32", copy=False),
ctx_list=[mx.cpu()], batch_axis=0)
label = gluon.utils.split_and_load(
batch.label[0].astype("float32", copy=False),
ctx_list=[mx.cpu()], batch_axis=0)
outputs = [self.model(X) for X in data]
self.val_metrics.update(label, outputs)
data = cpu_context(batch.data)
label = cpu_context(batch.label)
if not isinstance(data, list):
data = [data]
if not isinstance(label, list):
label = [label]
output = self.model(*data)
if not isinstance(output, list):
output = [output]
self.val_metrics.update(label, output)
epoch_val_log = "[Epoch %d] validation: " % epoch
names, accs = self.val_metrics.get()
names, accs = to_list(names), to_list(accs)
Expand All @@ -185,9 +198,9 @@ def train(self, train_data, epochs=1, batch_size=32,
# TODO: seems no history (i.e. validation accuracy) returned by fit?
if "init" not in self.config:
from mxnet.initializer import Uniform
self.config["init"] = Uniform(0.01) # This is the default value for MXNet
self.config["init"] = Uniform(0.01) # This is the default value for MXNet.
if self.eval_metrics is None:
self.eval_metrics = 'acc'
self.eval_metrics = 'acc' # This is the default value for MXNet.
self.model.fit(train_data=train_data_iter,
num_epoch=epochs,
initializer=self.config["init"],
Expand Down
2 changes: 1 addition & 1 deletion python/orca/src/bigdl/orca/learn/mxnet/mxnet_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, train_resize_
:param epochs: The number of epochs to train the MXNet model. Default is 1.
:param batch_size: The number of samples per batch. Default is 32.
:param batch_size: The number of samples per batch for each worker. Default is 32.
:param validation_data: An instance of SparkXShards or a function that takes config and
kv as arguments and returns an MXNet DataIter/DataLoader for validation.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest import TestCase

import numpy as np
import pytest

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from zoo.ray import RayContext
from zoo.orca.learn.mxnet import Estimator, create_config

np.random.seed(1337) # for reproducibility


def get_train_data_iter(config, kv):
train_data = [np.random.rand(100, 30), np.random.rand(100, 20)]
train_label = np.random.randint(0, 10, (200,))
train = mx.io.NDArrayIter(train_data, train_label,
batch_size=config["batch_size"], shuffle=True)
return train


def get_test_data_iter(config, kv):
test_data = [np.random.rand(40, 30), np.random.rand(40, 20)]
test_label = np.random.randint(0, 10, (80,))
test = mx.io.NDArrayIter(test_data, test_label,
batch_size=config["batch_size"], shuffle=True)
return test


def get_model(config):
class SimpleModel(gluon.nn.HybridBlock):
def __init__(self, **kwargs):
super(SimpleModel, self).__init__(**kwargs)
self.fc1 = nn.Dense(20)
self.fc2 = nn.Dense(40)
self.fc3 = nn.Dense(10)

def hybrid_forward(self, F, x1, x2):
y1 = self.fc1(x1)
y2 = self.fc2(x2)
y = F.concat(y1, y2, dim=1)
return self.fc3(y)

net = SimpleModel()
net.initialize(mx.init.Xavier(rnd_type="gaussian"), ctx=[mx.cpu()], force_reinit=True)
return net


def get_loss(config):
return gluon.loss.SoftmaxCrossEntropyLoss()


def get_metrics(config):
return ['accuracy', mx.metric.TopKAccuracy(3)]


class TestMXNetGluonMultipleInput(TestCase):
def test_gluon_multiple_input(self):
config = create_config(log_interval=2, optimizer="adagrad", seed=1128,
optimizer_params={'learning_rate': 0.02})
estimator = Estimator(config, get_model, get_loss,
eval_metrics_creator=get_metrics,
validation_metrics_creator=get_metrics,
num_workers=4)
estimator.fit(get_train_data_iter, validation_data=get_test_data_iter, epochs=2)
estimator.shutdown()


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit d6454d5

Please sign in to comment.