Skip to content

Commit

Permalink
Orca MXNetTrainer migration (#2320)
Browse files Browse the repository at this point in the history
* migrate mxnet_trainer to estimator

* fix

* newline

* indent

* fix test path

* fix

* ignore mxnet estimator test in spark2.4-

* estimator to trainer

* style
  • Loading branch information
cyita authored and yangw1234 committed Sep 27, 2021
1 parent f56582a commit bfcde04
Show file tree
Hide file tree
Showing 9 changed files with 609 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/orca/src/bigdl/orca/learn/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .mxnet_trainer import MXNetTrainer
from .utils import create_trainer_config
215 changes: 215 additions & 0 deletions python/orca/src/bigdl/orca/learn/mxnet/mxnet_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import time
import logging
import subprocess
import ray.services
import mxnet as mx
from mxnet import gluon
from zoo.ray.utils import to_list
from zoo.orca.learn.mxnet.utils import find_free_port


class MXNetRunner(object):
"""Manages a MXNet model for training."""

def setup_distributed(self, env, config, data_creator, model_creator,
loss_creator=None, metrics_creator=None):
logging.basicConfig(level=logging.INFO) # This can print log messages to console.
self.logger = logging.getLogger()
assert isinstance(config, dict), "config must be a dict"
for param in ["batch_size", "optimizer", "optimizer_params", "log_interval"]:
assert param in config, param + " must be specified in config"
self.config = config
self.data_creator = data_creator
self.model_creator = model_creator
self.loss_creator = loss_creator
self.metrics_creator = metrics_creator
self.is_worker = False
env["DMLC_NODE_HOST"] = self.get_node_ip()
if env["DMLC_ROLE"] == "worker":
self.is_worker = True

if self.is_worker:
os.environ.update(env)
self.kv = mx.kv.create("dist_sync")
# Set seed so that the model on each worker is initialized with the same weights
if "seed" in self.config:
mx.random.seed(self.config["seed"])
data = self.data_creator(self.config, self.kv)
if isinstance(data, tuple):
assert len(data) == 1 or len(data) == 2, \
"data_creator should return either train_data only or a tuple of " \
"(train_data, val_data), which can be directly fed to model training"
if len(data) == 1:
self.train_data, self.val_data = data[0], None
else:
self.train_data, self.val_data = data
else: # Only return one object, supposed to be train data.
self.train_data, self.val_data = data, None
self.model = self.model_creator(self.config)
if self.loss_creator:
self.loss = self.loss_creator(self.config)
else:
self.loss = None
if self.val_data:
assert self.metrics_creator, \
"Metrics not defined for validation, please specify metrics_creator"
self.metrics = self.metrics_creator(self.config)
else:
self.metrics = None
# For BaseModule, use symbolic API. Otherwise, use imperative API.
# TODO: change to Estimator API?
if not isinstance(self.model, mx.module.BaseModule):
assert self.loss, "Loss not defined for gluon model, please specify loss_creator"
self.trainer = gluon.Trainer(self.model.collect_params(), self.config["optimizer"],
optimizer_params=self.config["optimizer_params"],
kvstore=self.kv)
else: # Trainer is not needed for symbolic API.
self.trainer = None
else: # server
# Need to use the environment on each raylet process for the correct python environment.
# TODO: Need to kill this process manually?
modified_env = os.environ.copy()
modified_env.update(env)
# For servers, just import mxnet and no need to do anything else
subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env)

