-
Notifications
You must be signed in to change notification settings - Fork 729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refine orca tf example #2821
refine orca tf example #2821
Changes from 3 commits
3e529c1
cd18ebc
c4a23fa
9842911
dbcad66
c910a56
1edba28
06f6769
46a3f18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,47 +3,43 @@ | |
This is an example to demonstrate how to use Analytics-Zoo's Orca TF Estimator API to run distributed | ||
Tensorflow and Keras on Spark. | ||
|
||
## Install or download Analytics Zoo | ||
Follow the instructions [here](https://analytics-zoo.github.io/master/#PythonUserGuide/install/) to install analytics-zoo via __pip__ or __download the prebuilt package__. | ||
|
||
## Environment Preparation | ||
``` | ||
pip install tensorflow==1.15 tensorflow-datasets==2.0 | ||
|
||
Download and install latest analytics whl from source forge ([here](https://sourceforge.net/projects/analytics-zoo/files/zoo-py/)). | ||
|
||
```bash | ||
conda create -y -n analytics-zoo python==3.7.7 | ||
conda activate analytics-zoo | ||
pip install analytics_zoo-${VERSION}-${TIMESTAMP}-py2.py3-none-${OS}_x86_64.whl | ||
pip install tensorflow==1.15.0 | ||
pip install psutil | ||
``` | ||
|
||
## Model Preparation | ||
Note: conda environment is required to run on Yarn, but not strictly necessary for running on local. | ||
|
||
In this example, we will use the **slim** library to construct the model. You can | ||
clone it [here](https://github.com/tensorflow/models/tree/master/research/slim) and add | ||
the `research/slim` directory to `PYTHONPATH`. | ||
## Run examples on local | ||
|
||
```bash | ||
git clone https://github.com/tensorflow/models/ | ||
export PYTHONPATH=$PWD/models/research/slim:$PYTHONPATH | ||
python lenet_mnist_graph.py --cluster_mode local | ||
``` | ||
|
||
## Run tf graph model example after pip install | ||
|
||
```bash | ||
python lenet_mnist_graph.py | ||
python lenet_mnist_keras.py --cluster_mode local | ||
``` | ||
## Run tf graph model example with prebuilt package | ||
|
||
## Run examples on yarn cluster | ||
```bash | ||
export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package | ||
export SPARK_HOME=... # the root directory of Spark | ||
bash $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] lenet_mnist_graph.py | ||
python lenet_mnist_graph.py --cluster_mode yarn-client --num_nodes 2 --cores 4 --memory 4g | ||
``` | ||
|
||
## Run tf keras model example after pip install | ||
```bash | ||
python lenet_mnist_keras.py | ||
python lenet_mnist_keras.py --cluster_mode yarn-client --num_nodes 2 --cores 4 --memory 4g | ||
``` | ||
|
||
## Run tf keras model example with prebuilt package | ||
```bash | ||
export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package | ||
export SPARK_HOME=... # the root directory of Spark | ||
bash $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] lenet_mnist_keras.py | ||
``` | ||
## Additional Resources | ||
The application is also be able to run on spark standalone cluster or in yarn cluster mode. | ||
Please refer to the following links to learn more details. | ||
|
||
1. [Orca Overview](https://analytics-zoo.github.io/master/#Orca/overview/) and [`init_orca_context`](link_to_be_added) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
2. [Download and install Analytics Zoo](https://analytics-zoo.github.io/master/#PythonUserGuide/install/) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,47 +13,50 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
import sys | ||
import argparse | ||
|
||
import tensorflow as tf | ||
from zoo.orca.data import XShards | ||
from zoo.orca.learn.tf.estimator import Estimator | ||
from zoo.orca import init_orca_context, stop_orca_context | ||
|
||
sys.path.append("/tmp/models/slim") # add the slim library | ||
from nets import lenet | ||
|
||
slim = tf.contrib.slim | ||
|
||
|
||
def accuracy(logits, labels): | ||
predictions = tf.argmax(logits, axis=1, output_type=labels.dtype) | ||
is_correct = tf.cast(tf.equal(predictions, labels), dtype=tf.float32) | ||
return tf.reduce_mean(is_correct) | ||
|
||
|
||
def main(max_epoch): | ||
sc = init_orca_context(cores=4, memory="2g") | ||
def lenet(images, is_training): | ||
with tf.variable_scope('LeNet', [images]): | ||
net = tf.layers.conv2d(images, 32, (5, 5), activation=tf.nn.relu, name='conv1') | ||
net = tf.layers.max_pooling2d(net, (2, 2), 2, name='pool1') | ||
net = tf.layers.conv2d(net, 64, (5, 5), activation=tf.nn.relu, name='conv2') | ||
net = tf.layers.max_pooling2d(net, (2, 2), 2, name='pool2') | ||
net = tf.layers.flatten(net) | ||
net = tf.layers.dense(net, 1024, activation=tf.nn.relu, name='fc3') | ||
net = tf.layers.dropout( | ||
net, 0.5, training=is_training, name='dropout3') | ||
logits = tf.layers.dense(net, 10) | ||
return logits | ||
|
||
# get DataSet | ||
mnist_train = tfds.load(name="mnist", split="train") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tfds requires every node's local file system has the pre-downloaded data files. It is a little complex to set it up on yarn. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it will be automatically downloaded? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it will be only downloaded on driver. |
||
mnist_test = tfds.load(name="mnist", split="test") | ||
|
||
# Normalizes images | ||
def normalize_img(data): | ||
data['image'] = tf.cast(data["image"], tf.float32) / 255. | ||
return data | ||
def main(max_epoch): | ||
|
||
mnist_train = mnist_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
mnist_test = mnist_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
(train_feature, train_label), (val_feature, val_label) = tf.keras.datasets.mnist.load_data() | ||
train_feature = train_feature.reshape(-1, 28, 28, 1) / 255.0 | ||
val_feature = val_feature.reshape(-1, 28, 28, 1) / 255.0 | ||
train_data = XShards.partition({"x": train_feature, "y": train_label}) | ||
val_data = XShards.partition({"x": val_feature, "y": val_label}) | ||
|
||
# tensorflow inputs | ||
images = tf.placeholder(dtype=tf.float32, shape=(None, 28, 28, 1)) | ||
# tensorflow labels | ||
labels = tf.placeholder(dtype=tf.int32, shape=(None,)) | ||
|
||
with slim.arg_scope(lenet.lenet_arg_scope()): | ||
logits, end_points = lenet.lenet(images, num_classes=10, is_training=True) | ||
is_training = tf.placeholder_with_default(False, shape=()) | ||
|
||
logits = lenet(images, is_training=is_training) | ||
|
||
loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)) | ||
|
||
|
@@ -66,21 +69,36 @@ def normalize_img(data): | |
loss=loss, | ||
optimizer=tf.train.AdamOptimizer(), | ||
metrics={"acc": acc}) | ||
est.fit(data=mnist_train, | ||
est.fit(data=train_data, | ||
batch_size=320, | ||
epochs=max_epoch, | ||
validation_data=mnist_test) | ||
validation_data=val_data, | ||
feed_dict={is_training: (True, False)}) | ||
|
||
result = est.evaluate(mnist_test) | ||
result = est.evaluate(val_data) | ||
print(result) | ||
|
||
est.save_tf_checkpoint("/tmp/lenet/model") | ||
stop_orca_context() | ||
|
||
|
||
if __name__ == '__main__': | ||
max_epoch = 5 | ||
|
||
if len(sys.argv) > 1: | ||
max_epoch = int(sys.argv[1]) | ||
main(max_epoch) | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--cluster_mode', type=str, default="local", | ||
help='The mode for the Spark cluster.') | ||
parser.add_argument("--num_nodes", type=int, default=1, | ||
help="The number of nodes to be used in the cluster. " | ||
"You can change it depending on your own cluster setting.") | ||
parser.add_argument("--cores", type=int, default=4, | ||
help="The number of cpu cores you want to use on each node. " | ||
"You can change it depending on your own cluster setting.") | ||
parser.add_argument("--memory", type=str, default="10g", | ||
help="The memory you want to use on each node. " | ||
"You can change it depending on your own cluster setting.") | ||
|
||
parser.add_argument("--max_epoch", type=int, default=5) | ||
|
||
args = parser.parse_args() | ||
init_orca_context(cluster_mode=args.cluster_mode, cores=args.cores, | ||
num_nodes=args.num_nodes, memory=args.memory) | ||
main(args.max_epoch) | ||
stop_orca_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to this: https://analytics-zoo.github.io/master/#PythonUserGuide/install/#install-the-latest-nightly-build-wheels-for-pip