Skip to content

Commit

Permalink
[Python] Support loading of TF models with saved weights (#25496)
Browse files Browse the repository at this point in the history
* load model with weight

* example

* update test

* update test

* make create model fn optional

* change tf to tensorflow

* add readme and change urls

* fix whitespace

* add doc and changes.md

* add tensorflow dependency

* remove tf dependency
  • Loading branch information
riteshghorse committed Feb 22, 2023
1 parent 3a62599 commit 33750c1
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
container was based upon Debian 11.
* RunInference PTransform will accept model paths as SideInputs in Python SDK. ([#24042](https://github.com/apache/beam/issues/24042))
* RunInference supports ONNX runtime in Python SDK ([#22972](https://github.com/apache/beam/issues/22972))
* Tensorflow Model Handler for RunInference in Python SDK ([#25366](https://github.com/apache/beam/issues/25366))

## I/Os

Expand Down
53 changes: 53 additions & 0 deletions sdks/python/apache_beam/examples/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,56 @@ background
...
```
Each line has a list of predicted label.

---
## MNIST digit classification with Tensorflow using Saved Model Weights
[`tensorflow_mnist_with_weights.py`](./tensorflow_mnist_with_weights.py) contains an implementation for a RunInference pipeline that performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) database.

The pipeline reads rows of pixels corresponding to a digit, performs basic preprocessing(converts the input shape to 28x28), passes the pixels to the trained Tensorflow model with RunInference, and then writes the predictions to a text file.

The model is loaded from the saved model weights. This can be done by passing a function which creates the model and setting the model type as
`ModelType.SAVED_WEIGHTS` to the `TFModelHandler`. The path to saved weights saved using `model.save_weights(path)` should be passed to the `model_path` argument.

### Dataset and model for language modeling

To use this transform, you need a dataset and model for language modeling.

1. Create a file named [`INPUT.csv`](gs://apache-beam-ml/testing/inputs/it_mnist_data.csv) that contains labels and pixels to feed into the model. Each row should have comma-separated elements. The first element is the label. All other elements are pixel values. The csv should not have column headers. The content of the file should be similar to the following example:
```
1,0,0,0...
0,0,0,0...
1,0,0,0...
4,0,0,0...
...
```
2. Save the weights of trained tensorflow model to a directory `SAVED_WEIGHTS_DIR` .


### Running `tensorflow_mnist_with_weights.py`

To run the MNIST classification pipeline locally, use the following command:
```sh
python -m apache_beam.examples.inference.tensorflow_mnist_with_weights.py \
--input INPUT \
--output OUTPUT \
--model_path SAVED_WEIGHTS_DIR
```
For example:
```sh
python -m apache_beam.examples.inference.tensorflow_mnist_with_weights.py \
--input INPUT.csv \
--output predictions.txt \
--model_path SAVED_WEIGHTS_DIR
```

This writes the output to the `predictions.txt` with contents like:
```
1,1
4,4
0,0
7,7
3,3
5,5
...
```
Each line has data separated by a comma ",". The first item is the actual label of the digit. The second item is the predicted label of the digit.
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging

import apache_beam as beam
import tensorflow as tf
from apache_beam.examples.inference.tensorflow_mnist_classification import PostProcessor
from apache_beam.examples.inference.tensorflow_mnist_classification import parse_known_args
from apache_beam.examples.inference.tensorflow_mnist_classification import process_input
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.tensorflow_inference import ModelType
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult


def get_model():
inputs = tf.keras.layers.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, 3, activation="relu")(inputs)
x = tf.keras.layers.Conv2D(32, 3, activation="relu")(x)
x = tf.keras.layers.MaxPooling2D(2)(x)
x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
x = tf.keras.layers.Conv2D(64, 3, activation="relu")(x)
x = tf.keras.layers.MaxPooling2D(2)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
return model


def run(
argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
save_main_session: Used for internal testing.
test_pipeline: Used for internal testing.
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session

# In this example we pass keyed inputs to RunInference transform.
# Therefore, we use KeyedModelHandler wrapper over TFModelHandlerNumpy.
model_loader = KeyedModelHandler(
TFModelHandlerNumpy(
model_uri=known_args.model_path,
model_type=ModelType.SAVED_WEIGHTS,
create_model_fn=get_model))

pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)

label_pixel_tuple = (
pipeline
| "ReadFromInput" >> beam.io.ReadFromText(known_args.input)
| "PreProcessInputs" >> beam.Map(process_input))

predictions = (
label_pixel_tuple
| "RunInference" >> RunInference(model_loader)
| "PostProcessOutputs" >> beam.ParDo(PostProcessor()))

_ = predictions | "WriteOutput" >> beam.io.WriteToText(
known_args.output, shard_name_template='', append_trailing_newlines=True)

result = pipeline.run()
result.wait_until_finish()
return result


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
29 changes: 29 additions & 0 deletions sdks/python/apache_beam/ml/inference/tensorflow_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
class ModelType(enum.Enum):
"""Defines how a model file should be loaded."""
SAVED_MODEL = 1
SAVED_WEIGHTS = 2


def _load_model(model_uri, model_type):
Expand All @@ -61,6 +62,12 @@ def _load_model(model_uri, model_type):
raise AssertionError('Unsupported model type for loading.')


def _load_model_from_weights(create_model_fn, weights_path):
model = create_model_fn()
model.load_weights(weights_path)
return model


def default_numpy_inference_fn(
model: tf.Module,
batch: Sequence[numpy.ndarray],
Expand Down Expand Up @@ -88,6 +95,7 @@ def __init__(
self,
model_uri: str,
model_type: ModelType = ModelType.SAVED_MODEL,
create_model_fn: Optional[Callable] = None,
*,
inference_fn: TensorInferenceFn = default_numpy_inference_fn):
"""Implementation of the ModelHandler interface for Tensorflow.
Expand All @@ -101,6 +109,9 @@ def __init__(
Args:
model_uri (str): path to the trained model.
model_type: type of model to be loaded. Defaults to SAVED_MODEL.
create_model_fn: a function that creates and returns a new
tensorflow model to load the saved weights.
It should be used with ModelType.SAVED_WEIGHTS.
inference_fn: inference function to use during RunInference.
Defaults to default_numpy_inference_fn.
Expand All @@ -110,9 +121,16 @@ def __init__(
self._model_uri = model_uri
self._model_type = model_type
self._inference_fn = inference_fn
self._create_model_fn = create_model_fn

def load_model(self) -> tf.Module:
"""Loads and initializes a Tensorflow model for processing."""
if self._model_type == ModelType.SAVED_WEIGHTS:
if not self._create_model_fn:
raise ValueError(
"Callable create_model_fn must be passed"
"with ModelType.SAVED_WEIGHTS")
return _load_model_from_weights(self._create_model_fn, self._model_uri)
return _load_model(self._model_uri, self._model_type)

def update_model_path(self, model_path: Optional[str] = None):
Expand Down Expand Up @@ -169,6 +187,7 @@ def __init__(
self,
model_uri: str,
model_type: ModelType = ModelType.SAVED_MODEL,
create_model_fn: Optional[Callable] = None,
*,
inference_fn: TensorInferenceFn = default_tensor_inference_fn):
"""Implementation of the ModelHandler interface for Tensorflow.
Expand All @@ -183,6 +202,9 @@ def __init__(
model_uri (str): path to the trained model.
model_type: type of model to be loaded.
Defaults to SAVED_MODEL.
create_model_fn: a function that creates and returns a new
tensorflow model to load the saved weights.
It should be used with ModelType.SAVED_WEIGHTS.
inference_fn: inference function to use during RunInference.
Defaults to default_numpy_inference_fn.
Expand All @@ -192,9 +214,16 @@ def __init__(
self._model_uri = model_uri
self._model_type = model_type
self._inference_fn = inference_fn
self._create_model_fn = create_model_fn

def load_model(self) -> tf.Module:
"""Loads and initializes a tensorflow model for processing."""
if self._model_type == ModelType.SAVED_WEIGHTS:
if not self._create_model_fn:
raise ValueError(
"Callable create_model_fn must be passed"
"with ModelType.SAVED_WEIGHTS")
return _load_model_from_weights(self._create_model_fn, self._model_uri)
return _load_model(self._model_uri, self._model_type)

def update_model_path(self, model_path: Optional[str] = None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import tensorflow as tf
from apache_beam.examples.inference import tensorflow_imagenet_segmentation
from apache_beam.examples.inference import tensorflow_mnist_classification
from apache_beam.examples.inference import tensorflow_mnist_with_weights
except ImportError as e:
tf = None

Expand Down Expand Up @@ -108,6 +109,36 @@ def test_tf_imagenet_image_segmentation(self):
for true_label, predicted_label in zip(expected_outputs, predicted_outputs):
self.assertEqual(true_label, predicted_label)

def test_tf_mnist_with_weights_classification(self):
test_pipeline = TestPipeline(is_integration_test=True)
input_file = 'gs://apache-beam-ml/testing/inputs/it_mnist_data.csv'
output_file_dir = 'gs://apache-beam-ml/testing/outputs'
output_file = '/'.join([output_file_dir, str(uuid.uuid4()), 'result.txt'])
model_path = 'gs://apache-beam-ml/models/tensorflow/mnist'
extra_opts = {
'input': input_file,
'output': output_file,
'model_path': model_path,
}
tensorflow_mnist_with_weights.run(
test_pipeline.get_full_options_as_args(**extra_opts),
save_main_session=False)
self.assertEqual(FileSystems().exists(output_file), True)

expected_output_filepath = 'gs://apache-beam-ml/testing/expected_outputs/test_sklearn_mnist_classification_actuals.txt' # pylint: disable=line-too-long
expected_outputs = process_outputs(expected_output_filepath)
predicted_outputs = process_outputs(output_file)
self.assertEqual(len(expected_outputs), len(predicted_outputs))

predictions_dict = {}
for i in range(len(predicted_outputs)):
true_label, prediction = predicted_outputs[i].split(',')
predictions_dict[true_label] = prediction

for i in range(len(expected_outputs)):
true_label, expected_prediction = expected_outputs[i].split(',')
self.assertEqual(predictions_dict[true_label], expected_prediction)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ commands =
/bin/sh -c "pip freeze | grep -E onnx"
# Run all ONNX unit tests
pytest -o junit_suite_name={envname} --junitxml=pytest_{envname}.xml -n 6 -m uses_onnx {posargs}

[testenv:py{37,38,39,310}-tensorflow-{29,210,211}]
deps =
-r build-requirements.txt
Expand Down

0 comments on commit 33750c1

Please sign in to comment.