-
Notifications
You must be signed in to change notification settings - Fork 729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Orca tf lenet mnist slim & keras example #2792
Merged
cyita
merged 7 commits into
intel-analytics:master
from
cyita:orca-tf-lenet-mnist-example
Aug 31, 2020
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
23e520b
add orca tf lenet mnist slim example
cyita 66a2447
Delete lenet_mnist_keras.py
cyita 4bb326c
Merge remote-tracking branch 'upstream/master' into orca-tf-lenet-mni…
cyita 2c62e05
add keras example
cyita 3e7d20a
merge upstream
cyita d0190c5
keras change to lenet
cyita e46144d
add comments
cyita File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Orca TF Estimator | ||
|
||
This is an example to demonstrate how to use Analytics-Zoo's Orca TF Estimator API to run distributed | ||
Tensorflow and Keras on Spark. | ||
|
||
## Install or download Analytics Zoo | ||
Follow the instructions [here](https://analytics-zoo.github.io/master/#PythonUserGuide/install/) to install analytics-zoo via __pip__ or __download the prebuilt package__. | ||
|
||
## Environment Preparation | ||
``` | ||
pip install tensorflow==1.15 tensorflow-datasets==2.0 | ||
pip install psutil | ||
``` | ||
|
||
## Model Preparation | ||
|
||
In this example, we will use the **slim** library to construct the model. You can | ||
clone it [here](https://github.com/tensorflow/models/tree/master/research/slim) and add | ||
the `research/slim` directory to `PYTHONPATH`. | ||
|
||
```bash | ||
git clone https://github.com/tensorflow/models/ | ||
export PYTHONPATH=$PWD/models/research/slim:$PYTHONPATH | ||
``` | ||
|
||
## Run tf graph model example after pip install | ||
|
||
```bash | ||
python lenet_mnist_graph.py | ||
``` | ||
## Run tf graph model example with prebuilt package | ||
|
||
```bash | ||
export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package | ||
export SPARK_HOME=... # the root directory of Spark | ||
bash $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] lenet_mnist_graph.py | ||
``` | ||
|
||
## Run tf keras model example after pip install | ||
```bash | ||
python lenet_mnist_keras.py | ||
``` | ||
|
||
## Run tf keras model example with prebuilt package | ||
```bash | ||
export ANALYTICS_ZOO_HOME=... # the directory where you extract the downloaded Analytics Zoo zip package | ||
export SPARK_HOME=... # the root directory of Spark | ||
bash $ANALYTICS_ZOO_HOME/bin/spark-submit-python-with-zoo.sh --master local[4] lenet_mnist_keras.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
import sys | ||
|
||
from zoo.orca.learn.tf.estimator import Estimator | ||
from zoo.orca import init_orca_context, stop_orca_context | ||
|
||
sys.path.append("/tmp/models/slim") # add the slim library | ||
from nets import lenet | ||
|
||
slim = tf.contrib.slim | ||
|
||
|
||
def accuracy(logits, labels): | ||
predictions = tf.argmax(logits, axis=1, output_type=labels.dtype) | ||
is_correct = tf.cast(tf.equal(predictions, labels), dtype=tf.float32) | ||
return tf.reduce_mean(is_correct) | ||
|
||
|
||
def main(max_epoch): | ||
sc = init_orca_context(cluster_mode="local", cores=4, memory="2g") | ||
|
||
# get DataSet | ||
mnist_train = tfds.load(name="mnist", split="train") | ||
mnist_test = tfds.load(name="mnist", split="test") | ||
|
||
# Normalizes images | ||
def normalize_img(data): | ||
data['image'] = tf.cast(data["image"], tf.float32) / 255. | ||
return data | ||
|
||
mnist_train = mnist_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
mnist_test = mnist_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
|
||
# tensorflow inputs | ||
images = tf.placeholder(dtype=tf.float32, shape=(None, 28, 28, 1)) | ||
cyita marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# tensorflow labels | ||
labels = tf.placeholder(dtype=tf.int32, shape=(None,)) | ||
|
||
with slim.arg_scope(lenet.lenet_arg_scope()): | ||
logits, end_points = lenet.lenet(images, num_classes=10, is_training=True) | ||
|
||
loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)) | ||
|
||
acc = accuracy(logits, labels) | ||
|
||
# create an estimator | ||
est = Estimator.from_graph(inputs=images, | ||
outputs=logits, | ||
labels=labels, | ||
loss=loss, | ||
optimizer=tf.train.AdamOptimizer(), | ||
metrics={"acc": acc}) | ||
est.fit(data=mnist_train, | ||
batch_size=320, | ||
epochs=max_epoch, | ||
validation_data=mnist_test) | ||
|
||
result = est.evaluate(mnist_test) | ||
print(result) | ||
|
||
est.save_tf_checkpoint("/tmp/lenet/model") | ||
stop_orca_context() | ||
|
||
|
||
if __name__ == '__main__': | ||
max_epoch = 5 | ||
|
||
if len(sys.argv) > 1: | ||
max_epoch = int(sys.argv[1]) | ||
main(max_epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import sys | ||
|
||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is |
||
from zoo.orca import init_orca_context, stop_orca_context | ||
from zoo.orca.learn.tf.estimator import Estimator | ||
|
||
|
||
def main(max_epoch): | ||
sc = init_orca_context(cluster_mode="local", cores=4, memory="2g") | ||
|
||
# get DataSet | ||
# as_supervised returns tuple (img, label) instead of dict {'image': img, 'label':label} | ||
mnist_train = tfds.load(name="mnist", split="train", as_supervised=True) | ||
mnist_test = tfds.load(name="mnist", split="test", as_supervised=True) | ||
|
||
# Normalizes images, unit8 -> float32 | ||
def normalize_img(image, label): | ||
return tf.cast(image, tf.float32) / 255., label | ||
|
||
mnist_train = mnist_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
mnist_test = mnist_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
|
||
model = tf.keras.Sequential( | ||
[tf.keras.layers.Conv2D(20, kernel_size=(5, 5), strides=(1, 1), activation='tanh', | ||
input_shape=(28, 28, 1), padding='valid'), | ||
tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'), | ||
tf.keras.layers.Conv2D(50, kernel_size=(5, 5), strides=(1, 1), activation='tanh', | ||
padding='valid'), | ||
tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'), | ||
tf.keras.layers.Flatten(), | ||
tf.keras.layers.Dense(500, activation='tanh'), | ||
tf.keras.layers.Dense(10, activation='softmax'), | ||
] | ||
) | ||
cyita marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
model.compile(optimizer=tf.keras.optimizers.RMSprop(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we move |
||
loss='sparse_categorical_crossentropy', | ||
metrics=['accuracy']) | ||
|
||
est = Estimator.from_keras(keras_model=model) | ||
est.fit(data=mnist_train, | ||
batch_size=320, | ||
epochs=max_epoch, | ||
validation_data=mnist_test) | ||
|
||
result = est.evaluate(mnist_test) | ||
print(result) | ||
|
||
est.save_keras_model("/tmp/mnist_keras.h5") | ||
stop_orca_context() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
max_epoch = 5 | ||
|
||
if len(sys.argv) > 1: | ||
max_epoch = int(sys.argv[1]) | ||
main(max_epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should add this example into auto-test. :)