Skip to content

Commit

Permalink
[WIP] spark 3.0 (intel-analytics#3054)
Browse files Browse the repository at this point in the history
* spark 3.0
  • Loading branch information
Le-Zheng committed Sep 27, 2020
1 parent b8daffe commit f6d9181
Show file tree
Hide file tree
Showing 18 changed files with 591 additions and 49 deletions.
55 changes: 55 additions & 0 deletions scala/common/spark-version/3.0/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>spark-version</artifactId>
<groupId>com.intel.analytics.bigdl</groupId>
<version>0.12.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>

<groupId>com.intel.analytics.bigdl.spark-version</groupId>
<artifactId>3.0</artifactId>
<packaging>jar</packaging>

<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.12</artifactId>
<version>${spark.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
<version>0.8.0</version>
<configuration>
<verbose>false</verbose>
<failOnViolation>true</failOnViolation>
<failOnWarning>false</failOnWarning>
<sourceDirectory>${basedir}/src/main/scala</sourceDirectory>
<configLocation>${project.parent.parent.parent.basedir}/scalastyle_config.xml</configLocation>
<outputFile>${project.build.directory}/stylecheck/scalastyle-output.xml</outputFile>
<outputEncoding>UTF-8</outputEncoding>
</configuration>
<executions>
<execution>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml

import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasLabelCol
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

/**
* Handle different Vector types in Spark 1.5/1.6 and Spark 2.0+.
* Support both ML Vector and MLlib Vector for Spark 2.0+.
*/
trait VectorCompatibility {

val validVectorTypes = Seq(new VectorUDT, new org.apache.spark.mllib.linalg.VectorUDT)

def getVectorSeq(row: Row, colType: DataType, index: Int): Seq[AnyVal] = {
if (colType == new VectorUDT) {
row.getAs[Vector](index).toArray.toSeq
} else if (colType == new org.apache.spark.mllib.linalg.VectorUDT) {
row.getAs[org.apache.spark.mllib.linalg.Vector](index).toArray.toSeq
} else {
throw new IllegalArgumentException(
s"$colType is not a supported vector type.")
}
}
}


/**
*A wrapper from org.apache.spark.ml.Estimator
* Extends MLEstimator and override process to gain compatibility with
* both spark 1.5 and spark 2.0.
*/
abstract class DLEstimatorBase[Learner <: DLEstimatorBase[Learner, M],
M <: DLTransformerBase[M]]
extends Estimator[M] with HasLabelCol {

protected def internalFit(dataFrame: DataFrame): M

override def fit(dataset: Dataset[_]): M = {
transformSchema(dataset.schema, logging = true)
internalFit(dataset.toDF())
}

override def copy(extra: ParamMap): Learner = defaultCopy(extra)

}



Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml

import org.apache.spark.ml.param.ParamMap
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}

/**
* A wrapper for org.apache.spark.ml.Transformer.
* Extends MlTransformer and override process to gain compatibility with
* both spark 1.5 and spark 2.0.
*/
abstract class DLTransformerBase[M <: DLTransformerBase[M]]
extends Model[M] {

/**
* convert feature columns(MLlib Vectors or Array) to Seq format
*/
protected def internalTransform(dataFrame: DataFrame): DataFrame

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
internalTransform(dataset.toDF())
}

override def copy(extra: ParamMap): M = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.rdd

import java.io.{IOException, ObjectOutputStream}

import org.apache.log4j.Logger
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object ZippedPartitionsWithLocalityRDD {
def apply[T: ClassTag, B: ClassTag, V: ClassTag]
(rdd1: RDD[T], rdd2: RDD[B], preservesPartitioning: Boolean = false)
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = rdd1.withScope {
val sc = rdd1.sparkContext
new ZippedPartitionsWithLocalityRDD(
sc, sc.clean(f), rdd1, rdd2, preservesPartitioning)
}

val logger: Logger = Logger.getLogger(getClass)
}

/**
* Prefer to zip partitions of rdd1 and rdd2 in the same location.
* Remaining partitions not in same location will be zipped by order.
* For example:
* Say we have two RDDs, rdd1 and rdd2. The first partition of rdd1 is on node A, and the second
* is on node B. The first partition of rdd2 is on node B and the second one is on node A.
* If we just use rdd1.zipPartition(rdd2), the result will be the first partition of rdd1 is
* zipped with the first partition of rdd2, so there will be cross node communication. This is
* bad for performance. That's why we introduce the ZippedPartitionsWithLocalityRDD.
* In our method, the first partition of rdd1 will be zipped with the second partition of rdd2,
* as they are on the same node. This will reduce the network communication cost and result in
* a better performance.
* @param sc spark context
* @param _f
* @param _rdd1
* @param _rdd2
* @param preservesPartitioning
*/
class ZippedPartitionsWithLocalityRDD[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
_f: (Iterator[A], Iterator[B]) => Iterator[V],
_rdd1: RDD[A],
_rdd2: RDD[B],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsRDD2[A, B, V](sc, _f, _rdd1, _rdd2, preservesPartitioning) {

override def getPartitions: Array[Partition] = {
require(rdds.length == 2, "this is only for 2 rdd zip")
val numParts = rdds.head.partitions.length
if (!rdds.forall(rdd => rdd.partitions.length == numParts)) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}

val candidateLocs = new ArrayBuffer[(Int, Seq[String])]()
(0 until numParts).foreach(p => {
candidateLocs.append((p, rdds(1)
.context.getPreferredLocs(rdds(1), p)
.map(_.toString).distinct))
})
val nonmatchPartitionId = new ArrayBuffer[Int]()
val parts = new Array[Partition](numParts)

(0 until numParts).foreach { i =>
val curPrefs = rdds(0).context.getPreferredLocs(rdds(0), i).map(_.toString).distinct
var p = 0
var matchPartition: (Int, Seq[String]) = null
var locs: Seq[String] = null
while (p < candidateLocs.length) {
locs = candidateLocs(p)._2.intersect(curPrefs)
if (!locs.isEmpty) {
matchPartition = candidateLocs.remove(p)
p = Integer.MAX_VALUE - 1
}
p += 1
}
if (matchPartition != null) {
parts(i) =
new ZippedPartitionsLocalityPartition(i, Array(i, matchPartition._1), rdds, locs)
} else {
ZippedPartitionsWithLocalityRDD.logger.warn(s"can't find locality partition" +
s"for partition $i Partition locations are (${curPrefs}) Candidate partition" +
s" locations are\n" + s"${candidateLocs.mkString("\n")}.")
nonmatchPartitionId.append(i)
}
}

require(nonmatchPartitionId.size == candidateLocs.size,
"unmatched partition size should be the same with candidateLocs size")
nonmatchPartitionId.foreach { i =>
val locs = rdds(0).context.getPreferredLocs(rdds(0), i).map(_.toString).distinct
val matchPartition = candidateLocs.remove(0)
parts(i) = new ZippedPartitionsLocalityPartition(i, Array(i, matchPartition._1), rdds, locs)
}
parts
}
}


private[spark] class ZippedPartitionsLocalityPartition(
idx: Int,
@transient val indexes: Seq[Int],
@transient val rdds: Seq[RDD[_]],
@transient override val preferredLocations: Seq[String])
extends ZippedPartitionsPartition(idx, rdds, preferredLocations) {

override val index: Int = idx
var _partitionValues = rdds.zip(indexes).map{ case (rdd, i) => rdd.partitions(i) }
override def partitions: Seq[Partition] = _partitionValues

@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
_partitionValues = rdds.zip(indexes).map{ case (rdd, i) => rdd.partitions(i) }
oos.defaultWriteObject()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql

import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.types.DataType

object SqlAdapter {

def getUDF(f: AnyRef, dataType: DataType): SparkUserDefinedFunction = {
SparkUserDefinedFunction(f, dataType)
}

}
Loading

0 comments on commit f6d9181

Please sign in to comment.