Skip to content

Commit

Permalink
feat: add softmax backward (intel-analytics#2967)
Browse files Browse the repository at this point in the history
* feat: add softmax backward
  • Loading branch information
i8run committed Nov 26, 2019
1 parent c79f818 commit b4d4975
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ object MklDnnMemory {
MklDnn.SoftMaxForwardDescInit(prop_kind, dataDesc, axis)).ptr
}

def SoftMaxBackwardDescInit(propKind: Int, diffDesc: Long, dstDesc: Long,
axis: Int)(implicit owner: MemoryOwner): Long = {
new MklMemoryDescInit(MklDnn.SoftMaxBackwardDescInit(diffDesc, dstDesc, axis)).ptr
}

def ConvForwardDescInit(prop_kind: Int, alg_kind: Int, src_desc: Long, weights_desc: Long,
bias_desc: Long, dst_desc: Long, strides: Array[Int], padding_l: Array[Int],
padding_r: Array[Int], padding_kind: Int)(implicit owner: MemoryOwner): Long = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.intel.analytics.bigdl.nn.mkldnn

import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, PropKind, Stream => DnnStream}
import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, PropKind, Query, Stream => DnnStream}
import com.intel.analytics.bigdl.nn
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.mkldnn.Phase.{InferencePhase, TrainingPhase}
Expand All @@ -27,12 +27,14 @@ import com.intel.analytics.bigdl.utils.Shape
import scala.collection.mutable.ArrayBuffer