def train(self, nb_epoch=1):
"""Train the model and update the model parameters."""
stats = dict()
if self.is_worker:
start_time = time.time()
if self.trainer: # Imperative API
for epoch in range(nb_epoch):
self.train_data.reset()
if self.metrics:
self.metrics.reset() # metrics will accumulate for one batch
batch_start_time = time.time()
epoch_start_time = time.time()
for i, batch in enumerate(self.train_data):
data = gluon.utils.split_and_load(
batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
label = gluon.utils.split_and_load(
batch.label[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
outputs = []
Ls = []
from mxnet import autograd as ag
with ag.record():
for x, y in zip(data, label):
z = self.model(x) # forward
L = self.loss(z, y)
# store the loss and do backward on a batch for better speed
Ls.append(L)
outputs.append(z)
ag.backward(Ls)
self.trainer.step(batch.data[0].shape[0])
if self.metrics:
self.metrics.update(label, outputs)
if not (i + 1) % self.config["log_interval"]:
# This would be logged on driver for each worker process.
iteration_log = \
"Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \
% (epoch, i,
self.config["batch_size"] / (time.time() - batch_start_time),
"loss", Ls[0].asnumpy().mean())
if self.metrics:
names, accs = self.metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
iteration_log += " %s=%f" % (name, acc)
self.logger.info(iteration_log)
batch_start_time = time.time()
# Epoch time log
self.logger.info("[Epoch %d] time cost: %f" %
(epoch, time.time() - epoch_start_time))
# Epoch metrics log on train data
if self.metrics:
epoch_train_log = "[Epoch %d] training: " % epoch
names, accs = self.metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_train_log += "%s=%f " % (name, acc)
self.logger.info(epoch_train_log)
# Epoch metrics log on validation data if any:
if self.val_data:
self.metrics.reset()
self.val_data.reset()
for batch in self.val_data:
data = gluon.utils.split_and_load(
batch.data[0].astype("float32", copy=False),
ctx_list=[mx.cpu()], batch_axis=0)
label = gluon.utils.split_and_load(
batch.label[0].astype("float32", copy=False),
ctx_list=[mx.cpu()], batch_axis=0)
outputs = [self.model(X) for X in data]
self.metrics.update(label, outputs)
epoch_val_log = "[Epoch %d] validation: " % epoch
names, accs = self.metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_val_log += "%s=%f " % (name, acc)
self.logger.info(epoch_val_log)
# TODO: save checkpoints
if self.metrics:
names, accs = self.metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
stats[name] = acc
else: # Symbolic API
# TODO: seems no history (i.e. validation accuracy) returned by fit?
if "init" not in self.config:
from mxnet.initializer import Uniform
self.config["init"] = Uniform(0.01) # This is the default value for MXNet
self.model.fit(train_data=self.train_data,
num_epoch=nb_epoch,
initializer=self.config["init"],
kvstore=self.kv,
optimizer=self.config["optimizer"],
optimizer_params=self.config["optimizer_params"],
eval_data=self.val_data,
# TODO: eval and validation metrics could be different
eval_metric=self.metrics,
validation_metric=self.metrics,
batch_end_callback=mx.callback.Speedometer(
self.config["batch_size"], self.config["log_interval"]),
epoch_end_callback=None if "model" not in self.config
else mx.callback.do_checkpoint(self.config["model"]))
epoch_time = time.time() - start_time
stats["epoch_time"] = epoch_time
return stats

def shutdown(self):
"""Attempts to shut down the runner."""
del self.logger
if self.is_worker:
del self.kv
del self.model
del self.train_data
del self.val_data
del self.trainer
del self.loss
# TODO: also delete downloaded data as well?

def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()

def find_free_port(self):
"""Finds a free port on the current node."""
return find_free_port()
146 changes: 146 additions & 0 deletions python/orca/src/bigdl/orca/learn/mxnet/mxnet_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import logging
import subprocess
import ray.services
from dmlc_tracker.tracker import get_host_ip
from zoo.orca.learn.mxnet.mxnet_runner import MXNetRunner
from zoo.orca.learn.mxnet.utils import find_free_port


class MXNetTrainer(object):
"""
MXNetTrainer provides an automatic setup for synchronous distributed MXNet training.
:param config: A dictionary for training configurations. Keys must include the following:
batch_size, optimizer, optimizer_params, log_interval.
optimizer should be an MXNet optimizer or its string representation.
optimizer_params should be a dict in companion with the optimizer. It can contain learning_rate
and other optimization configurations.
log_interval should be an integer, specifying the interval for logging throughput and metrics
information (if any) during the training process.
You can call create_trainer_config to create the config easily.
You can specify "seed" in config to set random seed.
You can specify "init" in seed to set model initializer.
:param data_creator: A function that takes config and kv as arguments and returns an MXNet
DataIter/DataLoader for training or a tuple of training and validation datasets.
You can specify data related configurations for this function in the config argument above.
kv is an instance of MXNet distributed key-value store. kv.num_workers and kv.rank
can be used in this function to split data for different workers if necessary.
:param model_creator: A function that takes config as argument and returns an MXNet model.
The model can be defined either using MXNet symbolic API or imperative(gluon) API.
:param loss_creator: A function that takes config as argument and returns an MXNet loss.
This is not needed for symbolic API where loss is already defined as model output.
:param metrics_creator: A function that takes config as argument and returns one or a list of
MXNet metrics or corresponding string representations of metrics, for example, 'accuracy'.
This is not needed if you don't have validation data throughout the training.
:param num_workers: The number of workers for distributed training. Default is 1.
:param num_servers: The number of servers for distributed training. Default is None and in this
case it would be equal to the number of workers.
:param runner_cores: The number of CPU cores allocated for each MXNet worker and server.
Default is None. You may need to specify this for better performance.
"""
def __init__(self, config, data_creator, model_creator,
loss_creator=None, metrics_creator=None,
num_workers=1, num_servers=None, runner_cores=None):
self.config = config
self.data_creator = data_creator
self.model_creator = model_creator
self.loss_creator = loss_creator
self.metrics_creator = metrics_creator
self.num_workers = num_workers
self.num_servers = num_servers if num_servers else self.num_workers

# Generate actor class
# Add a dummy custom resource: _mxnet_worker and _mxnet_server to diff worker from server
# if runner_cores is specified so that we can place one worker and one server on a node
# for better performance.
Worker = ray.remote(num_cpus=runner_cores, resources={"_mxnet_worker": 1})(MXNetRunner) \
if runner_cores else ray.remote(MXNetRunner)
Server = ray.remote(num_cpus=runner_cores, resources={"_mxnet_server": 1})(MXNetRunner) \
if runner_cores else ray.remote(MXNetRunner)

# Start runners: workers followed by servers
self.runners = [
Worker.remote()
for i in range(self.num_workers)
]
self.runners += [
Server.remote()
for i in range(self.num_servers)
]

# Compute URL for initializing distributed setup
ips = ray.get(
[runner.get_node_ip.remote() for runner in self.runners])
ports = ray.get(
[runner.find_free_port.remote() for runner in self.runners])
logger = logging.getLogger()
logger.info(ips)
logger.info(ports)

env = {
"DMLC_PS_ROOT_URI": str(get_host_ip()),
"DMLC_PS_ROOT_PORT": str(find_free_port()),
"DMLC_NUM_SERVER": str(self.num_servers),
"DMLC_NUM_WORKER": str(self.num_workers),
}
envs = []
for i in range(self.num_workers):
current_env = env.copy()
current_env['DMLC_ROLE'] = 'worker'
envs.append(current_env)
for i in range(self.num_servers):
current_env = env.copy()
current_env['DMLC_ROLE'] = 'server'
envs.append(current_env)

env['DMLC_ROLE'] = 'scheduler'
modified_env = os.environ.copy()
modified_env.update(env)
# Need to contain system env to run bash
# TODO: Need to kill this process manually?
subprocess.Popen("python -c 'import mxnet'", shell=True, env=modified_env)

ray.get([
runner.setup_distributed.remote(envs[i], self.config,
self.data_creator,
self.model_creator,
self.loss_creator,
self.metrics_creator)
for i, runner in enumerate(self.runners)
])

def train(self, nb_epoch=1):
"""Trains an MXNet model for several epochs."""
stats = ray.get([w.train.remote(nb_epoch) for w in self.runners])
return stats

def shutdown(self):
"""Shuts down runners and releases resources."""
for runner in self.runners:
runner.shutdown.remote()
runner.__ray_terminate__.remote()

# TODO: add model save and restore
# TODO: add predict, evaluate
Loading

0 comments on commit bfcde04

Please sign in to comment.