diff --git a/python/orca/src/bigdl/orca/learn/tf/estimator.py b/python/orca/src/bigdl/orca/learn/tf/estimator.py index ecc5aed47bc..29c06c647cf 100644 --- a/python/orca/src/bigdl/orca/learn/tf/estimator.py +++ b/python/orca/src/bigdl/orca/learn/tf/estimator.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from pyspark.sql import DataFrame from bigdl.optim.optimizer import MaxEpoch +from zoo.orca.data.tf.data import Dataset, TFDataDataset2 from zoo.orca.learn.tf.utils import * from zoo.orca.learn.utils import find_latest_checkpoint @@ -193,6 +195,66 @@ def save_keras_model(self, path, overwrite=True): raise NotImplementedError() +def is_tf_data_dataset(data): + is_dataset = isinstance(data, tf.data.Dataset) + is_dataset_v2 = isinstance(data, tf.python.data.ops.dataset_ops.DatasetV2) + return is_dataset or is_dataset_v2 + + +def to_dataset(data, batch_size, batch_per_thread, validation_data, + feature_cols, labels_cols, hard_code_batch_size, + sequential_order, shuffle, auto_shard_files): + # todo wrap argument into kwargs + if validation_data: + if isinstance(data, SparkXShards): + assert isinstance(validation_data, SparkXShards), \ + "train data and validation data should be both SparkXShards" + if isinstance(data, Dataset): + assert isinstance(validation_data, Dataset), \ + "train data and validation data should be both orca.data.tf.Dataset" + if isinstance(data, DataFrame): + assert isinstance(validation_data, DataFrame), \ + "train data and validation data should be both Spark DataFrame" + if isinstance(data, tf.data.Dataset): + assert isinstance(validation_data, tf.data.Dataset), \ + "train data and validation data should be both tf.data.Dataset" + + if isinstance(data, SparkXShards): + dataset = xshards_to_tf_dataset(data, + batch_size, + batch_per_thread, + validation_data, + hard_code_batch_size=hard_code_batch_size, + sequential_order=sequential_order, + shuffle=shuffle) + elif isinstance(data, Dataset): + dataset = TFDataDataset2(data, batch_size=batch_size, + batch_per_thread=batch_per_thread, + validation_dataset=validation_data) + elif isinstance(data, DataFrame): + dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols, + batch_size, + batch_per_thread, + hard_code_batch_size, + validation_data, + sequential_order, + shuffle + ) + elif is_tf_data_dataset(data): + dataset = TFDataset.from_tf_data_dataset(data, + batch_size, + batch_per_thread, + hard_code_batch_size, + validation_data, + sequential_order, + shuffle, auto_shard_files=auto_shard_files) + else: + raise ValueError("data must be SparkXShards or orca.data.tf.Dataset or " + "Spark DataFrame or tf.data.Dataset") + + return dataset + + class TFOptimizerWrapper(Estimator): def __init__(self, *, inputs, outputs, labels, loss, optimizer, clip_norm, clip_value, @@ -493,7 +555,7 @@ def fit(self, data, feature_cols=feature_cols, labels_cols=labels_cols, hard_code_batch_size=hard_code_batch_size, sequential_order=False, shuffle=True, - auto_shard_files=auto_shard_files,) + auto_shard_files=auto_shard_files) self.tf_optimizer = TFOptimizer.from_keras(self.model.model, dataset, model_dir=self.model.model_dir, diff --git a/python/orca/src/bigdl/orca/learn/tf/utils.py b/python/orca/src/bigdl/orca/learn/tf/utils.py index 031b782ef78..83a2af420b6 100644 --- a/python/orca/src/bigdl/orca/learn/tf/utils.py +++ b/python/orca/src/bigdl/orca/learn/tf/utils.py @@ -20,10 +20,8 @@ import shutil import tensorflow as tf import numpy as np -from pyspark.sql.dataframe import DataFrame from zoo.orca.data import SparkXShards -from zoo.orca.data.tf.data import Dataset, TFDataDataset2 from zoo.tfpark.tf_dataset import TFDataset from zoo.orca.data.utils import get_spec, flatten_xy from zoo.common.utils import put_local_file_to_remote, get_remote_file_to_local, get_file_list,\ @@ -64,66 +62,6 @@ def xshards_to_tf_dataset(data_shard, return dataset -def is_tf_data_dataset(data): - is_dataset = isinstance(data, tf.data.Dataset) - is_dataset_v2 = isinstance(data, tf.python.data.ops.dataset_ops.DatasetV2) - return is_dataset or is_dataset_v2 - - -def to_dataset(data, batch_size, batch_per_thread, validation_data, - feature_cols, labels_cols, hard_code_batch_size, - sequential_order, shuffle, auto_shard_files): - # todo wrap argument into kwargs - if validation_data: - if isinstance(data, SparkXShards): - assert isinstance(validation_data, SparkXShards), \ - "train data and validation data should be both SparkXShards" - if isinstance(data, Dataset): - assert isinstance(validation_data, Dataset), \ - "train data and validation data should be both orca.data.tf.Dataset" - if isinstance(data, DataFrame): - assert isinstance(validation_data, DataFrame), \ - "train data and validation data should be both Spark DataFrame" - if isinstance(data, tf.data.Dataset): - assert isinstance(validation_data, tf.data.Dataset), \ - "train data and validation data should be both tf.data.Dataset" - - if isinstance(data, SparkXShards): - dataset = xshards_to_tf_dataset(data, - batch_size, - batch_per_thread, - validation_data, - hard_code_batch_size=hard_code_batch_size, - sequential_order=sequential_order, - shuffle=shuffle) - elif isinstance(data, Dataset): - dataset = TFDataDataset2(data, batch_size=batch_size, - batch_per_thread=batch_per_thread, - validation_dataset=validation_data) - elif isinstance(data, DataFrame): - dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols, - batch_size, - batch_per_thread, - hard_code_batch_size, - validation_data, - sequential_order, - shuffle - ) - elif is_tf_data_dataset(data): - dataset = TFDataset.from_tf_data_dataset(data, - batch_size, - batch_per_thread, - hard_code_batch_size, - validation_data, - sequential_order, - shuffle, auto_shard_files=auto_shard_files) - else: - raise ValueError("data must be SparkXShards or orca.data.tf.Dataset or " - "Spark DataFrame or tf.data.Dataset") - - return dataset - - def convert_predict_to_dataframe(df, prediction_rdd): from pyspark.sql import Row from pyspark.sql.types import StructType, StructField, FloatType, ArrayType