diff --git a/python/orca/src/bigdl/orca/data/pandas/preprocessing.py b/python/orca/src/bigdl/orca/data/pandas/preprocessing.py index 7cd55df6881..239e681b76f 100644 --- a/python/orca/src/bigdl/orca/data/pandas/preprocessing.py +++ b/python/orca/src/bigdl/orca/data/pandas/preprocessing.py @@ -19,6 +19,7 @@ from bigdl.util.common import get_node_and_core_number from pyspark.context import SparkContext +from zoo import ZooContext from zoo.ray import RayContext from zoo.orca.data.shard import RayXShards, RayPartition, SparkXShards from zoo.orca.data.utils import * @@ -107,22 +108,48 @@ def read_file_spark(context, file_path, file_type, **kwargs): if not file_paths: raise Exception("The file path is invalid/empty or does not include csv/json files") - num_files = len(file_paths) - total_cores = node_num * core_num - num_partitions = num_files if num_files < total_cores else total_cores - rdd = context.parallelize(file_paths, num_partitions) + if ZooContext.orca_pandas_read_backend == "pandas": + num_files = len(file_paths) + total_cores = node_num * core_num + num_partitions = num_files if num_files < total_cores else total_cores + rdd = context.parallelize(file_paths, num_partitions) - if prefix == "hdfs": - pd_rdd = rdd.mapPartitions(lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) - elif prefix == "s3": - pd_rdd = rdd.mapPartitions(lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) + if prefix == "hdfs": + pd_rdd = rdd.mapPartitions( + lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs)) + elif prefix == "s3": + pd_rdd = rdd.mapPartitions( + lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs)) + else: + def loadFile(iterator): + for x in iterator: + df = read_pd_file(x, file_type, **kwargs) + yield df + + pd_rdd = rdd.mapPartitions(loadFile) else: - def loadFile(iterator): - for x in iterator: - df = read_pd_file(x, file_type, **kwargs) - yield df + from pyspark.sql import SQLContext + sqlContext = SQLContext.getOrCreate(context) + spark = sqlContext.sparkSession + # TODO: add S3 confidentials + if file_type == "json": + df = spark.read.json(file_paths, **kwargs) + elif file_type == "csv": + df = spark.read.csv(file_paths, **kwargs) + else: + raise Exception("Unsupported file type") + if df.rdd.getNumPartitions() < node_num: + df = df.repartition(node_num) + + def to_pandas(columns): + def f(iter): + import pandas as pd + data = list(iter) + yield pd.DataFrame(data, columns=columns) + + return f - pd_rdd = rdd.mapPartitions(loadFile) + pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns)) data_shards = SparkXShards(pd_rdd) return data_shards diff --git a/python/orca/test/bigdl/orca/data/conftest.py b/python/orca/test/bigdl/orca/data/conftest.py index 8c98c07fea6..59bdabe9278 100644 --- a/python/orca/test/bigdl/orca/data/conftest.py +++ b/python/orca/test/bigdl/orca/data/conftest.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import os.path +import os from zoo import ZooContext import pytest @@ -26,6 +26,7 @@ def orca_data_fixture(): from zoo import init_spark_on_local from zoo.ray import RayContext + global sc global ray_ctx ZooContext._orca_eager_mode = True sc = init_spark_on_local(cores=4, spark_log_level="INFO") @@ -45,16 +46,9 @@ def orca_data_fixture(): sc.stop() -# @pytest.fixture() -# def setUpModule(): -# sc = init_spark_on_local(cores=4, spark_log_level="INFO") -# ray_ctx = RayContext(sc=sc) -# ray_ctx.init() -# -# -# def tearDownModule(): -# ray_ctx.stop() -# sc.stop() - def get_ray_ctx(): return ray_ctx + + +def get_spark_ctx(): + return sc diff --git a/python/orca/test/bigdl/orca/data/test_spark_pandas.py b/python/orca/test/bigdl/orca/data/test_spark_pandas.py index d8543a9d8f5..a67d8247a81 100644 --- a/python/orca/test/bigdl/orca/data/test_spark_pandas.py +++ b/python/orca/test/bigdl/orca/data/test_spark_pandas.py @@ -18,27 +18,22 @@ import shutil import pytest +from unittest import TestCase import zoo.orca.data import zoo.orca.data.pandas from zoo.orca.data.shard import SharedValue -from test.zoo.pipeline.utils.test_utils import ZooTestCase from zoo.common.nncontext import * +from test.zoo.orca.data.conftest import get_spark_ctx -class TestSparkXShards(ZooTestCase): +class TestSparkXShards(TestCase): def setup_method(self, method): self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources") - sparkConf = init_spark_conf().setMaster("local[4]").setAppName("testSparkXShards") - self.sc = init_nncontext(sparkConf) + self.sc = get_spark_ctx() - def teardown_method(self, method): - """ teardown any state that was previously setup with a setup_method - call. - """ - self.sc.stop() - - def test_read_local_csv(self): + def test_read_local_csv_pandas_backend(self): + ZooContext.orca_pandas_read_backend = "pandas" file_path = os.path.join(self.resource_path, "orca/data/csv") data_shard = zoo.orca.data.pandas.read_csv(file_path, self.sc) data = data_shard.collect() @@ -50,7 +45,22 @@ def test_read_local_csv(self): xshards = zoo.orca.data.pandas.read_csv(file_path, self.sc) self.assertTrue('The file path is invalid/empty' in str(context.exception)) - def test_read_local_json(self): + def test_read_local_csv_spark_backend(self): + ZooContext.orca_pandas_read_backend = "spark" + file_path = os.path.join(self.resource_path, "orca/data/csv") + data_shard = zoo.orca.data.pandas.read_csv(file_path, self.sc, header=True) + data = data_shard.collect() + df = data[0] + assert "location" in df.columns, "location is not in columns" + file_path = os.path.join(self.resource_path, "abc") + with self.assertRaises(Exception) as context: + xshards = zoo.orca.data.pandas.read_csv(file_path, self.sc) + self.assertTrue('The file path is invalid/empty' in str(context.exception)) + # Change the backend to default pandas so that this won't affect other unit tests. + ZooContext.orca_pandas_read_backend = "pandas" + + def test_read_local_json_pandas_backend(self): + ZooContext.orca_pandas_read_backend = "pandas" file_path = os.path.join(self.resource_path, "orca/data/json") data_shard = zoo.orca.data.pandas.read_json(file_path, self.sc, orient='columns', lines=True) @@ -59,6 +69,15 @@ def test_read_local_json(self): df = data[0] assert "value" in df.columns, "value is not in columns" + def test_read_local_json_spark_backend(self): + ZooContext.orca_pandas_read_backend = "spark" + file_path = os.path.join(self.resource_path, "orca/data/json") + data_shard = zoo.orca.data.pandas.read_json(file_path, self.sc) + data = data_shard.collect() + df = data[0] + assert "value" in df.columns, "value is not in columns" + ZooContext.orca_pandas_read_backend = "pandas" + def test_read_s3(self): access_key_id = os.getenv("AWS_ACCESS_KEY_ID") secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")