Skip to content

Commit

Permalink
Add README for MXNet LeNet example (intel-analytics#2208)
Browse files Browse the repository at this point in the history
* initial readme

* update

* typo

* typo

* add batch size

* update style

* add validation for gluon

* fix style

* minor

* update

* minor

* more doc

* minor
  • Loading branch information
hkvision committed Apr 14, 2020
1 parent 9adf3a6 commit 78dcd75
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
57 changes: 42 additions & 15 deletions python/orca/src/bigdl/orca/ray/mxnet/mxnet_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ray.services
import mxnet as mx
from mxnet import gluon
from zoo.ray.utils import to_list
from zoo.ray.mxnet.utils import find_free_port


Expand Down Expand Up @@ -100,6 +101,7 @@ def train(self, nb_epoch=1):
if self.metrics:
self.metrics.reset() # metrics will accumulate for one batch
batch_start_time = time.time()
epoch_start_time = time.time()
for i, batch in enumerate(self.train_data):
data = gluon.utils.split_and_load(
batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
Expand All @@ -121,27 +123,52 @@ def train(self, nb_epoch=1):
self.metrics.update(label, outputs)
if not (i + 1) % self.config["log_interval"]:
# This would be logged on driver for each worker process.
print_output = ""
print_output \
+= 'Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f' \
% (epoch, i,
self.config["batch_size"] / (time.time() - batch_start_time),
"loss", Ls[0].asnumpy().mean())
iteration_log = \
"Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \
% (epoch, i,
self.config["batch_size"] / (time.time() - batch_start_time),
"loss", Ls[0].asnumpy().mean())
if self.metrics:
names, accs = self.metrics.get()
if not isinstance(names, list):
names = [names]
accs = [accs]
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
print_output += ' %s=%f' % (name, acc)
self.logger.info(print_output)
iteration_log += " %s=%f" % (name, acc)
self.logger.info(iteration_log)
batch_start_time = time.time()
# TODO: save checkpoints
# Epoch time log
self.logger.info("[Epoch %d] time cost: %f" %
(epoch, time.time() - epoch_start_time))
# Epoch metrics log on train data
if self.metrics:
epoch_train_log = "[Epoch %d] training: " % epoch
names, accs = self.metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_train_log += "%s=%f " % (name, acc)
self.logger.info(epoch_train_log)
# Epoch metrics log on validation data if any:
if self.val_data:
self.metrics.reset()
self.val_data.reset()
for batch in self.val_data:
data = gluon.utils.split_and_load(
batch.data[0].astype("float32", copy=False),
ctx_list=[mx.cpu()], batch_axis=0)
label = gluon.utils.split_and_load(
batch.label[0].astype("float32", copy=False),
ctx_list=[mx.cpu()], batch_axis=0)
outputs = [self.model(X) for X in data]
self.metrics.update(label, outputs)
epoch_val_log = "[Epoch %d] validation: " % epoch
names, accs = self.metrics.get()
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
epoch_val_log += "%s=%f " % (name, acc)
self.logger.info(epoch_val_log)
# TODO: save checkpoints
if self.metrics:
names, accs = self.metrics.get()
if not isinstance(names, list):
names = [names]
accs = [accs]
names, accs = to_list(names), to_list(accs)
for name, acc in zip(names, accs):
stats[name] = acc
else: # Symbolic API
Expand Down
2 changes: 1 addition & 1 deletion python/orca/src/bigdl/orca/ray/mxnet/mxnet_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MXNetTrainer(object):
optimizer_params should be a dict in companion with the optimizer. It can contain learning_rate
and other optimization configurations.
log_interval should be an integer, specifying the interval for logging throughput and metrics
if any during the training process.
information (if any) during the training process.
You can call create_trainer_config to create the config easily.
You can specify "seed" in config to set random seed.
You can specify "init" in seed to set model initializer.
Expand Down
2 changes: 1 addition & 1 deletion python/orca/src/bigdl/orca/ray/mxnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def find_free_port():
return s.getsockname()[1]


def create_trainer_config(batch_size, optimizer="sgd", optimizer_params=None,
def create_trainer_config(batch_size=32, optimizer="sgd", optimizer_params=None,
log_interval=10, seed=None, extra_config=None):
if not optimizer_params:
optimizer_params = {'learning_rate': 0.01}
Expand Down

0 comments on commit 78dcd75

Please sign in to comment.