diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/DnnGraph.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/DnnGraph.scala index 61c9aa8a951..60cee93d3c6 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/DnnGraph.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/DnnGraph.scala @@ -394,14 +394,15 @@ class DnnGraph( */ private def fusion(): Unit = { if (!this.train) { - for (j <- 0 to 3) { + for (j <- 0 to 4) { var i = forwardExecution.length - 1 while (i >= 0) { - if (j == 0) Fusion.fuseModule(forwardExecution(i)) + if (j == 0) Fusion.fuseScale(forwardExecution(i)) + if (j == 1) Fusion.fuseModule(forwardExecution(i)) // we should do this before sum fusion, because it will change the structure of graph - if (j == 1) Fusion.setNegativeInputOfConv(forwardExecution(i)) - if (j == 2) Fusion.fuseCAdd(forwardExecution(i)) - if (j == 3) Fusion.setScalesPrevousJoinTable(forwardExecution(i)) + if (j == 2) Fusion.setNegativeInputOfConv(forwardExecution(i)) + if (j == 3) Fusion.fuseCAdd(forwardExecution(i)) + if (j == 4) Fusion.setScalesPrevousJoinTable(forwardExecution(i)) i -= 1 } } diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/Fusion.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/Fusion.scala index 87cb0145b81..b99a7b19f95 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/Fusion.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/Fusion.scala @@ -38,8 +38,6 @@ private[mkldnn] object Fusion { node.element match { case relu: ReLU => fusionRelu(node) case bn: SpatialBatchNormalization => fusionBN(node) - case blasWrapper: BlasWrapper if blasWrapper.module.isInstanceOf[ScaleLayer[Float]] => - fuseScale(node) case _ => } } @@ -288,31 +286,35 @@ private[mkldnn] object Fusion { } def fuseScale(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = { - // check all prevNodes are SpatialBatchNormalization - val isValid = node.prevNodes.forall(_.element.isInstanceOf[SpatialBatchNormalization]) - if (!isValid) { return } + node.element match { + case wrapper: BlasWrapper if wrapper.module.isInstanceOf[ScaleLayer[Float]] => + // check all prevNodes are SpatialBatchNormalization + val isValid = node.prevNodes.forall(_.element.isInstanceOf[SpatialBatchNormalization]) + if (!isValid) { return } - node.prevNodes.foreach { prevNode => - val bn = prevNode.element.asInstanceOf[SpatialBatchNormalization] - val weightAndBias = bn.weightAndBias.dense - val weight = weightAndBias.narrow(1, 1, bn.nOutput) - val bias = weightAndBias.narrow(1, bn.nOutput + 1, bn.nOutput) + node.prevNodes.foreach { prevNode => + val bn = prevNode.element.asInstanceOf[SpatialBatchNormalization] + val weightAndBias = bn.weightAndBias.dense + val weight = weightAndBias.narrow(1, 1, bn.nOutput) + val bias = weightAndBias.narrow(1, bn.nOutput + 1, bn.nOutput) - val scale = node.element.asInstanceOf[BlasWrapper].module.asInstanceOf[ScaleLayer[Float]] - val scaleWeight = scale.parameters()._1(0) - val scaleBias = scale.parameters()._1(1) + val scale = node.element.asInstanceOf[BlasWrapper].module.asInstanceOf[ScaleLayer[Float]] + val scaleWeight = scale.parameters()._1(0) + val scaleBias = scale.parameters()._1(1) - weight.cmul(scaleWeight) - bias.cmul(scaleWeight) - bias.add(scaleBias) + weight.cmul(scaleWeight) + bias.cmul(scaleWeight) + bias.add(scaleBias) - // set the weight and bias to new tensor, we do not modify the original model's tensor. - // sometimes, the model need to be reused. - bn.weightAndBias.dense.set(weightAndBias) - } + // set the weight and bias to new tensor, we do not modify the original model's tensor. + // sometimes, the model need to be reused. + bn.weightAndBias.dense.set(weightAndBias) + } - node.element = Identity[Float]() // set the BlasWrapper to Identity, we need no scale now + node.element = Identity[Float]() // set the BlasWrapper to Identity, we need no scale now + case _ => + } } private def findAllNonIdentityPrevs(node: Node[AbstractModule[Activity, Activity, Float]]) diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/FusionSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/FusionSpec.scala index 891c15141f9..29f69048c0a 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/FusionSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/FusionSpec.scala @@ -406,4 +406,37 @@ class FusionSpec extends FlatSpec with Matchers { System.clearProperty("bigdl.mkldnn.fusion") } + + "bn + scale + relu fusion" should "work correctly" in { + import com.intel.analytics.bigdl.nn.{Scale => NNScale} + val inputShape = Array(4, 64, 3, 3) + val input = Input(inputShape, Memory.Format.nchw).inputs() + val bn1 = SpatialBatchNormalization(64).inputs(input) + val scale1 = BlasWrapper(NNScale[Float](Array(1, 64, 1, 1))).inputs(bn1) + val relu1 = ReLU() + val output = Output(Memory.Format.nchw).inputs(scale1) + + // the running mean and running variance should be 1. + bn1.element.getExtraParameter().foreach(_.fill(1)) + + val model = DnnGraph(Seq(input), Seq(output)) + val fused = model.cloneModule() + + model.evaluate() + fused.evaluate() + + val tensor = Tensor[Float](inputShape).rand(-1, 1) + + System.setProperty("bigdl.mkldnn.fusion", "false") + model.compile(InferencePhase) + model.forward(tensor) + + System.setProperty("bigdl.mkldnn.fusion", "true") + fused.compile(InferencePhase) + fused.forward(tensor) + + Equivalent.nearequals(model.output.toTensor[Float], fused.output.toTensor[Float], 1e-7) + + System.clearProperty("bigdl.mkldnn.fusion") + } }