Skip to content

Commit

Permalink
support tfds on hdfs (intel-analytics#2921)
Browse files Browse the repository at this point in the history
* support tfds on hdfs

* add documents

* fix keras

* remove duplicate

* fix style

* fix style and tests
  • Loading branch information
yangw1234 authored and Wang, Yang committed Sep 26, 2021
1 parent 14ee639 commit a9c95e8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 63 deletions.
64 changes: 63 additions & 1 deletion python/orca/src/bigdl/orca/learn/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql import DataFrame

from bigdl.optim.optimizer import MaxEpoch
from zoo.orca.data.tf.data import Dataset, TFDataDataset2

from zoo.orca.learn.tf.utils import *
from zoo.orca.learn.utils import find_latest_checkpoint
Expand Down Expand Up @@ -193,6 +195,66 @@ def save_keras_model(self, path, overwrite=True):
raise NotImplementedError()


def is_tf_data_dataset(data):
is_dataset = isinstance(data, tf.data.Dataset)
is_dataset_v2 = isinstance(data, tf.python.data.ops.dataset_ops.DatasetV2)
return is_dataset or is_dataset_v2


def to_dataset(data, batch_size, batch_per_thread, validation_data,
feature_cols, labels_cols, hard_code_batch_size,
sequential_order, shuffle, auto_shard_files):
# todo wrap argument into kwargs
if validation_data:
if isinstance(data, SparkXShards):
assert isinstance(validation_data, SparkXShards), \
"train data and validation data should be both SparkXShards"
if isinstance(data, Dataset):
assert isinstance(validation_data, Dataset), \
"train data and validation data should be both orca.data.tf.Dataset"
if isinstance(data, DataFrame):
assert isinstance(validation_data, DataFrame), \
"train data and validation data should be both Spark DataFrame"
if isinstance(data, tf.data.Dataset):
assert isinstance(validation_data, tf.data.Dataset), \
"train data and validation data should be both tf.data.Dataset"

if isinstance(data, SparkXShards):
dataset = xshards_to_tf_dataset(data,
batch_size,
batch_per_thread,
validation_data,
hard_code_batch_size=hard_code_batch_size,
sequential_order=sequential_order,
shuffle=shuffle)
elif isinstance(data, Dataset):
dataset = TFDataDataset2(data, batch_size=batch_size,
batch_per_thread=batch_per_thread,
validation_dataset=validation_data)
elif isinstance(data, DataFrame):
dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols,
batch_size,
batch_per_thread,
hard_code_batch_size,
validation_data,
sequential_order,
shuffle
)
elif is_tf_data_dataset(data):
dataset = TFDataset.from_tf_data_dataset(data,
batch_size,
batch_per_thread,
hard_code_batch_size,
validation_data,
sequential_order,
shuffle, auto_shard_files=auto_shard_files)
else:
raise ValueError("data must be SparkXShards or orca.data.tf.Dataset or "
"Spark DataFrame or tf.data.Dataset")

return dataset


class TFOptimizerWrapper(Estimator):
def __init__(self, *, inputs, outputs, labels, loss,
optimizer, clip_norm, clip_value,
Expand Down Expand Up @@ -493,7 +555,7 @@ def fit(self, data,
feature_cols=feature_cols, labels_cols=labels_cols,
hard_code_batch_size=hard_code_batch_size,
sequential_order=False, shuffle=True,
auto_shard_files=auto_shard_files,)
auto_shard_files=auto_shard_files)

self.tf_optimizer = TFOptimizer.from_keras(self.model.model, dataset,
model_dir=self.model.model_dir,
Expand Down
62 changes: 0 additions & 62 deletions python/orca/src/bigdl/orca/learn/tf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
import shutil
import tensorflow as tf
import numpy as np
from pyspark.sql.dataframe import DataFrame

from zoo.orca.data import SparkXShards
from zoo.orca.data.tf.data import Dataset, TFDataDataset2
from zoo.tfpark.tf_dataset import TFDataset
from zoo.orca.data.utils import get_spec, flatten_xy
from zoo.common.utils import put_local_file_to_remote, get_remote_file_to_local, get_file_list,\
Expand Down Expand Up @@ -64,66 +62,6 @@ def xshards_to_tf_dataset(data_shard,
return dataset


def is_tf_data_dataset(data):
is_dataset = isinstance(data, tf.data.Dataset)
is_dataset_v2 = isinstance(data, tf.python.data.ops.dataset_ops.DatasetV2)
return is_dataset or is_dataset_v2


def to_dataset(data, batch_size, batch_per_thread, validation_data,
feature_cols, labels_cols, hard_code_batch_size,
sequential_order, shuffle, auto_shard_files):
# todo wrap argument into kwargs
if validation_data:
if isinstance(data, SparkXShards):
assert isinstance(validation_data, SparkXShards), \
"train data and validation data should be both SparkXShards"
if isinstance(data, Dataset):
assert isinstance(validation_data, Dataset), \
"train data and validation data should be both orca.data.tf.Dataset"
if isinstance(data, DataFrame):
assert isinstance(validation_data, DataFrame), \
"train data and validation data should be both Spark DataFrame"
if isinstance(data, tf.data.Dataset):
assert isinstance(validation_data, tf.data.Dataset), \
"train data and validation data should be both tf.data.Dataset"

if isinstance(data, SparkXShards):
dataset = xshards_to_tf_dataset(data,
batch_size,
batch_per_thread,
validation_data,
hard_code_batch_size=hard_code_batch_size,
sequential_order=sequential_order,
shuffle=shuffle)
elif isinstance(data, Dataset):
dataset = TFDataDataset2(data, batch_size=batch_size,
batch_per_thread=batch_per_thread,
validation_dataset=validation_data)
elif isinstance(data, DataFrame):
dataset = TFDataset.from_dataframe(data, feature_cols, labels_cols,
batch_size,
batch_per_thread,
hard_code_batch_size,
validation_data,
sequential_order,
shuffle
)
elif is_tf_data_dataset(data):
dataset = TFDataset.from_tf_data_dataset(data,
batch_size,
batch_per_thread,
hard_code_batch_size,
validation_data,
sequential_order,
shuffle, auto_shard_files=auto_shard_files)
else:
raise ValueError("data must be SparkXShards or orca.data.tf.Dataset or "
"Spark DataFrame or tf.data.Dataset")

return dataset


def convert_predict_to_dataframe(df, prediction_rdd):
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, FloatType, ArrayType
Expand Down

0 comments on commit a9c95e8

Please sign in to comment.