Skip to content

Commit

Permalink
fix: fuse bn scale and relu to bn. (intel-analytics#2966)
Browse files Browse the repository at this point in the history
* fix: fuse bn scale and relu.
  • Loading branch information
i8run committed Nov 27, 2019
1 parent b4d4975 commit 2b1b5c6
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,15 @@ class DnnGraph(
*/
private def fusion(): Unit = {
if (!this.train) {
for (j <- 0 to 3) {
for (j <- 0 to 4) {
var i = forwardExecution.length - 1
while (i >= 0) {
if (j == 0) Fusion.fuseModule(forwardExecution(i))
if (j == 0) Fusion.fuseScale(forwardExecution(i))
if (j == 1) Fusion.fuseModule(forwardExecution(i))
// we should do this before sum fusion, because it will change the structure of graph
if (j == 1) Fusion.setNegativeInputOfConv(forwardExecution(i))
if (j == 2) Fusion.fuseCAdd(forwardExecution(i))
if (j == 3) Fusion.setScalesPrevousJoinTable(forwardExecution(i))
if (j == 2) Fusion.setNegativeInputOfConv(forwardExecution(i))
if (j == 3) Fusion.fuseCAdd(forwardExecution(i))
if (j == 4) Fusion.setScalesPrevousJoinTable(forwardExecution(i))
i -= 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ private[mkldnn] object Fusion {
node.element match {
case relu: ReLU => fusionRelu(node)
case bn: SpatialBatchNormalization => fusionBN(node)
case blasWrapper: BlasWrapper if blasWrapper.module.isInstanceOf[ScaleLayer[Float]] =>
fuseScale(node)
case _ =>
}
}
Expand Down Expand Up @@ -288,31 +286,35 @@ private[mkldnn] object Fusion {
}

def fuseScale(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
// check all prevNodes are SpatialBatchNormalization
val isValid = node.prevNodes.forall(_.element.isInstanceOf[SpatialBatchNormalization])
if (!isValid) { return }
node.element match {
case wrapper: BlasWrapper if wrapper.module.isInstanceOf[ScaleLayer[Float]] =>
// check all prevNodes are SpatialBatchNormalization
val isValid = node.prevNodes.forall(_.element.isInstanceOf[SpatialBatchNormalization])
if (!isValid) { return }

node.prevNodes.foreach { prevNode =>
val bn = prevNode.element.asInstanceOf[SpatialBatchNormalization]
val weightAndBias = bn.weightAndBias.dense
val weight = weightAndBias.narrow(1, 1, bn.nOutput)
val bias = weightAndBias.narrow(1, bn.nOutput + 1, bn.nOutput)
node.prevNodes.foreach { prevNode =>
val bn = prevNode.element.asInstanceOf[SpatialBatchNormalization]
val weightAndBias = bn.weightAndBias.dense
val weight = weightAndBias.narrow(1, 1, bn.nOutput)
val bias = weightAndBias.narrow(1, bn.nOutput + 1, bn.nOutput)

val scale = node.element.asInstanceOf[BlasWrapper].module.asInstanceOf[ScaleLayer[Float]]
val scaleWeight = scale.parameters()._1(0)
val scaleBias = scale.parameters()._1(1)
val scale = node.element.asInstanceOf[BlasWrapper].module.asInstanceOf[ScaleLayer[Float]]
val scaleWeight = scale.parameters()._1(0)
val scaleBias = scale.parameters()._1(1)

weight.cmul(scaleWeight)
bias.cmul(scaleWeight)
bias.add(scaleBias)
weight.cmul(scaleWeight)
bias.cmul(scaleWeight)
bias.add(scaleBias)


// set the weight and bias to new tensor, we do not modify the original model's tensor.
// sometimes, the model need to be reused.
bn.weightAndBias.dense.set(weightAndBias)
}
// set the weight and bias to new tensor, we do not modify the original model's tensor.
// sometimes, the model need to be reused.
bn.weightAndBias.dense.set(weightAndBias)
}

node.element = Identity[Float]() // set the BlasWrapper to Identity, we need no scale now
node.element = Identity[Float]() // set the BlasWrapper to Identity, we need no scale now
case _ =>
}
}

private def findAllNonIdentityPrevs(node: Node[AbstractModule[Activity, Activity, Float]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,37 @@ class FusionSpec extends FlatSpec with Matchers {

System.clearProperty("bigdl.mkldnn.fusion")
}

"bn + scale + relu fusion" should "work correctly" in {
import com.intel.analytics.bigdl.nn.{Scale => NNScale}
val inputShape = Array(4, 64, 3, 3)
val input = Input(inputShape, Memory.Format.nchw).inputs()
val bn1 = SpatialBatchNormalization(64).inputs(input)
val scale1 = BlasWrapper(NNScale[Float](Array(1, 64, 1, 1))).inputs(bn1)
val relu1 = ReLU()
val output = Output(Memory.Format.nchw).inputs(scale1)

// the running mean and running variance should be 1.
bn1.element.getExtraParameter().foreach(_.fill(1))

val model = DnnGraph(Seq(input), Seq(output))
val fused = model.cloneModule()

model.evaluate()
fused.evaluate()

val tensor = Tensor[Float](inputShape).rand(-1, 1)

System.setProperty("bigdl.mkldnn.fusion", "false")
model.compile(InferencePhase)
model.forward(tensor)

System.setProperty("bigdl.mkldnn.fusion", "true")
fused.compile(InferencePhase)
fused.forward(tensor)

Equivalent.nearequals(model.output.toTensor[Float], fused.output.toTensor[Float], 1e-7)

System.clearProperty("bigdl.mkldnn.fusion")
}
}

0 comments on commit 2b1b5c6

Please sign in to comment.