Skip to content

Commit

Permalink
fix: softmax and bn+scale fusion (intel-analytics#2937)
Browse files Browse the repository at this point in the history
  • Loading branch information
i8run committed Oct 24, 2019
1 parent 85efd88 commit e1a0f05
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.MklInt8Convertible
import com.intel.analytics.bigdl.nn.{MklInt8Convertible, Scale => ScaleLayer}
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Node
Expand All @@ -38,6 +38,8 @@ private[mkldnn] object Fusion {
node.element match {
case relu: ReLU => fusionRelu(node)
case bn: SpatialBatchNormalization => fusionBN(node)
case blasWrapper: BlasWrapper if blasWrapper.module.isInstanceOf[ScaleLayer[Float]] =>
fuseScale(node)
case _ =>
}
}
Expand Down Expand Up @@ -76,7 +78,8 @@ private[mkldnn] object Fusion {
*/
private def fusionRelu(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
node.prevNodes.foreach(n => {
n.element match {
val notIdentity = findPrevious(n)
notIdentity.element match {
case conv: SpatialConvolution =>
if (!conv.relu) {
conv.setReLU(true)
Expand Down Expand Up @@ -284,6 +287,34 @@ private[mkldnn] object Fusion {
}
}

def fuseScale(node: Node[AbstractModule[Activity, Activity, Float]]): Unit = {
// check all prevNodes are SpatialBatchNormalization
val isValid = node.prevNodes.forall(_.element.isInstanceOf[SpatialBatchNormalization])
if (!isValid) { return }

node.prevNodes.foreach { prevNode =>
val bn = prevNode.element.asInstanceOf[SpatialBatchNormalization]
val weightAndBias = bn.weightAndBias.dense
val weight = weightAndBias.narrow(1, 1, bn.nOutput)
val bias = weightAndBias.narrow(1, bn.nOutput + 1, bn.nOutput)

val scale = node.element.asInstanceOf[BlasWrapper].module.asInstanceOf[ScaleLayer[Float]]
val scaleWeight = scale.parameters()._1(0)
val scaleBias = scale.parameters()._1(1)

weight.cmul(scaleWeight)
bias.cmul(scaleWeight)
bias.add(scaleBias)


// set the weight and bias to new tensor, we do not modify the original model's tensor.
// sometimes, the model need to be reused.
bn.weightAndBias.dense.set(weightAndBias)
}

node.element = Identity[Float]() // set the BlasWrapper to Identity, we need no scale now
}

private def findAllNonIdentityPrevs(node: Node[AbstractModule[Activity, Activity, Float]])
: Seq[Node[AbstractModule[Activity, Activity, Float]]] = {
// TODO currently, it will only skip the Identity, MaxPooling, AvgPooling, JoinTable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.mkl.{Memory, MklDnn, PropKind, Stream => DnnStream}
import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, PropKind, Stream => DnnStream}
import com.intel.analytics.bigdl.nn
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.mkldnn.Phase.{InferencePhase, TrainingPhase}
Expand Down Expand Up @@ -68,7 +68,7 @@ class SoftMax(val axis: Int = -1) extends MklDnnLayer {
case _ => throw new UnsupportedOperationException("1D, 2D, 3D or 4D tensor expected")
}

_inputFormats = singleNativeData(inputs)
_inputFormats = Array(NativeData(inputs(0).shape, inputs(0).layout, DataType.F32))

val localInputFormat = if (inputs(0).shape.length == 3 &&
inputs(0).layout == Memory.Format.ntc) {
Expand Down

0 comments on commit e1a0f05

Please sign in to comment.