Skip to content

Commit

Permalink
add shard size in orca context and fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
shanyu-sys committed Feb 22, 2021
1 parent 5b04049 commit 0851ec1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
14 changes: 9 additions & 5 deletions pyzoo/test/zoo/orca/learn/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,10 @@ def test_array2dict(self):
from zoo.orca.learn.utils import arrays2dict
record_num = 100
shard_size = 30
data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)])) for i in range(record_num)]
result = arrays2dict(data, feature_cols=["feature"], label_cols=["label"], shard_size=shard_size)
data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)]))
for i in range(record_num)]
result = arrays2dict(data, feature_cols=["feature"], label_cols=["label"],
shard_size=shard_size)
for i, d in enumerate(result):
if (record_num % shard_size == 0) or (i != record_num // shard_size):
assert d['x'].shape[0] == shard_size
Expand All @@ -200,7 +202,8 @@ def test_array2dict(self):
def test_array2dict_shard_size_none(self):
from zoo.orca.learn.utils import arrays2dict
record_num = 100
data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)])) for i in range(record_num)]
data = [(np.float32(np.random.randn(1, 50)), np.float32([np.random.randint(0, 2,)]))
for i in range(record_num)]
result = arrays2dict(data, feature_cols=["feature"], label_cols=["label"], shard_size=None)
for i, d in enumerate(result):
assert d['x'].shape[0] == record_num
Expand All @@ -217,8 +220,9 @@ def test_dataframe_to_xshards(self):
num_shards = shards.rdd.count()
assert num_shards == num_partitions

shard_size = 1
shards = _dataframe_to_xshards(df, feature_cols=["feature"], label_cols=["label"], shard_size=shard_size)
from zoo.orca import OrcaContext
OrcaContext.shard_size = 1
shards = _dataframe_to_xshards(df, feature_cols=["feature"], label_cols=["label"])
num_shards = shards.rdd.count()
assert num_shards == df.rdd.count()

Expand Down
18 changes: 18 additions & 0 deletions pyzoo/zoo/orca/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class OrcaContextMeta(type):
__eager_mode = True
_serialize_data_creator = False
_train_data_store = "DRAM"
_shard_size = None

@property
def log_output(cls):
Expand Down Expand Up @@ -101,6 +102,23 @@ def train_data_store(cls, value):
"train_data_store must be either DRAM or PMEM or DIRECT or DISK_n"
cls._train_data_store = value

@property
def shard_size(cls):
"""
The number of Rows in Spark DataFrame to transform as one shard of SparkXShards. We convert
Spark DataFrame input to SparkXShards internally in fit/predict/evaluate of
PyTorchRayEstimator and TensorFlow2Estimator. This parameter may affect the performance in
transferring an SparkXShards to an RayXShards.
"""
return cls._shard_size

@shard_size.setter
def shard_size(cls, value):
if value is not None:
assert isinstance(value, int) and value > 0, \
"shard size should be either None or a positive integer."
cls._shard_size = value


class OrcaContext(metaclass=OrcaContextMeta):
@staticmethod
Expand Down
15 changes: 8 additions & 7 deletions pyzoo/zoo/orca/learn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ def generate_output(feature_lists, label_lists):
yield generate_output(feature_lists, label_lists)


def _dataframe_to_xshards(data, feature_cols, label_cols=None, shard_size=None):
def _dataframe_to_xshards(data, feature_cols, label_cols=None):
from zoo.orca import OrcaContext
schema = data.schema
shard_size = OrcaContext.shard_size
numpy_rdd = data.rdd.map(lambda row: convert_row_to_numpy(row,
schema,
feature_cols,
Expand All @@ -241,7 +243,7 @@ def _dataframe_to_xshards(data, feature_cols, label_cols=None, shard_size=None):
return SparkXShards(shard_rdd)


def dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit", shard_size=None):
def dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit"):
from pyspark.sql import DataFrame
valid_mode = {"fit", "evaluate", "predict"}
assert mode in valid_mode, f"invalid mode {mode} " \
Expand All @@ -255,21 +257,20 @@ def dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="
assert label_cols is not None, \
"label_cols must be provided if data is a spark dataframe"

data = _dataframe_to_xshards(data, feature_cols, label_cols, shard_size=shard_size)
data = _dataframe_to_xshards(data, feature_cols, label_cols)
if validation_data is not None:
validation_data = _dataframe_to_xshards(validation_data, feature_cols, label_cols, shard_size=shard_size)
validation_data = _dataframe_to_xshards(validation_data, feature_cols, label_cols)

return data, validation_data


def maybe_dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit", shard_size=None):
def maybe_dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit"):
from pyspark.sql import DataFrame
if isinstance(data, DataFrame):
data, validation_data = dataframe_to_xshards(data, validation_data,
feature_cols=feature_cols,
label_cols=label_cols,
mode=mode,
shard_size=shard_size)
mode=mode)
return data, validation_data


Expand Down

0 comments on commit 0851ec1

Please sign in to comment.