class SoftMax(val axis: Int = -1) extends MklDnnLayer {
private val nnSoftMax = nn.SoftMax[Float]()

@transient private var updateOutputTensors: Array[Tensor[Float]] = _
@transient private var updateOutputMemoryPrimitives: Array[Long] = _
@transient private var updateGradInputTensors: Array[Tensor[Float]] = _
@transient private var updateGradInputMemoryPrimitives: Array[Long] = _
@transient private var modelPhase: Phase = null

private var defaultAxis = 0

private def initPhase(phase: Phase): Unit = {
if (phase != null) return modelPhase = phase
isTraining() match {
Expand All @@ -53,96 +55,118 @@ class SoftMax(val axis: Int = -1) extends MklDnnLayer {

override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
initPhase(phase)
modelPhase match {
case TrainingPhase =>
_inputFormats = inputs.map(x => HeapData(x.shape, format(x.shape)))
_outputFormats = inputs.map(x => HeapData(x.shape, format(x.shape)))

(_inputFormats, _outputFormats)
case InferencePhase =>
val defaultAxis = inputs(0).shape.length match {
case 1 => 0
case 2 => 1
case 3 => 0
case 4 => 1
case _ => throw new UnsupportedOperationException("1D, 2D, 3D or 4D tensor expected")
}

_inputFormats = Array(NativeData(inputs(0).shape, inputs(0).layout, DataType.F32))

val localInputFormat = if (inputs(0).shape.length == 3 &&
inputs(0).layout == Memory.Format.ntc) {
// note: here, the format and the true memory layout is not consistent.
// for ntc input, we should reshape the `shape` and make the format to tnc
val shape = Array(inputs(0).shape(1), inputs(0).shape(0), inputs(0).shape(2))
NativeData(shape, Memory.Format.tnc)
} else {
_inputFormats(0)
}

val desc = MklDnnMemory.SoftMaxForwardDescInit(PropKind.ForwardInference,
localInputFormat.getMemoryDescription(), if (axis == -1) defaultAxis else axis)
val forwardPrimDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L)

_outputFormats = if (inputs(0).shape.length ==3 &&
inputs(0).layout == Memory.Format.ntc) {
// because set the input format as tnc first, we should set the output to ntc.
Array(NativeData(inputs(0).shape, Memory.Format.ntc))
} else {
Array(MemoryData.primitiveOutput(forwardPrimDesc))
}

val srcs = Array(inputs(0).getPrimitive(runtime))
val indexes = Array(0)
val dsts = Array(_outputFormats(0).getPrimitive(runtime))

val primitive = MklDnnMemory.PrimitiveCreate2(forwardPrimDesc, srcs, indexes,
srcs.length, dsts, dsts.length)

updateOutputPrimitives = Array(primitive)
updateOutputMemoryPrimitives = srcs ++ dsts

output = initTensor(_outputFormats(0))

(_inputFormats, _outputFormats)
case _ => throw new UnsupportedOperationException
defaultAxis = inputs(0).shape.length match {
case 1 => 0
case 2 => 1
case 3 => 0
case 4 => 1
case _ => throw new UnsupportedOperationException("1D, 2D, 3D or 4D tensor expected")
}

_inputFormats = Array(NativeData(inputs(0).shape, inputs(0).layout, DataType.F32))

val localInputFormat = if (inputs(0).shape.length == 3 &&
inputs(0).layout == Memory.Format.ntc) {
// note: here, the format and the true memory layout is not consistent.
// for ntc input, we should reshape the `shape` and make the format to tnc
val shape = Array(inputs(0).shape(1), inputs(0).shape(0), inputs(0).shape(2))
NativeData(shape, Memory.Format.tnc)
} else {
_inputFormats(0)
}

val desc = MklDnnMemory.SoftMaxForwardDescInit(PropKind.Forward,
localInputFormat.getMemoryDescription(), if (axis == -1) defaultAxis else axis)
val forwardPrimDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L)

_outputFormats = if (inputs(0).shape.length ==3 &&
inputs(0).layout == Memory.Format.ntc) {
// because set the input format as tnc first, we should set the output to ntc.
Array(NativeData(inputs(0).shape, Memory.Format.ntc))
} else {
Array(MemoryData.primitiveOutput(forwardPrimDesc))
}

val srcs = Array(inputs(0).getPrimitive(runtime))
val indexes = Array(0)
val dsts = Array(_outputFormats(0).getPrimitive(runtime))

val primitive = MklDnnMemory.PrimitiveCreate2(forwardPrimDesc, srcs, indexes,
srcs.length, dsts, dsts.length)

updateOutputPrimitives = Array(primitive)
updateOutputMemoryPrimitives = srcs ++ dsts

output = initTensor(_outputFormats(0))

(_inputFormats, _outputFormats)
}

override private[mkldnn] def initBwdPrimitives(grad: Array[MemoryData], phase: Phase) = {
_gradInputFormats = grad.clone()
_gradOutputFormats = grad.clone()
val desc = MklDnnMemory.SoftMaxBackwardDescInit(PropKind.Backward,
grad(0).getMemoryDescription(), outputFormats()(0).getMemoryDescription(),
if (axis == -1) defaultAxis else axis)
val primDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L)

_gradOutputFormats = grad
_gradInputFormats = Array(MemoryData.operationWant(primDesc, Query.DiffSrcPd))

val srcs = Array(grad(0).getPrimitive(runtime), outputFormats()(0).getPrimitive(runtime))
val indexes = Array(0)
val dsts = Array(_gradInputFormats(0).getPrimitive(runtime))

val primitive = MklDnnMemory.PrimitiveCreate2(primDesc, srcs, indexes,
srcs.length, dsts, dsts.length)

updateGradInputPrimitives = Array(primitive)
updateGradInputMemoryPrimitives = srcs ++ dsts

gradInput = initTensor(_gradInputFormats(0))

(_gradInputFormats, _gradOutputFormats)
}

override def updateOutput(input: Activity): Activity = {
if (this.isTraining()) {
nnSoftMax.forward(input)
output = nnSoftMax.output
} else {
if (updateOutputTensors == null) {
val buffer = new ArrayBuffer[Tensor[Float]]()
buffer.append(input.asInstanceOf[Tensor[Float]])
buffer.append(output.asInstanceOf[Tensor[Float]])
updateOutputTensors = buffer.toArray
}

input.toTensor[Float].getTensorType match {
case DenseType => updateOutputTensors(0) = input.toTensor
case _ =>
}

MklDnnOps.streamSubmit(runtime.stream, 1,
updateOutputPrimitives,
updateOutputPrimitives.length,
updateOutputMemoryPrimitives, updateOutputTensors)
if (updateOutputTensors == null) {
val buffer = new ArrayBuffer[Tensor[Float]]()
buffer.append(input.asInstanceOf[Tensor[Float]])
buffer.append(output.asInstanceOf[Tensor[Float]])
updateOutputTensors = buffer.toArray
}

input.toTensor[Float].getTensorType match {
case DenseType => updateOutputTensors(0) = input.toTensor
case _ =>
}

MklDnnOps.streamSubmit(runtime.stream, 1,
updateOutputPrimitives,
updateOutputPrimitives.length,
updateOutputMemoryPrimitives, updateOutputTensors)
output
}

override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
gradInput = nnSoftMax.backward(input, gradOutput)
if (updateGradInputTensors == null) {
val buffer = new ArrayBuffer[Tensor[Float]]()
buffer.append(gradOutput.asInstanceOf[Tensor[Float]])
buffer.append(output.asInstanceOf[Tensor[Float]])
buffer.append(gradInput.asInstanceOf[Tensor[Float]])

updateGradInputTensors = buffer.toArray
}

gradOutput.toTensor[Float].getTensorType match {
case DenseType => updateGradInputTensors(0) = gradOutput.toTensor
case _ =>
}

MklDnnOps.streamSubmit(runtime.stream, 1,
updateGradInputPrimitives,
updateGradInputPrimitives.length,
updateGradInputMemoryPrimitives, updateGradInputTensors
)

gradInput
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class SoftMaxSpec extends FlatSpec with Matchers {

for (x <- tests) {
val sm = SoftMax()
sm.evaluate()
sm.setRuntime(new MklDnnRuntime)
sm.initFwdPrimitives(Array(HeapData(Array(x), Memory.Format.x)), InferencePhase)
sm.initFwdPrimitives(Array(HeapData(Array(x), Memory.Format.x)), TrainingPhase)
sm.initBwdPrimitives(Array(HeapData(Array(x), Memory.Format.x)), TrainingPhase)

val input = Tensor(x).rand()

Expand All @@ -43,6 +43,11 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val nnOutput = nnSm.forward(input)

Tools.dense(output) should be (nnOutput)

sm.backward(input, nnOutput)
nnSm.backward(input, nnOutput)

Tools.dense(sm.gradInput) should be (nnSm.gradInput)
}
}

Expand All @@ -57,8 +62,9 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val sm = SoftMax()
sm.setRuntime(new MklDnnRuntime)
sm.initFwdPrimitives(Array(HeapData(Array(batchSize, channel), Memory.Format.nc)),
InferencePhase)
sm.evaluate()
TrainingPhase)
sm.initBwdPrimitives(Array(HeapData(Array(batchSize, channel), Memory.Format.nc)),
TrainingPhase)

val input = Tensor(batchSize, channel).rand()

Expand All @@ -68,6 +74,11 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val nnOutput = nnSm.forward(input)

Tools.dense(output) shouldEqual nnOutput

sm.backward(input, nnOutput)
nnSm.backward(input, nnOutput)

Tools.dense(sm.gradInput) should be (nnSm.gradInput)
}
}

