Skip to content

Commit

Permalink
Spark backend for reading csv/json (intel-analytics#2510)
Browse files Browse the repository at this point in the history
* initial

* update doc

* add utr

* style

* meet review

* fix ut
  • Loading branch information
hkvision authored and Wang, Yang committed Sep 26, 2021
1 parent d039732 commit 24b747d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 37 deletions.
53 changes: 40 additions & 13 deletions python/orca/src/bigdl/orca/data/pandas/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from bigdl.util.common import get_node_and_core_number
from pyspark.context import SparkContext

from zoo import ZooContext
from zoo.ray import RayContext
from zoo.orca.data.shard import RayXShards, RayPartition, SparkXShards
from zoo.orca.data.utils import *
Expand Down Expand Up @@ -107,22 +108,48 @@ def read_file_spark(context, file_path, file_type, **kwargs):
if not file_paths:
raise Exception("The file path is invalid/empty or does not include csv/json files")

num_files = len(file_paths)
total_cores = node_num * core_num
num_partitions = num_files if num_files < total_cores else total_cores
rdd = context.parallelize(file_paths, num_partitions)
if ZooContext.orca_pandas_read_backend == "pandas":
num_files = len(file_paths)
total_cores = node_num * core_num
num_partitions = num_files if num_files < total_cores else total_cores
rdd = context.parallelize(file_paths, num_partitions)

if prefix == "hdfs":
pd_rdd = rdd.mapPartitions(lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs))
elif prefix == "s3":
pd_rdd = rdd.mapPartitions(lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs))
if prefix == "hdfs":
pd_rdd = rdd.mapPartitions(
lambda iter: read_pd_hdfs_file_list(iter, file_type, **kwargs))
elif prefix == "s3":
pd_rdd = rdd.mapPartitions(
lambda iter: read_pd_s3_file_list(iter, file_type, **kwargs))
else:
def loadFile(iterator):
for x in iterator:
df = read_pd_file(x, file_type, **kwargs)
yield df

pd_rdd = rdd.mapPartitions(loadFile)
else:
def loadFile(iterator):
for x in iterator:
df = read_pd_file(x, file_type, **kwargs)
yield df
from pyspark.sql import SQLContext
sqlContext = SQLContext.getOrCreate(context)
spark = sqlContext.sparkSession
# TODO: add S3 confidentials
if file_type == "json":
df = spark.read.json(file_paths, **kwargs)
elif file_type == "csv":
df = spark.read.csv(file_paths, **kwargs)
else:
raise Exception("Unsupported file type")
if df.rdd.getNumPartitions() < node_num:
df = df.repartition(node_num)

def to_pandas(columns):
def f(iter):
import pandas as pd
data = list(iter)
yield pd.DataFrame(data, columns=columns)

return f

pd_rdd = rdd.mapPartitions(loadFile)
pd_rdd = df.rdd.mapPartitions(to_pandas(df.columns))

data_shards = SparkXShards(pd_rdd)
return data_shards
Expand Down
18 changes: 6 additions & 12 deletions python/orca/test/bigdl/orca/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os.path
import os
from zoo import ZooContext

import pytest
Expand All @@ -26,6 +26,7 @@
def orca_data_fixture():
from zoo import init_spark_on_local
from zoo.ray import RayContext
global sc
global ray_ctx
ZooContext._orca_eager_mode = True
sc = init_spark_on_local(cores=4, spark_log_level="INFO")
Expand All @@ -45,16 +46,9 @@ def orca_data_fixture():
sc.stop()


# @pytest.fixture()
# def setUpModule():
# sc = init_spark_on_local(cores=4, spark_log_level="INFO")
# ray_ctx = RayContext(sc=sc)
# ray_ctx.init()
#
#
# def tearDownModule():
# ray_ctx.stop()
# sc.stop()

def get_ray_ctx():
return ray_ctx


def get_spark_ctx():
return sc
43 changes: 31 additions & 12 deletions python/orca/test/bigdl/orca/data/test_spark_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,22 @@
import shutil

import pytest
from unittest import TestCase

import zoo.orca.data
import zoo.orca.data.pandas
from zoo.orca.data.shard import SharedValue
from test.zoo.pipeline.utils.test_utils import ZooTestCase
from zoo.common.nncontext import *
from test.zoo.orca.data.conftest import get_spark_ctx


class TestSparkXShards(ZooTestCase):
class TestSparkXShards(TestCase):
def setup_method(self, method):
self.resource_path = os.path.join(os.path.split(__file__)[0], "../../resources")
sparkConf = init_spark_conf().setMaster("local[4]").setAppName("testSparkXShards")
self.sc = init_nncontext(sparkConf)
self.sc = get_spark_ctx()

def teardown_method(self, method):
""" teardown any state that was previously setup with a setup_method
call.
"""
self.sc.stop()

def test_read_local_csv(self):
def test_read_local_csv_pandas_backend(self):
ZooContext.orca_pandas_read_backend = "pandas"
file_path = os.path.join(self.resource_path, "orca/data/csv")
data_shard = zoo.orca.data.pandas.read_csv(file_path, self.sc)
data = data_shard.collect()
Expand All @@ -50,7 +45,22 @@ def test_read_local_csv(self):
xshards = zoo.orca.data.pandas.read_csv(file_path, self.sc)
self.assertTrue('The file path is invalid/empty' in str(context.exception))

def test_read_local_json(self):
def test_read_local_csv_spark_backend(self):
ZooContext.orca_pandas_read_backend = "spark"
file_path = os.path.join(self.resource_path, "orca/data/csv")
data_shard = zoo.orca.data.pandas.read_csv(file_path, self.sc, header=True)
data = data_shard.collect()
df = data[0]
assert "location" in df.columns, "location is not in columns"
file_path = os.path.join(self.resource_path, "abc")
with self.assertRaises(Exception) as context:
xshards = zoo.orca.data.pandas.read_csv(file_path, self.sc)
self.assertTrue('The file path is invalid/empty' in str(context.exception))
# Change the backend to default pandas so that this won't affect other unit tests.
ZooContext.orca_pandas_read_backend = "pandas"

def test_read_local_json_pandas_backend(self):
ZooContext.orca_pandas_read_backend = "pandas"
file_path = os.path.join(self.resource_path, "orca/data/json")
data_shard = zoo.orca.data.pandas.read_json(file_path, self.sc,
orient='columns', lines=True)
Expand All @@ -59,6 +69,15 @@ def test_read_local_json(self):
df = data[0]
assert "value" in df.columns, "value is not in columns"

def test_read_local_json_spark_backend(self):
ZooContext.orca_pandas_read_backend = "spark"
file_path = os.path.join(self.resource_path, "orca/data/json")
data_shard = zoo.orca.data.pandas.read_json(file_path, self.sc)
data = data_shard.collect()
df = data[0]
assert "value" in df.columns, "value is not in columns"
ZooContext.orca_pandas_read_backend = "pandas"

def test_read_s3(self):
access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
Expand Down

0 comments on commit 24b747d

Please sign in to comment.