diff --git a/python/orca/src/bigdl/orca/data/image/parquet_dataset.py b/python/orca/src/bigdl/orca/data/image/parquet_dataset.py index 0b1ba5b8b68..bebd7813aa4 100644 --- a/python/orca/src/bigdl/orca/data/image/parquet_dataset.py +++ b/python/orca/src/bigdl/orca/data/image/parquet_dataset.py @@ -21,12 +21,14 @@ from zoo.orca.data import SparkXShards from zoo.orca.data.file import open_text, write_text from zoo.orca.data.image.utils import chunks, dict_to_row, row_to_dict, encode_schema, \ - decode_schema, SchemaField, FeatureType, DType, ndarray_dtype_to_dtype + decode_schema, SchemaField, FeatureType, DType, ndarray_dtype_to_dtype, \ + decode_feature_type_ndarray, pa_fs from zoo.orca.data.image.voc_dataset import VOCDatasets from bigdl.util.common import get_node_and_core_number import os import numpy as np import random +import pyarrow.parquet as pq import io @@ -265,3 +267,42 @@ def write_parquet(format, output_path, *args, **kwargs): func, required_args = format_to_function[format] _check_arguments(format, kwargs, required_args) func(output_path=output_path, *args, **kwargs) + + +def read_as_tfdataset(path, output_types, output_shapes=None, *args, **kwargs): + """ + return a orca.data.tf.data.Dataset + :param path: + :return: + """ + path, _ = pa_fs(path) + import tensorflow as tf + + schema_path = os.path.join(path, "_orca_metadata") + j_str = open_text(schema_path)[0] + schema = decode_schema(j_str) + + def generator(): + for root, dirs, files in os.walk(path): + for name in dirs: + if name.startswith("chunk="): + chunk_path = os.path.join(path, name) + pq_table = pq.read_table(chunk_path) + df = decode_feature_type_ndarray(pq_table.to_pandas(), schema) + for record in df.to_dict("records"): + yield record + + dataset = tf.data.Dataset.from_generator(generator, output_types=output_types, + output_shapes=output_shapes) + return dataset + + +def read_parquet(format, input_path, *args, **kwargs): + supported_format = {"tf_dataset"} + if format not in supported_format: + raise ValueError(format + " is not supported, should be 'tf_dataset'.") + + format_to_function = {"tf_dataset": (read_as_tfdataset, ["output_types"])} + func, required_args = format_to_function[format] + _check_arguments(format, kwargs, required_args) + return func(path=input_path, *args, **kwargs) diff --git a/python/orca/src/bigdl/orca/data/image/utils.py b/python/orca/src/bigdl/orca/data/image/utils.py index 9ba40c12dbc..67dfc5bcae8 100644 --- a/python/orca/src/bigdl/orca/data/image/utils.py +++ b/python/orca/src/bigdl/orca/data/image/utils.py @@ -15,9 +15,11 @@ # import copy +import os from collections import namedtuple from io import BytesIO import numpy as np +import pyarrow as pa from itertools import chain, islice from enum import Enum @@ -146,7 +148,27 @@ def dict_to_row(schema, row_dict): return pyspark.Row(**row) +def decode_feature_type_ndarray(df, schema): + for n, field in schema.items(): + if field.feature_type == FeatureType.NDARRAY: + df[n] = df[n].map(lambda k: decode_ndarray(k)) + return df + + def chunks(iterable, size=10): iterator = iter(iterable) for first in iterator: yield chain([first], islice(iterator, size - 1)) + + +def pa_fs(path): + if path.startswith("hdfs"): # hdfs://url:port/file_path + fs = pa.hdfs.connect() + path = path[len("hdfs://"):] + return path, fs + elif path.startswith("s3"): + raise ValueError("aws s3 is not supported for now") + else: # Local path + if path.startswith("file://"): + path = path[len("file://"):] + return path, pa.LocalFileSystem() diff --git a/python/orca/test/bigdl/orca/data/test_read_parquet_images.py b/python/orca/test/bigdl/orca/data/test_read_parquet_images.py new file mode 100644 index 00000000000..45f07380fa4 --- /dev/null +++ b/python/orca/test/bigdl/orca/data/test_read_parquet_images.py @@ -0,0 +1,115 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import shutil + +import pytest +from unittest import TestCase +import os +from zoo.orca.data.image.parquet_dataset import ParquetDataset, read_parquet +from zoo.orca.data.image.utils import DType, FeatureType, SchemaField +import tensorflow as tf +from zoo.ray import RayContext + +resource_path = os.path.join(os.path.split(__file__)[0], "../../resources") +WIDTH, HEIGHT, NUM_CHANNELS = 224, 224, 3 + + +def images_generator(): + dataset_path = os.path.join(resource_path, "cat_dog") + for root, dirs, files in os.walk(os.path.join(dataset_path, "cats")): + for name in files: + image_path = os.path.join(root, name) + yield {"image": image_path, "label": 1, "id": image_path} + + for root, dirs, files in os.walk(os.path.join(dataset_path, "dogs")): + for name in files: + image_path = os.path.join(root, name) + yield {"image": image_path, "label": 0, "id": image_path} + + +images_schema = { + "image": SchemaField(feature_type=FeatureType.IMAGE, dtype=DType.FLOAT32, shape=()), + "label": SchemaField(feature_type=FeatureType.SCALAR, dtype=DType.FLOAT32, shape=()), + "id": SchemaField(feature_type=FeatureType.SCALAR, dtype=DType.STRING, shape=()) +} + + +def parse_data_train(image, label): + image = tf.io.decode_jpeg(image, NUM_CHANNELS) + image = tf.image.resize(image, size=(WIDTH, HEIGHT)) + image = tf.reshape(image, [WIDTH, HEIGHT, NUM_CHANNELS]) + return image, label + + +def model_creator(config): + import tensorflow as tf + model = tf.keras.Sequential([ + tf.keras.layers.Flatten(input_shape=(224, 224, 3)), + tf.keras.layers.Dense(64, activation='relu'), + tf.keras.layers.Dense(2) + ]) + model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy']) + return model + + +class TestReadParquet(TestCase): + def test_read_parquet_images_tf_dataset(self): + temp_dir = tempfile.mkdtemp() + + try: + ParquetDataset.write("file://" + temp_dir, images_generator(), images_schema) + path = "file://" + temp_dir + output_types = {"id": tf.string, "image": tf.string, "label": tf.float32} + dataset = read_parquet("tf_dataset", input_path=path, output_types=output_types) + for dt in dataset.take(1): + print(dt.keys()) + + finally: + shutil.rmtree(temp_dir) + + def test_parquet_images_training(self): + from zoo.orca.learn.tf2 import Estimator + temp_dir = tempfile.mkdtemp() + try: + ParquetDataset.write("file://" + temp_dir, images_generator(), images_schema) + path = "file://" + temp_dir + output_types = {"id": tf.string, "image": tf.string, "label": tf.float32} + output_shapes = {"id": (), "image": (), "label": ()} + + def data_creator(config, batch_size): + dataset = read_parquet("tf_dataset", input_path=path, + output_types=output_types, output_shapes=output_shapes) + dataset = dataset.shuffle(10) + dataset = dataset.map(lambda data_dict: (data_dict["image"], data_dict["label"])) + dataset = dataset.map(parse_data_train) + dataset = dataset.batch(batch_size) + return dataset + + ray_ctx = RayContext.get() + trainer = Estimator.from_keras(model_creator=model_creator) + trainer.fit(data=data_creator, + epochs=1, + batch_size=2) + finally: + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + pytest.main([__file__])