Skip to content

Commit

Permalink
[WIP] spark 3.0 (#3054)
Browse files Browse the repository at this point in the history
* spark 3.0
  • Loading branch information
Le-Zheng committed Sep 27, 2020
1 parent 6038919 commit 24e9173
Show file tree
Hide file tree
Showing 21 changed files with 692 additions and 61 deletions.
106 changes: 96 additions & 10 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@
<findbugs.version>3.0.0</findbugs.version>

<!-- define the Java language version used by the compiler -->
<java.version>1.7</java.version>
<javac.version>1.7</javac.version>
<scala.major.version>2.10</scala.major.version>
<scala.version>2.10.7</scala.version>
<java.version>1.8</java.version>
<javac.version>1.8</javac.version>
<scala.major.version>2.11</scala.major.version>
<scala.version>2.11.8</scala.version>
<scala.macros.version>2.1.0</scala.macros.version>
<scalatest.version>2.2.4</scalatest.version>
<scalatest.version>3.0.7</scalatest.version>

<!-- platform encoding override -->
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand Down Expand Up @@ -142,7 +142,7 @@
<maven-failsafe-plugin.version>${maven-surefire-plugin.version}</maven-failsafe-plugin.version>

<maven-clean-plugin.version>2.5</maven-clean-plugin.version>
<maven-compiler-plugin.version>3.1</maven-compiler-plugin.version>
<maven-compiler-plugin.version>3.2.1</maven-compiler-plugin.version>
<maven-install-plugin.version>2.5.1</maven-install-plugin.version>
<maven-resources-plugin.version>2.6</maven-resources-plugin.version>
<maven-jar-plugin.version>2.5</maven-jar-plugin.version>
Expand Down Expand Up @@ -172,9 +172,9 @@
<commons-math.version>2.2</commons-math.version>
<collections.version>3.2.1</collections.version>
<scoverage.plugin.version>1.1.1</scoverage.plugin.version>
<spark-version.project>1.5-plus</spark-version.project>
<spark.version>1.5.1</spark.version>
<breeze.version>0.11.2</breeze.version>
<spark-version.project>2.0</spark-version.project>
<spark.version>2.4.0</spark.version>
<breeze.version>0.13.2</breeze.version>
<spark-scope>provided</spark-scope>

<bigdl-core-all-scope>compile</bigdl-core-all-scope>
Expand Down Expand Up @@ -544,7 +544,7 @@
<id>spark_2.x</id>
<properties>
<spark-version.project>2.0</spark-version.project>
<spark.version>2.0.0</spark.version>
<spark.version>2.4.3</spark.version>
<scala.major.version>2.11</scala.major.version>
<scala.version>2.11.8</scala.version>
<scala.macros.version>2.1.0</scala.macros.version>
Expand Down Expand Up @@ -618,6 +618,92 @@
</plugins>
</build>
</profile>
<profile>
<id>spark_3.x</id>
<properties>
<spark-version.project>3.0</spark-version.project>
<spark.version>3.0.0</spark.version>
<scala.major.version>2.12</scala.major.version>
<scala.version>2.12.8</scala.version>
<scala.macros.version>2.1.0</scala.macros.version>
</properties>
<build>
<plugins>
<!-- Redefine the plugin to enable fatal warninings -->
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>3.2.0</version>
<executions>
<execution>
<id>eclipse-add-source</id>
<goals>
<goal>add-source</goal>
</goals>
</execution>
<execution>
<id>scala-compile-first</id>
<phase>process-resources</phase>
<goals>
<goal>compile</goal>
</goals>
</execution>
<execution>
<id>scala-test-compile-first</id>
<phase>process-test-resources</phase>
<goals>
<goal>testCompile</goal>
</goals>
</execution>
<execution>
<id>attach-scaladocs</id>
<phase>verify</phase>
<goals>
<goal>doc-jar</goal>
</goals>
<configuration>
<args>
<!-- Do not change the arg orders. It is a weird way to pass in
this arg. Maybe it is a bug of the plugin. -->
<arg>-skip-packages</arg>
<arg>caffe:org.tensorflow:netty:org.apache.spark.sparkExtension:org.apache.spark.rdd:org.apache.spark.storage:org.apache.spark.bigdl</arg>
</args>
</configuration>
</execution>
</executions>
<configuration>
<scalaVersion>${scala.version}</scalaVersion>
<recompileMode>incremental</recompileMode>
<useZincServer>false</useZincServer>
<args>
<arg>-unchecked</arg>
<!-- Too many deprecation usage, let's suspend it temporary-->
<arg>-deprecation:false</arg>
<arg>-feature</arg>
<arg>-Xfatal-warnings</arg>
</args>
<!-- The following plugin is required to use quasiquotes in Scala 2.10 and is used
by Spark SQL for code generation. -->
<compilerPlugins>
<compilerPlugin>
<groupId>org.scalamacros</groupId>
<artifactId>paradise_${scala.version}</artifactId>
<version>${scala.macros.version}</version>
</compilerPlugin>
</compilerPlugins>
</configuration>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>scala_2.12</id>
<properties>
<scala.major.version>2.12</scala.major.version>
<scala.version>2.12.12</scala.version>
<scala.macros.version>2.1.0</scala.macros.version>
</properties>
</profile>
<profile>
<id>scala_2.11</id>
<properties>
Expand Down
5 changes: 4 additions & 1 deletion pyspark/bigdl/models/utils/model_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import gc
from tempfile import NamedTemporaryFile

from pyspark.cloudpickle import print_exec
from pyspark.broadcast import Broadcast
from pyspark.broadcast import _from_id
from bigdl.nn.layer import Model
Expand Down Expand Up @@ -49,6 +48,10 @@ def dump(self, value, f):
value.saveModel(f.name, over_write=True)
except Exception as e:
msg = "Could not serialize broadcast: %s" % e.__class__.__name__
if not self.sc.version.startswith("2.1"):
from pyspark.cloudpickle import print_exec
else:
from pyspark.util import print_exec
print_exec(sys.stderr)
raise ValueError(msg)
f.close()
Expand Down
2 changes: 1 addition & 1 deletion pyspark/test/bigdl/dlframes/test_dl_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_transform_image(self):
# test, and withhold the support for Spark 1.5, until the unit test failure reason
# is clarified.

if not self.sc.version.startswith("1.5"):
if not self.sc.version.startswith("1.5" and "3.0"):
image_frame = DLImageReader.readImages(self.image_path, self.sc)
transformer = DLImageTransformer(
Pipeline([Resize(256, 256), CenterCrop(224, 224),
Expand Down
4 changes: 2 additions & 2 deletions spark/dl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
<dependency>
<groupId>com.github.scopt</groupId>
<artifactId>scopt_${scala.major.version}</artifactId>
<version>3.2.0</version>
<version>3.5.0</version>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
Expand Down Expand Up @@ -190,7 +190,7 @@
or shade plugin will be executed after assembly plugin. -->
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.0.0</version>
<version>3.2.1</version>
<configuration>
<filters>
<filter>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Pooler[T: ClassTag] (
val num_rois = rois.size(1)
totalNum += num_rois

if (out.getOrElse(i + 1, null) == null) out(i + 1) = Tensor[T]()
if (!out.contains(i + 1)) out(i + 1) = Tensor[T]()
val outROI = out[Tensor[T]](i + 1)
outROI.resize(num_rois, num_channels, resolution, resolution)
.fill(ev.fromType[Float](Float.MinValue))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class RegionProposal(
val postNmsTopN = if (this.isTraining()) min(postNmsTopNTrain, bboxNumber)
else min(postNmsTopNTest, bboxNumber)

if (output.getOrElse(b, null) == null) {
if (!output.contains(b)) {
output(b) = Tensor[Float]()
}
output[Tensor[Float]](b).resize(postNmsTopN, 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,8 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
require(copiedModuleParamTable.get(name) != None, s"cloned module should have for $name")
setLayerWeightAndBias(params,
copiedModuleParamTable.get(name).get.asInstanceOf[Table], deepCopy)
case _ =>
throw new UnsupportedOperationException("unsupported $name and $params")
}
}
}
Expand Down Expand Up @@ -1125,6 +1127,8 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
} else {
if (matchAll) new Exception(s"module $name cannot find corresponding weight bias")
}
case _ =>
throw new UnsupportedOperationException("unsupported $name and $targetParams")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.commons.lang.exception.ExceptionUtils
import org.apache.log4j.Logger
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future
Expand Down Expand Up @@ -195,8 +196,8 @@ object DistriOptimizer extends AbstractOptimizer {
var dataRDD = dataset.data(train = true)

while (!endWhen(driverState)) {
val lossSum = sc.accumulator(0.0, "loss sum")
val recordsNum = sc.accumulator(0, "record number")
val lossSum = sc.doubleAccumulator("loss sum")
val recordsNum = sc.doubleAccumulator("record number")
metrics.set("computing time for each node", mutable.ArrayBuffer[Double](), sc)
metrics.set("get weights for each node", mutable.ArrayBuffer[Double](), sc)
metrics.set("computing time average", 0.0, sc, partitionNum)
Expand Down Expand Up @@ -293,10 +294,10 @@ object DistriOptimizer extends AbstractOptimizer {
driverMetrics.add("computing time for each node", computingTime)

val finishedThreads = trainingThreads.filter(!_.isCancelled).map(_.get())
recordsNum += finishedThreads.size * stackSize
recordsNum.add(finishedThreads.size * stackSize)
var i = 0
while (i < finishedThreads.size) {
lossSum += lossArray(finishedThreads(i))
lossSum.add(lossArray(finishedThreads(i)))
i += 1
}

Expand Down Expand Up @@ -409,7 +410,7 @@ object DistriOptimizer extends AbstractOptimizer {
}.count()

stateBroadcast.destroy()
recordsProcessedThisEpoch += recordsNum.value
recordsProcessedThisEpoch += (recordsNum.value).toInt
val end = System.nanoTime()
wallClockTime += end - start
driverState("isGradientUpdated") = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ object DistriOptimizerV2 extends AbstractOptimizer {
cacheOfMaster: MasterCache[T],
context: TrainingContext[T], trainingTrace: TrainingTrace
)(implicit ev: TensorNumeric[T]): Unit = {
val lossSum = sc.accumulator(0.0, "loss sum")
val recordsNum = sc.accumulator(0, "record number")
val lossSum = sc.doubleAccumulator("loss sum")
val recordsNum = sc.doubleAccumulator("record number")
val metrics = cacheOfMaster.metrics
val partitionNum = cacheOfMaster.partitionNum
initMetrics(sc, metrics, partitionNum)
Expand Down Expand Up @@ -202,16 +202,16 @@ object DistriOptimizerV2 extends AbstractOptimizer {

val results = train(cached, miniBatchBuffer, context, metrics)

lossSum += results.loss
recordsNum += results.records
lossSum.add(results.loss)
recordsNum.add(results.records)

Iterator.single(results.successed)
}.reduce(_ + _)

parameterSync(lossSum.value, successModels, cacheOfMaster, models, context)
})

driverStatesUpdate(cacheOfMaster, recordsNum.value,
driverStatesUpdate(cacheOfMaster, (recordsNum.value).toInt,
context, trainingTrace, metrics)
}

Expand Down Expand Up @@ -240,10 +240,6 @@ object DistriOptimizerV2 extends AbstractOptimizer {
parameterProcessers: Array[ParameterProcessor]
)

case class Replica(model: Module[T], weights: Tensor[T], gradients: Tensor[T],
criterion: Criterion[T], state: Table,
validationMethods: Option[Array[ValidationMethod[T]]])

val config = TrainingConfig(
cacheOfMaster.criterion,
cacheOfMaster.validationMethods,
Expand Down Expand Up @@ -1056,6 +1052,10 @@ private case object AGGREGATE_PARTITION_GRADIENT extends MetricEntry("aggregrate
// scalastyle:on
private case object SEND_WEIGHTS_AVERAGE extends MetricEntry("send weights average")

private case class Replica[T](model: Module[T], weights: Tensor[T], gradients: Tensor[T],
criterion: Criterion[T], state: Table,
validationMethods: Option[Array[ValidationMethod[T]]])

private class TrainingTrace(
private var _records: Int = 0,
private var _iterations: Int = 0,
Expand Down
Loading

0 comments on commit 24e9173

Please sign in to comment.