Skip to content

Commit

Permalink
Orca DataShards to XShards (intel-analytics#2344)
Browse files Browse the repository at this point in the history
* init

* datashards to xshards
  • Loading branch information
cyita authored and yangw1234 committed Sep 27, 2021
1 parent 831dcda commit e9eedf0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
14 changes: 7 additions & 7 deletions python/orca/src/bigdl/orca/data/pandas/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@
from bigdl.util.common import get_node_and_core_number

from zoo.ray import RayContext
from zoo.orca.data.shard import RayDataShards, RayPartition, SparkDataShards
from zoo.orca.data.shard import RayXShards, RayPartition, SparkXShards
from zoo.orca.data.utils import *


def read_csv(file_path, context, **kwargs):
"""
Read csv files to DataShards
Read csv files to XShards
:param file_path: could be a csv file, multiple csv file paths separated by comma,
a directory containing csv files.
Supported file systems are local file system, hdfs, and s3.
:param context: SparkContext or RayContext
:return: DataShards
:return: XShards
"""
if isinstance(context, RayContext):
return read_file_ray(context, file_path, "csv", **kwargs)
Expand All @@ -45,12 +45,12 @@ def read_csv(file_path, context, **kwargs):

def read_json(file_path, context, **kwargs):
"""
Read json files to DataShards
Read json files to XShards
:param file_path: could be a json file, multiple json file paths separated by comma,
a directory containing json files.
Supported file systems are local file system, hdfs, and s3.
:param context: SparkContext or RayContext
:return: DataShards
:return: XShards
"""
if isinstance(context, RayContext):
return read_file_ray(context, file_path, "json", **kwargs)
Expand Down Expand Up @@ -86,7 +86,7 @@ def read_file_ray(context, file_path, file_type, **kwargs):

# create initial partition
partitions = [RayPartition([shard]) for shard in shards]
data_shards = RayDataShards(partitions)
data_shards = RayXShards(partitions)
return data_shards


Expand Down Expand Up @@ -158,7 +158,7 @@ def loadFile(iterator):

pd_rdd = rdd.mapPartitions(loadFile)

data_shards = SparkDataShards(pd_rdd)
data_shards = SparkXShards(pd_rdd)
return data_shards


Expand Down
24 changes: 12 additions & 12 deletions python/orca/src/bigdl/orca/data/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from zoo.orca.data.utils import *


class DataShards(object):
class XShards(object):
"""
A collection of data which can be pre-processed parallelly.
"""

def transform_shard(self, func, *args):
"""
Transform each shard in the DataShards using func
Transform each shard in the XShards using func
:param func: pre-processing function
:param args: arguments for the pre-processing function
:return: DataShard
Expand All @@ -33,13 +33,13 @@ def transform_shard(self, func, *args):

def collect(self):
"""
Returns a list that contains all of the elements in this DataShards
Returns a list that contains all of the elements in this XShards
:return: list of elements
"""
pass


class RayDataShards(DataShards):
class RayXShards(XShards):
"""
A collection of data which can be pre-processed parallelly on Ray
"""
Expand All @@ -49,7 +49,7 @@ def __init__(self, partitions):

def transform_shard(self, func, *args):
"""
Transform each shard in the DataShards using func
Transform each shard in the XShards using func
:param func: pre-processing function.
In the function, the element object should be the first argument
:param args: rest arguments for the pre-processing function
Expand All @@ -64,33 +64,33 @@ def transform_shard(self, func, *args):

def collect(self):
"""
Returns a list that contains all of the elements in this DataShards
Returns a list that contains all of the elements in this XShards
:return: list of elements
"""
import ray
return ray.get([shard.get_data.remote() for shard in self.shard_list])

def repartition(self, num_partitions):
"""
Repartition DataShards.
Repartition XShards.
:param num_partitions: number of partitions
:return: this DataShards
:return: this XShards
"""
shards_partitions = list(chunk(self.shard_list, num_partitions))
self.partitions = [RayPartition(shards) for shards in shards_partitions]
return self

def get_partitions(self):
"""
Return partition list of the DataShards
Return partition list of the XShards
:return: partition list
"""
return self.partitions


class RayPartition(object):
"""
Partition of RayDataShards
Partition of RayXShards
"""

def __init__(self, shard_list):
Expand All @@ -100,7 +100,7 @@ def get_data(self):
return [shard.get_data.remote() for shard in self.shard_list]


class SparkDataShards(DataShards):
class SparkXShards(XShards):
def __init__(self, rdd):
self.rdd = rdd

Expand Down Expand Up @@ -149,5 +149,5 @@ def merge(iterator):
self.rdd = partitioned_rdd.mapPartitions(merge)
return self
else:
raise Exception("Currently only support partition by for Datashards"
raise Exception("Currently only support partition by for XShards"
" of Pandas DataFrame")
2 changes: 1 addition & 1 deletion python/orca/test/bigdl/orca/data/test_ray_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from test.zoo.orca.data.conftest import get_ray_ctx


class TestRayDataShards(ZooTestCase):
class TestRayXShards(ZooTestCase):
def setup_method(self, method):
self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
self.ray_ctx = get_ray_ctx()
Expand Down
4 changes: 2 additions & 2 deletions python/orca/test/bigdl/orca/data/test_spark_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from zoo.common.nncontext import *


class TestSparkDataShards(ZooTestCase):
class TestSparkXShards(ZooTestCase):
def setup_method(self, method):
self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
sparkConf = init_spark_conf().setMaster("local[4]").setAppName("testSparkDataShards")
sparkConf = init_spark_conf().setMaster("local[4]").setAppName("testSparkXShards")
self.sc = init_nncontext(sparkConf)

def teardown_method(self, method):
Expand Down

0 comments on commit e9eedf0

Please sign in to comment.