Expand All @@ -86,8 +97,9 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val sm = SoftMax()
sm.setRuntime(new MklDnnRuntime)
sm.initFwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width),
Memory.Format.nchw)), InferencePhase)
sm.evaluate()
Memory.Format.nchw)), TrainingPhase)
sm.initBwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width),
Memory.Format.nchw)), TrainingPhase)

val input = Tensor(batchSize, channel, height, width).rand()

Expand All @@ -97,6 +109,12 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val nnOutput = nnSm.forward(input)

Tools.dense(output) should be (nnOutput)

sm.backward(input, nnOutput)
nnSm.backward(input, nnOutput)

Equivalent.nearequals(Tools.dense(sm.gradInput).toTensor, nnSm.gradInput.toTensor,
epsilon = 10-5)
}
}

Expand All @@ -114,9 +132,8 @@ class SoftMaxSpec extends FlatSpec with Matchers {
for ((i, j, k) <- tests) {
val sm = SoftMax()
sm.setRuntime(new MklDnnRuntime)
sm.initFwdPrimitives(Array(HeapData(Array(i, j, k),
Memory.Format.ncw)), InferencePhase)
sm.evaluate()
sm.initFwdPrimitives(Array(HeapData(Array(i, j, k), Memory.Format.ncw)), TrainingPhase)
sm.initBwdPrimitives(Array(HeapData(Array(i, j, k), Memory.Format.ncw)), TrainingPhase)

val input = Tensor(i, j, k).rand()

Expand All @@ -126,6 +143,11 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val nnOutput = nnSm.forward(input)

Tools.dense(output) should be (nnOutput)
sm.backward(input, nnOutput)
nnSm.backward(input, nnOutput)

Equivalent.nearequals(Tools.dense(sm.gradInput).toTensor, nnSm.gradInput.toTensor,
epsilon = 10-5)
}
}

Expand All @@ -134,7 +156,9 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val sm = SoftMax()
sm.setRuntime(new MklDnnRuntime)
sm.initFwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width),
Memory.Format.nchw)), InferencePhase)
Memory.Format.nchw)), TrainingPhase)
sm.initBwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width),
Memory.Format.nchw)), TrainingPhase)

val nnSm = nn.SoftMax()

Expand All @@ -147,8 +171,10 @@ class SoftMaxSpec extends FlatSpec with Matchers {
sm.backward(input, gradOutput)
nnSm.backward(input, gradOutput)

sm.output should be (nnSm.output)
sm.gradInput should be (nnSm.gradInput)
Equivalent.nearequals(Tools.dense(sm.output).toTensor, nnSm.output.toTensor,
epsilon = 10-4)
Equivalent.nearequals(Tools.dense(sm.gradInput).toTensor, nnSm.gradInput.toTensor,
epsilon = 10-4)
}

"SoftMax multi times forward" should "work correctly" in {
Expand Down Expand Up @@ -178,8 +204,7 @@ class SoftMaxSpec extends FlatSpec with Matchers {
.add(Input(Array(2, 24564, 21), Memory.Format.ntc))
.add(sm1)
.add(Output(Memory.Format.ntc))
seq1.asInstanceOf[MklDnnContainer].compile(InferencePhase)
seq1.evaluate()
seq1.asInstanceOf[MklDnnContainer].compile(TrainingPhase)

seq1.forward(input)

Expand All @@ -189,7 +214,7 @@ class SoftMaxSpec extends FlatSpec with Matchers {
val seq2 = Sequential().add(Input(Array(2 * 24564, 21), Memory.Format.nc))
.add(sm2)
.add(Output())
seq2.asInstanceOf[MklDnnContainer].compile(InferencePhase)
seq2.asInstanceOf[MklDnnContainer].compile(TrainingPhase)
sm2.evaluate()

seq2.forward(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ class TopologySpec extends FlatSpec with Matchers {

val tmp = fusion.output.toTensor.max(1)

val softmax = SoftMax()
val softmax = nn.SoftMax()

softmax.forward(fusion.output).toTensor.max(2) should be (
softmax.forward(quant.output).toTensor.max(2))
Expand Down

0 comments on commit b4d4975

Please sign in to comment.