forked from intel-analytics/BigDL-2.x
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] spark 3.0 (intel-analytics#3054)
* spark 3.0
- Loading branch information
Showing
18 changed files
with
591 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<parent> | ||
<artifactId>spark-version</artifactId> | ||
<groupId>com.intel.analytics.bigdl</groupId> | ||
<version>0.12.0-SNAPSHOT</version> | ||
</parent> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<groupId>com.intel.analytics.bigdl.spark-version</groupId> | ||
<artifactId>3.0</artifactId> | ||
<packaging>jar</packaging> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-core_2.12</artifactId> | ||
<version>${spark.version}</version> | ||
<scope>provided</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-mllib_2.12</artifactId> | ||
<version>${spark.version}</version> | ||
<scope>provided</scope> | ||
</dependency> | ||
</dependencies> | ||
<build> | ||
<plugins> | ||
<plugin> | ||
<groupId>org.scalastyle</groupId> | ||
<artifactId>scalastyle-maven-plugin</artifactId> | ||
<version>0.8.0</version> | ||
<configuration> | ||
<verbose>false</verbose> | ||
<failOnViolation>true</failOnViolation> | ||
<failOnWarning>false</failOnWarning> | ||
<sourceDirectory>${basedir}/src/main/scala</sourceDirectory> | ||
<configLocation>${project.parent.parent.parent.basedir}/scalastyle_config.xml</configLocation> | ||
<outputFile>${project.build.directory}/stylecheck/scalastyle-output.xml</outputFile> | ||
<outputEncoding>UTF-8</outputEncoding> | ||
</configuration> | ||
<executions> | ||
<execution> | ||
<goals> | ||
<goal>check</goal> | ||
</goals> | ||
</execution> | ||
</executions> | ||
</plugin> | ||
</plugins> | ||
</build> | ||
</project> |
66 changes: 66 additions & 0 deletions
66
scala/common/spark-version/3.0/src/main/scala/org/apache/spark/ml/DLEstimatorBase.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright 2016 The BigDL Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.ml | ||
|
||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.ml.param.shared.HasLabelCol | ||
import org.apache.spark.ml.linalg.{Vector, VectorUDT} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
|
||
/** | ||
* Handle different Vector types in Spark 1.5/1.6 and Spark 2.0+. | ||
* Support both ML Vector and MLlib Vector for Spark 2.0+. | ||
*/ | ||
trait VectorCompatibility { | ||
|
||
val validVectorTypes = Seq(new VectorUDT, new org.apache.spark.mllib.linalg.VectorUDT) | ||
|
||
def getVectorSeq(row: Row, colType: DataType, index: Int): Seq[AnyVal] = { | ||
if (colType == new VectorUDT) { | ||
row.getAs[Vector](index).toArray.toSeq | ||
} else if (colType == new org.apache.spark.mllib.linalg.VectorUDT) { | ||
row.getAs[org.apache.spark.mllib.linalg.Vector](index).toArray.toSeq | ||
} else { | ||
throw new IllegalArgumentException( | ||
s"$colType is not a supported vector type.") | ||
} | ||
} | ||
} | ||
|
||
|
||
/** | ||
*A wrapper from org.apache.spark.ml.Estimator | ||
* Extends MLEstimator and override process to gain compatibility with | ||
* both spark 1.5 and spark 2.0. | ||
*/ | ||
abstract class DLEstimatorBase[Learner <: DLEstimatorBase[Learner, M], | ||
M <: DLTransformerBase[M]] | ||
extends Estimator[M] with HasLabelCol { | ||
|
||
protected def internalFit(dataFrame: DataFrame): M | ||
|
||
override def fit(dataset: Dataset[_]): M = { | ||
transformSchema(dataset.schema, logging = true) | ||
internalFit(dataset.toDF()) | ||
} | ||
|
||
override def copy(extra: ParamMap): Learner = defaultCopy(extra) | ||
|
||
} | ||
|
||
|
||
|
41 changes: 41 additions & 0 deletions
41
scala/common/spark-version/3.0/src/main/scala/org/apache/spark/ml/DLTransformerBase.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Copyright 2016 The BigDL Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.ml | ||
|
||
import org.apache.spark.ml.param.ParamMap | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.{DataFrame, Dataset} | ||
|
||
/** | ||
* A wrapper for org.apache.spark.ml.Transformer. | ||
* Extends MlTransformer and override process to gain compatibility with | ||
* both spark 1.5 and spark 2.0. | ||
*/ | ||
abstract class DLTransformerBase[M <: DLTransformerBase[M]] | ||
extends Model[M] { | ||
|
||
/** | ||
* convert feature columns(MLlib Vectors or Array) to Seq format | ||
*/ | ||
protected def internalTransform(dataFrame: DataFrame): DataFrame | ||
|
||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
transformSchema(dataset.schema, logging = true) | ||
internalTransform(dataset.toDF()) | ||
} | ||
|
||
override def copy(extra: ParamMap): M = defaultCopy(extra) | ||
} |
135 changes: 135 additions & 0 deletions
135
...ark-version/3.0/src/main/scala/org/apache/spark/rdd/ZippedPartitionsWithLocalityRDD.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* | ||
* Copyright 2016 The BigDL Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.rdd | ||
|
||
import java.io.{IOException, ObjectOutputStream} | ||
|
||
import org.apache.log4j.Logger | ||
import org.apache.spark.util.Utils | ||
import org.apache.spark.{Partition, SparkContext} | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
import scala.reflect.ClassTag | ||
|
||
object ZippedPartitionsWithLocalityRDD { | ||
def apply[T: ClassTag, B: ClassTag, V: ClassTag] | ||
(rdd1: RDD[T], rdd2: RDD[B], preservesPartitioning: Boolean = false) | ||
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = rdd1.withScope { | ||
val sc = rdd1.sparkContext | ||
new ZippedPartitionsWithLocalityRDD( | ||
sc, sc.clean(f), rdd1, rdd2, preservesPartitioning) | ||
} | ||
|
||
val logger: Logger = Logger.getLogger(getClass) | ||
} | ||
|
||
/** | ||
* Prefer to zip partitions of rdd1 and rdd2 in the same location. | ||
* Remaining partitions not in same location will be zipped by order. | ||
* For example: | ||
* Say we have two RDDs, rdd1 and rdd2. The first partition of rdd1 is on node A, and the second | ||
* is on node B. The first partition of rdd2 is on node B and the second one is on node A. | ||
* If we just use rdd1.zipPartition(rdd2), the result will be the first partition of rdd1 is | ||
* zipped with the first partition of rdd2, so there will be cross node communication. This is | ||
* bad for performance. That's why we introduce the ZippedPartitionsWithLocalityRDD. | ||
* In our method, the first partition of rdd1 will be zipped with the second partition of rdd2, | ||
* as they are on the same node. This will reduce the network communication cost and result in | ||
* a better performance. | ||
* @param sc spark context | ||
* @param _f | ||
* @param _rdd1 | ||
* @param _rdd2 | ||
* @param preservesPartitioning | ||
*/ | ||
class ZippedPartitionsWithLocalityRDD[A: ClassTag, B: ClassTag, V: ClassTag]( | ||
sc: SparkContext, | ||
_f: (Iterator[A], Iterator[B]) => Iterator[V], | ||
_rdd1: RDD[A], | ||
_rdd2: RDD[B], | ||
preservesPartitioning: Boolean = false) | ||
extends ZippedPartitionsRDD2[A, B, V](sc, _f, _rdd1, _rdd2, preservesPartitioning) { | ||
|
||
override def getPartitions: Array[Partition] = { | ||
require(rdds.length == 2, "this is only for 2 rdd zip") | ||
val numParts = rdds.head.partitions.length | ||
if (!rdds.forall(rdd => rdd.partitions.length == numParts)) { | ||
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") | ||
} | ||
|
||
val candidateLocs = new ArrayBuffer[(Int, Seq[String])]() | ||
(0 until numParts).foreach(p => { | ||
candidateLocs.append((p, rdds(1) | ||
.context.getPreferredLocs(rdds(1), p) | ||
.map(_.toString).distinct)) | ||
}) | ||
val nonmatchPartitionId = new ArrayBuffer[Int]() | ||
val parts = new Array[Partition](numParts) | ||
|
||
(0 until numParts).foreach { i => | ||
val curPrefs = rdds(0).context.getPreferredLocs(rdds(0), i).map(_.toString).distinct | ||
var p = 0 | ||
var matchPartition: (Int, Seq[String]) = null | ||
var locs: Seq[String] = null | ||
while (p < candidateLocs.length) { | ||
locs = candidateLocs(p)._2.intersect(curPrefs) | ||
if (!locs.isEmpty) { | ||
matchPartition = candidateLocs.remove(p) | ||
p = Integer.MAX_VALUE - 1 | ||
} | ||
p += 1 | ||
} | ||
if (matchPartition != null) { | ||
parts(i) = | ||
new ZippedPartitionsLocalityPartition(i, Array(i, matchPartition._1), rdds, locs) | ||
} else { | ||
ZippedPartitionsWithLocalityRDD.logger.warn(s"can't find locality partition" + | ||
s"for partition $i Partition locations are (${curPrefs}) Candidate partition" + | ||
s" locations are\n" + s"${candidateLocs.mkString("\n")}.") | ||
nonmatchPartitionId.append(i) | ||
} | ||
} | ||
|
||
require(nonmatchPartitionId.size == candidateLocs.size, | ||
"unmatched partition size should be the same with candidateLocs size") | ||
nonmatchPartitionId.foreach { i => | ||
val locs = rdds(0).context.getPreferredLocs(rdds(0), i).map(_.toString).distinct | ||
val matchPartition = candidateLocs.remove(0) | ||
parts(i) = new ZippedPartitionsLocalityPartition(i, Array(i, matchPartition._1), rdds, locs) | ||
} | ||
parts | ||
} | ||
} | ||
|
||
|
||
private[spark] class ZippedPartitionsLocalityPartition( | ||
idx: Int, | ||
@transient val indexes: Seq[Int], | ||
@transient val rdds: Seq[RDD[_]], | ||
@transient override val preferredLocations: Seq[String]) | ||
extends ZippedPartitionsPartition(idx, rdds, preferredLocations) { | ||
|
||
override val index: Int = idx | ||
var _partitionValues = rdds.zip(indexes).map{ case (rdd, i) => rdd.partitions(i) } | ||
override def partitions: Seq[Partition] = _partitionValues | ||
|
||
@throws(classOf[IOException]) | ||
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { | ||
// Update the reference to parent split at the time of task serialization | ||
_partitionValues = rdds.zip(indexes).map{ case (rdd, i) => rdd.partitions(i) } | ||
oos.defaultWriteObject() | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
27
scala/common/spark-version/3.0/src/main/scala/org/apache/spark/sql/SqlAdapter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* | ||
* Copyright 2016 The BigDL Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.spark.sql | ||
|
||
import org.apache.spark.sql.expressions.SparkUserDefinedFunction | ||
import org.apache.spark.sql.types.DataType | ||
|
||
object SqlAdapter { | ||
|
||
def getUDF(f: AnyRef, dataType: DataType): SparkUserDefinedFunction = { | ||
SparkUserDefinedFunction(f, dataType) | ||
} | ||
|
||
} |
Oops, something went wrong.