From 0851ec1d3090c9bbe23789c66d3ef60592f0b55e Mon Sep 17 00:00:00 2001 From: Yu Shan Date: Mon, 22 Feb 2021 10:27:58 +0800 Subject: [PATCH] add shard size in orca context and fix style --- pyzoo/test/zoo/orca/learn/test_utils.py | 14 +++++++++----- pyzoo/zoo/orca/common.py | 18 ++++++++++++++++++ pyzoo/zoo/orca/learn/utils.py | 15 ++++++++------- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pyzoo/test/zoo/orca/learn/test_utils.py b/pyzoo/test/zoo/orca/learn/test_utils.py index e3072b99273..36ec4067d8b 100644 --- a/pyzoo/test/zoo/orca/learn/test_utils.py +++ b/pyzoo/test/zoo/orca/learn/test_utils.py @@ -187,8 +187,10 @@ def test_array2dict(self): from zoo.orca.learn.utils import arrays2dict record_num = 100 shard_size = 30 - data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)])) for i in range(record_num)] - result = arrays2dict(data, feature_cols=["feature"], label_cols=["label"], shard_size=shard_size) + data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)])) + for i in range(record_num)] + result = arrays2dict(data, feature_cols=["feature"], label_cols=["label"], + shard_size=shard_size) for i, d in enumerate(result): if (record_num % shard_size == 0) or (i != record_num // shard_size): assert d['x'].shape[0] == shard_size @@ -200,7 +202,8 @@ def test_array2dict(self): def test_array2dict_shard_size_none(self): from zoo.orca.learn.utils import arrays2dict record_num = 100 - data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)])) for i in range(record_num)] + data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)])) + for i in range(record_num)] result = arrays2dict(data, feature_cols=["feature"], label_cols=["label"], shard_size=None) for i, d in enumerate(result): assert d['x'].shape[0] == record_num @@ -217,8 +220,9 @@ def test_dataframe_to_xshards(self): num_shards = shards.rdd.count() assert num_shards == num_partitions - shard_size = 1 - shards = _dataframe_to_xshards(df, feature_cols=["feature"], label_cols=["label"], shard_size=shard_size) + from zoo.orca import OrcaContext + OrcaContext.shard_size = 1 + shards = _dataframe_to_xshards(df, feature_cols=["feature"], label_cols=["label"]) num_shards = shards.rdd.count() assert num_shards == df.rdd.count() diff --git a/pyzoo/zoo/orca/common.py b/pyzoo/zoo/orca/common.py index f5d1f457eb0..10af46b9233 100644 --- a/pyzoo/zoo/orca/common.py +++ b/pyzoo/zoo/orca/common.py @@ -24,6 +24,7 @@ class OrcaContextMeta(type): __eager_mode = True _serialize_data_creator = False _train_data_store = "DRAM" + _shard_size = None @property def log_output(cls): @@ -101,6 +102,23 @@ def train_data_store(cls, value): "train_data_store must be either DRAM or PMEM or DIRECT or DISK_n" cls._train_data_store = value + @property + def shard_size(cls): + """ + The number of Rows in Spark DataFrame to transform as one shard of SparkXShards. We convert + Spark DataFrame input to SparkXShards internally in fit/predict/evaluate of + PyTorchRayEstimator and TensorFlow2Estimator. This parameter may affect the performance in + transferring an SparkXShards to an RayXShards. + """ + return cls._shard_size + + @shard_size.setter + def shard_size(cls, value): + if value is not None: + assert isinstance(value, int) and value > 0, \ + "shard size should be either None or a positive integer." + cls._shard_size = value + class OrcaContext(metaclass=OrcaContextMeta): @staticmethod diff --git a/pyzoo/zoo/orca/learn/utils.py b/pyzoo/zoo/orca/learn/utils.py index d3d033032c2..67437493466 100644 --- a/pyzoo/zoo/orca/learn/utils.py +++ b/pyzoo/zoo/orca/learn/utils.py @@ -228,8 +228,10 @@ def generate_output(feature_lists, label_lists): yield generate_output(feature_lists, label_lists) -def _dataframe_to_xshards(data, feature_cols, label_cols=None, shard_size=None): +def _dataframe_to_xshards(data, feature_cols, label_cols=None): + from zoo.orca import OrcaContext schema = data.schema + shard_size = OrcaContext.shard_size numpy_rdd = data.rdd.map(lambda row: convert_row_to_numpy(row, schema, feature_cols, @@ -241,7 +243,7 @@ def _dataframe_to_xshards(data, feature_cols, label_cols=None, shard_size=None): return SparkXShards(shard_rdd) -def dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit", shard_size=None): +def dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit"): from pyspark.sql import DataFrame valid_mode = {"fit", "evaluate", "predict"} assert mode in valid_mode, f"invalid mode {mode} " \ @@ -255,21 +257,20 @@ def dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode=" assert label_cols is not None, \ "label_cols must be provided if data is a spark dataframe" - data = _dataframe_to_xshards(data, feature_cols, label_cols, shard_size=shard_size) + data = _dataframe_to_xshards(data, feature_cols, label_cols) if validation_data is not None: - validation_data = _dataframe_to_xshards(validation_data, feature_cols, label_cols, shard_size=shard_size) + validation_data = _dataframe_to_xshards(validation_data, feature_cols, label_cols) return data, validation_data -def maybe_dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit", shard_size=None): +def maybe_dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit"): from pyspark.sql import DataFrame if isinstance(data, DataFrame): data, validation_data = dataframe_to_xshards(data, validation_data, feature_cols=feature_cols, label_cols=label_cols, - mode=mode, - shard_size=shard_size) + mode=mode) return data, validation_data