From b4d49754644a70337fc2303b73cb8d1c6b807bb6 Mon Sep 17 00:00:00 2001 From: Yanzhang Wang Date: Tue, 26 Nov 2019 16:32:38 +0800 Subject: [PATCH] feat: add softmax backward (#2967) * feat: add softmax backward --- .../bigdl/dllib/nn/mkldnn/MklDnnMemory.scala | 5 + .../bigdl/dllib/nn/mkldnn/SoftMax.scala | 182 ++++++++++-------- .../bigdl/dllib/nn/mkldnn/SoftMaxSpec.scala | 55 ++++-- .../bigdl/dllib/nn/mkldnn/TopologySpec.scala | 2 +- 4 files changed, 149 insertions(+), 95 deletions(-) diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/MklDnnMemory.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/MklDnnMemory.scala index 18363940893..a9c82ab5b36 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/MklDnnMemory.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/MklDnnMemory.scala @@ -126,6 +126,11 @@ object MklDnnMemory { MklDnn.SoftMaxForwardDescInit(prop_kind, dataDesc, axis)).ptr } + def SoftMaxBackwardDescInit(propKind: Int, diffDesc: Long, dstDesc: Long, + axis: Int)(implicit owner: MemoryOwner): Long = { + new MklMemoryDescInit(MklDnn.SoftMaxBackwardDescInit(diffDesc, dstDesc, axis)).ptr + } + def ConvForwardDescInit(prop_kind: Int, alg_kind: Int, src_desc: Long, weights_desc: Long, bias_desc: Long, dst_desc: Long, strides: Array[Int], padding_l: Array[Int], padding_r: Array[Int], padding_kind: Int)(implicit owner: MemoryOwner): Long = { diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMax.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMax.scala index 336282488ab..62079435f78 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMax.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMax.scala @@ -16,7 +16,7 @@ package com.intel.analytics.bigdl.nn.mkldnn -import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, PropKind, Stream => DnnStream} +import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, PropKind, Query, Stream => DnnStream} import com.intel.analytics.bigdl.nn import com.intel.analytics.bigdl.nn.abstractnn.Activity import com.intel.analytics.bigdl.nn.mkldnn.Phase.{InferencePhase, TrainingPhase} @@ -27,12 +27,14 @@ import com.intel.analytics.bigdl.utils.Shape import scala.collection.mutable.ArrayBuffer class SoftMax(val axis: Int = -1) extends MklDnnLayer { - private val nnSoftMax = nn.SoftMax[Float]() - @transient private var updateOutputTensors: Array[Tensor[Float]] = _ @transient private var updateOutputMemoryPrimitives: Array[Long] = _ + @transient private var updateGradInputTensors: Array[Tensor[Float]] = _ + @transient private var updateGradInputMemoryPrimitives: Array[Long] = _ @transient private var modelPhase: Phase = null + private var defaultAxis = 0 + private def initPhase(phase: Phase): Unit = { if (phase != null) return modelPhase = phase isTraining() match { @@ -53,96 +55,118 @@ class SoftMax(val axis: Int = -1) extends MklDnnLayer { override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = { initPhase(phase) - modelPhase match { - case TrainingPhase => - _inputFormats = inputs.map(x => HeapData(x.shape, format(x.shape))) - _outputFormats = inputs.map(x => HeapData(x.shape, format(x.shape))) - - (_inputFormats, _outputFormats) - case InferencePhase => - val defaultAxis = inputs(0).shape.length match { - case 1 => 0 - case 2 => 1 - case 3 => 0 - case 4 => 1 - case _ => throw new UnsupportedOperationException("1D, 2D, 3D or 4D tensor expected") - } - - _inputFormats = Array(NativeData(inputs(0).shape, inputs(0).layout, DataType.F32)) - - val localInputFormat = if (inputs(0).shape.length == 3 && - inputs(0).layout == Memory.Format.ntc) { - // note: here, the format and the true memory layout is not consistent. - // for ntc input, we should reshape the `shape` and make the format to tnc - val shape = Array(inputs(0).shape(1), inputs(0).shape(0), inputs(0).shape(2)) - NativeData(shape, Memory.Format.tnc) - } else { - _inputFormats(0) - } - - val desc = MklDnnMemory.SoftMaxForwardDescInit(PropKind.ForwardInference, - localInputFormat.getMemoryDescription(), if (axis == -1) defaultAxis else axis) - val forwardPrimDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L) - - _outputFormats = if (inputs(0).shape.length ==3 && - inputs(0).layout == Memory.Format.ntc) { - // because set the input format as tnc first, we should set the output to ntc. - Array(NativeData(inputs(0).shape, Memory.Format.ntc)) - } else { - Array(MemoryData.primitiveOutput(forwardPrimDesc)) - } - - val srcs = Array(inputs(0).getPrimitive(runtime)) - val indexes = Array(0) - val dsts = Array(_outputFormats(0).getPrimitive(runtime)) - - val primitive = MklDnnMemory.PrimitiveCreate2(forwardPrimDesc, srcs, indexes, - srcs.length, dsts, dsts.length) - - updateOutputPrimitives = Array(primitive) - updateOutputMemoryPrimitives = srcs ++ dsts - - output = initTensor(_outputFormats(0)) - - (_inputFormats, _outputFormats) - case _ => throw new UnsupportedOperationException + defaultAxis = inputs(0).shape.length match { + case 1 => 0 + case 2 => 1 + case 3 => 0 + case 4 => 1 + case _ => throw new UnsupportedOperationException("1D, 2D, 3D or 4D tensor expected") + } + + _inputFormats = Array(NativeData(inputs(0).shape, inputs(0).layout, DataType.F32)) + + val localInputFormat = if (inputs(0).shape.length == 3 && + inputs(0).layout == Memory.Format.ntc) { + // note: here, the format and the true memory layout is not consistent. + // for ntc input, we should reshape the `shape` and make the format to tnc + val shape = Array(inputs(0).shape(1), inputs(0).shape(0), inputs(0).shape(2)) + NativeData(shape, Memory.Format.tnc) + } else { + _inputFormats(0) + } + + val desc = MklDnnMemory.SoftMaxForwardDescInit(PropKind.Forward, + localInputFormat.getMemoryDescription(), if (axis == -1) defaultAxis else axis) + val forwardPrimDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L) + + _outputFormats = if (inputs(0).shape.length ==3 && + inputs(0).layout == Memory.Format.ntc) { + // because set the input format as tnc first, we should set the output to ntc. + Array(NativeData(inputs(0).shape, Memory.Format.ntc)) + } else { + Array(MemoryData.primitiveOutput(forwardPrimDesc)) } + + val srcs = Array(inputs(0).getPrimitive(runtime)) + val indexes = Array(0) + val dsts = Array(_outputFormats(0).getPrimitive(runtime)) + + val primitive = MklDnnMemory.PrimitiveCreate2(forwardPrimDesc, srcs, indexes, + srcs.length, dsts, dsts.length) + + updateOutputPrimitives = Array(primitive) + updateOutputMemoryPrimitives = srcs ++ dsts + + output = initTensor(_outputFormats(0)) + + (_inputFormats, _outputFormats) } override private[mkldnn] def initBwdPrimitives(grad: Array[MemoryData], phase: Phase) = { - _gradInputFormats = grad.clone() - _gradOutputFormats = grad.clone() + val desc = MklDnnMemory.SoftMaxBackwardDescInit(PropKind.Backward, + grad(0).getMemoryDescription(), outputFormats()(0).getMemoryDescription(), + if (axis == -1) defaultAxis else axis) + val primDesc = MklDnnMemory.PrimitiveDescCreate(desc, runtime.engine, 0L) + + _gradOutputFormats = grad + _gradInputFormats = Array(MemoryData.operationWant(primDesc, Query.DiffSrcPd)) + + val srcs = Array(grad(0).getPrimitive(runtime), outputFormats()(0).getPrimitive(runtime)) + val indexes = Array(0) + val dsts = Array(_gradInputFormats(0).getPrimitive(runtime)) + + val primitive = MklDnnMemory.PrimitiveCreate2(primDesc, srcs, indexes, + srcs.length, dsts, dsts.length) + + updateGradInputPrimitives = Array(primitive) + updateGradInputMemoryPrimitives = srcs ++ dsts + + gradInput = initTensor(_gradInputFormats(0)) + (_gradInputFormats, _gradOutputFormats) } override def updateOutput(input: Activity): Activity = { - if (this.isTraining()) { - nnSoftMax.forward(input) - output = nnSoftMax.output - } else { - if (updateOutputTensors == null) { - val buffer = new ArrayBuffer[Tensor[Float]]() - buffer.append(input.asInstanceOf[Tensor[Float]]) - buffer.append(output.asInstanceOf[Tensor[Float]]) - updateOutputTensors = buffer.toArray - } - - input.toTensor[Float].getTensorType match { - case DenseType => updateOutputTensors(0) = input.toTensor - case _ => - } - - MklDnnOps.streamSubmit(runtime.stream, 1, - updateOutputPrimitives, - updateOutputPrimitives.length, - updateOutputMemoryPrimitives, updateOutputTensors) + if (updateOutputTensors == null) { + val buffer = new ArrayBuffer[Tensor[Float]]() + buffer.append(input.asInstanceOf[Tensor[Float]]) + buffer.append(output.asInstanceOf[Tensor[Float]]) + updateOutputTensors = buffer.toArray + } + + input.toTensor[Float].getTensorType match { + case DenseType => updateOutputTensors(0) = input.toTensor + case _ => } + MklDnnOps.streamSubmit(runtime.stream, 1, + updateOutputPrimitives, + updateOutputPrimitives.length, + updateOutputMemoryPrimitives, updateOutputTensors) output } override def updateGradInput(input: Activity, gradOutput: Activity): Activity = { - gradInput = nnSoftMax.backward(input, gradOutput) + if (updateGradInputTensors == null) { + val buffer = new ArrayBuffer[Tensor[Float]]() + buffer.append(gradOutput.asInstanceOf[Tensor[Float]]) + buffer.append(output.asInstanceOf[Tensor[Float]]) + buffer.append(gradInput.asInstanceOf[Tensor[Float]]) + + updateGradInputTensors = buffer.toArray + } + + gradOutput.toTensor[Float].getTensorType match { + case DenseType => updateGradInputTensors(0) = gradOutput.toTensor + case _ => + } + + MklDnnOps.streamSubmit(runtime.stream, 1, + updateGradInputPrimitives, + updateGradInputPrimitives.length, + updateGradInputMemoryPrimitives, updateGradInputTensors + ) + gradInput } diff --git a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMaxSpec.scala b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMaxSpec.scala index 02cfc3a87ba..fc185e765d9 100644 --- a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMaxSpec.scala +++ b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/SoftMaxSpec.scala @@ -31,9 +31,9 @@ class SoftMaxSpec extends FlatSpec with Matchers { for (x <- tests) { val sm = SoftMax() - sm.evaluate() sm.setRuntime(new MklDnnRuntime) - sm.initFwdPrimitives(Array(HeapData(Array(x), Memory.Format.x)), InferencePhase) + sm.initFwdPrimitives(Array(HeapData(Array(x), Memory.Format.x)), TrainingPhase) + sm.initBwdPrimitives(Array(HeapData(Array(x), Memory.Format.x)), TrainingPhase) val input = Tensor(x).rand() @@ -43,6 +43,11 @@ class SoftMaxSpec extends FlatSpec with Matchers { val nnOutput = nnSm.forward(input) Tools.dense(output) should be (nnOutput) + + sm.backward(input, nnOutput) + nnSm.backward(input, nnOutput) + + Tools.dense(sm.gradInput) should be (nnSm.gradInput) } } @@ -57,8 +62,9 @@ class SoftMaxSpec extends FlatSpec with Matchers { val sm = SoftMax() sm.setRuntime(new MklDnnRuntime) sm.initFwdPrimitives(Array(HeapData(Array(batchSize, channel), Memory.Format.nc)), - InferencePhase) - sm.evaluate() + TrainingPhase) + sm.initBwdPrimitives(Array(HeapData(Array(batchSize, channel), Memory.Format.nc)), + TrainingPhase) val input = Tensor(batchSize, channel).rand() @@ -68,6 +74,11 @@ class SoftMaxSpec extends FlatSpec with Matchers { val nnOutput = nnSm.forward(input) Tools.dense(output) shouldEqual nnOutput + + sm.backward(input, nnOutput) + nnSm.backward(input, nnOutput) + + Tools.dense(sm.gradInput) should be (nnSm.gradInput) } } @@ -86,8 +97,9 @@ class SoftMaxSpec extends FlatSpec with Matchers { val sm = SoftMax() sm.setRuntime(new MklDnnRuntime) sm.initFwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width), - Memory.Format.nchw)), InferencePhase) - sm.evaluate() + Memory.Format.nchw)), TrainingPhase) + sm.initBwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width), + Memory.Format.nchw)), TrainingPhase) val input = Tensor(batchSize, channel, height, width).rand() @@ -97,6 +109,12 @@ class SoftMaxSpec extends FlatSpec with Matchers { val nnOutput = nnSm.forward(input) Tools.dense(output) should be (nnOutput) + + sm.backward(input, nnOutput) + nnSm.backward(input, nnOutput) + + Equivalent.nearequals(Tools.dense(sm.gradInput).toTensor, nnSm.gradInput.toTensor, + epsilon = 10-5) } } @@ -114,9 +132,8 @@ class SoftMaxSpec extends FlatSpec with Matchers { for ((i, j, k) <- tests) { val sm = SoftMax() sm.setRuntime(new MklDnnRuntime) - sm.initFwdPrimitives(Array(HeapData(Array(i, j, k), - Memory.Format.ncw)), InferencePhase) - sm.evaluate() + sm.initFwdPrimitives(Array(HeapData(Array(i, j, k), Memory.Format.ncw)), TrainingPhase) + sm.initBwdPrimitives(Array(HeapData(Array(i, j, k), Memory.Format.ncw)), TrainingPhase) val input = Tensor(i, j, k).rand() @@ -126,6 +143,11 @@ class SoftMaxSpec extends FlatSpec with Matchers { val nnOutput = nnSm.forward(input) Tools.dense(output) should be (nnOutput) + sm.backward(input, nnOutput) + nnSm.backward(input, nnOutput) + + Equivalent.nearequals(Tools.dense(sm.gradInput).toTensor, nnSm.gradInput.toTensor, + epsilon = 10-5) } } @@ -134,7 +156,9 @@ class SoftMaxSpec extends FlatSpec with Matchers { val sm = SoftMax() sm.setRuntime(new MklDnnRuntime) sm.initFwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width), - Memory.Format.nchw)), InferencePhase) + Memory.Format.nchw)), TrainingPhase) + sm.initBwdPrimitives(Array(HeapData(Array(batchSize, channel, height, width), + Memory.Format.nchw)), TrainingPhase) val nnSm = nn.SoftMax() @@ -147,8 +171,10 @@ class SoftMaxSpec extends FlatSpec with Matchers { sm.backward(input, gradOutput) nnSm.backward(input, gradOutput) - sm.output should be (nnSm.output) - sm.gradInput should be (nnSm.gradInput) + Equivalent.nearequals(Tools.dense(sm.output).toTensor, nnSm.output.toTensor, + epsilon = 10-4) + Equivalent.nearequals(Tools.dense(sm.gradInput).toTensor, nnSm.gradInput.toTensor, + epsilon = 10-4) } "SoftMax multi times forward" should "work correctly" in { @@ -178,8 +204,7 @@ class SoftMaxSpec extends FlatSpec with Matchers { .add(Input(Array(2, 24564, 21), Memory.Format.ntc)) .add(sm1) .add(Output(Memory.Format.ntc)) - seq1.asInstanceOf[MklDnnContainer].compile(InferencePhase) - seq1.evaluate() + seq1.asInstanceOf[MklDnnContainer].compile(TrainingPhase) seq1.forward(input) @@ -189,7 +214,7 @@ class SoftMaxSpec extends FlatSpec with Matchers { val seq2 = Sequential().add(Input(Array(2 * 24564, 21), Memory.Format.nc)) .add(sm2) .add(Output()) - seq2.asInstanceOf[MklDnnContainer].compile(InferencePhase) + seq2.asInstanceOf[MklDnnContainer].compile(TrainingPhase) sm2.evaluate() seq2.forward(input) diff --git a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/TopologySpec.scala b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/TopologySpec.scala index 0e3d87e91e4..3f1099ba6b8 100644 --- a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/TopologySpec.scala +++ b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nn/mkldnn/TopologySpec.scala @@ -1052,7 +1052,7 @@ class TopologySpec extends FlatSpec with Matchers { val tmp = fusion.output.toTensor.max(1) - val softmax = SoftMax() + val softmax = nn.SoftMax() softmax.forward(fusion.output).toTensor.max(2) should be ( softmax.forward(quant.output).toTensor.max(2))