Skip to content

Commit

Permalink
Pytorch training and inference for model with multiple output and mul…
Browse files Browse the repository at this point in the history
…tiple input (intel-analytics#1544)

* inference with table output

* add unit test

* multi and unit test

* remove duplicate ut

* clear caching data

* support multiple shape

* multi input ut

* release

*  remove empty line

* update so
  • Loading branch information
hhbyyh committed Aug 9, 2019
1 parent 9e3a515 commit f807dac
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractCriterion, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{T, Table}
import com.intel.analytics.zoo.pipeline.api.net.TorchNet.TorchModelHolder
import org.apache.commons.io.FileUtils

Expand All @@ -32,7 +33,6 @@ class TorchCriterion private(private val lossHolder: TorchModelHolder)

implicit val ev = TensorNumeric.NumericFloat
implicit val tag: ClassTag[Float] = ClassTag.Float
gradInput = Activity.allocate[Tensor[Float], Float]().toTensor

/**
* sequential id in cpp: std::vector<std::shared_ptr<torch::jit::script::Module>> handles;
Expand All @@ -46,35 +46,34 @@ class TorchCriterion private(private val lossHolder: TorchModelHolder)
}

override def updateOutput(input: Activity, target: Activity): Float = {
if (input.isTable || target.isTable) {
throw new UnsupportedOperationException()
}
val inputTensor = input.toTensor
val targetTensor = target.toTensor

require(inputTensor.isContiguous())
require(targetTensor.isContiguous())

val result = PytorchModel.lossForwardNative(nativeRef,
inputTensor.storage().array(),
inputTensor.storageOffset() - 1,
inputTensor.size(),
targetTensor.storage().array(),
targetTensor.storageOffset() - 1,
targetTensor.size()
)
val inputTabel = if (input.isTensor) T(input.toTensor) else input.toTable
val targetTable = if (target.isTensor) T(target.toTensor) else target.toTable

val (sto1, off1, shape1) = TorchCriterion.extract(inputTabel)
val (sto2, off2, shape2) = TorchCriterion.extract(targetTable)

val result = PytorchModel.lossForwardNative(nativeRef, sto1, off1, shape1, sto2, off2, shape2)
Tensor(result.getData, result.getShape).mean()
}

override def updateGradInput(input: Activity, target: Activity): Activity = {
if (input.isTable || target.isTable) {
throw new UnsupportedOperationException()
}

gradInput.asInstanceOf[Tensor[Float]].resizeAs(input.toTensor)
val result = PytorchModel.lossBackwardNative(nativeRef)
val resultTensor = Tensor(result.getData, result.getShape)
gradInput.toTensor.set(resultTensor)
if (result.length == 1) {
val resultTensor = Tensor(result(0).getData, result(0).getShape)
if (gradInput == null) {
gradInput = Tensor()
}
gradInput.toTensor.set(resultTensor)
} else {
if (gradInput == null) {
gradInput = T()
}
gradInput.toTable.clear()
result.foreach { t =>
gradInput.toTable.insert(Tensor(t.getData, t.getShape))
}
}
gradInput
}

Expand Down Expand Up @@ -113,4 +112,18 @@ object TorchCriterion {
nativeRef
}


private[net] def extract(t: Table): (Array[Array[Float]], Array[Int], Array[Array[Int]]) = {
val tensors = t.toSeq[Tensor[Float]]

val tuples = tensors.map { t =>
require(t.isContiguous())
(t.storage(), t.storageOffset() - 1, t.size())
}
val storages = tuples.map(_._1.array()).toArray
val offsets = tuples.map(_._2).toArray
val shapes = tuples.map(_._3).toArray
(storages, offsets, shapes)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,24 @@ import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.zoo.pipeline.api.Predictable
import com.intel.analytics.zoo.pipeline.api.net.TorchNet.TorchModelHolder
import org.apache.commons.io.FileUtils
import org.slf4j.LoggerFactory

import scala.reflect.ClassTag

/**
* [[TorchNet]] wraps a TorchScript model as a single layer.
*/
class TorchNet private(private val modelHolder: TorchModelHolder)
extends AbstractModule[Tensor[Float], Tensor[Float], Float] with Predictable[Float] {
extends AbstractModule[Activity, Activity, Float] with Predictable[Float] {

protected val module: Module[Float] = this
implicit val ev = TensorNumeric.NumericFloat
implicit val tag: ClassTag[Float] = ClassTag.Float
val logger = LoggerFactory.getLogger(getClass)

var weights: Tensor[Float] = _
var gradients: Tensor[Float] = _
Expand Down Expand Up @@ -70,31 +73,62 @@ class TorchNet private(private val modelHolder: TorchModelHolder)
(Array(weights), Array(gradients))
}

override def updateOutput(input: Tensor[Float]): Tensor[Float] = {
override def updateOutput(input: Activity): Activity = {
val inputTable = if (input.isTensor) T(input.toTensor) else input.toTable

val (sto1, off1, shape1) = TorchCriterion.extract(inputTable)

if (this.isTraining()) {
PytorchModel.updateWeightNative(this.nativeRef, weights.storage().array())
}

require(input.isContiguous())
val data = input.storage().array()
val size = input.size()
val offset = input.storageOffset() - 1
val result = PytorchModel.modelForwardNative(nativeRef, this.isTraining(), data, offset, size)
val resultTensor = Tensor(result.getData, result.getShape)
output.set(resultTensor)
val result = PytorchModel.modelForwardNative(nativeRef, this.isTraining(), sto1, off1, shape1)
if (result.length == 1) {
val resultTensor = Tensor(result(0).getData, result(0).getShape)
if (output == null) {
output = Tensor()
}
output.toTensor.set(resultTensor)
} else {
if (output == null) {
output = T()
}
output.toTable.clear()
result.foreach { t =>
output.toTable.insert(Tensor(t.getData, t.getShape))
}
}
output
}

override def updateGradInput(input: Tensor[Float], gradOutput: Tensor[Float]): Tensor[Float] = {
val data = gradOutput.storage().array()
val size = gradOutput.size()
val offset = gradOutput.storageOffset() - 1
val result = PytorchModel.modelBackwardNative(nativeRef, data, offset, size)
val resultTensor = Tensor(result.getData, result.getShape)
override def updateGradInput(input: Activity, gradOutput: Activity): Activity = {
val gradOutputTable = if (gradOutput.isTensor) T(gradOutput.toTensor) else gradOutput.toTable

val (sto1, off1, shape1) = TorchCriterion.extract(gradOutputTable)

val result = PytorchModel.modelBackwardNative(nativeRef, sto1, off1, shape1)
// update gradients
gradients.resizeAs(weights)
val g = PytorchModel.getGradientNative(this.nativeRef)
System.arraycopy(g, 0, gradients.storage().array(), 0, g.length)
gradInput.set(resultTensor)

// update gradinput
if (result.length == 1) {
val resultTensor = Tensor(result(0).getData, result(0).getShape)
if (gradInput == null) {
gradInput = Tensor()
}
gradInput.toTensor.set(resultTensor)
} else {
if (gradInput == null) {
gradInput = T()
}
gradInput.toTable.clear()
result.foreach { t =>
gradInput.toTable.insert(Tensor(t.getData, t.getShape))
}
}
gradInput
}

// TODO: use release if possible. now for larger model it's causing early release
Expand Down Expand Up @@ -188,7 +222,6 @@ object TorchNet {
}
}


private[net] def loadPytorchModel(bytes: Array[Byte]): Long = {
var nativeRef = -1L
try {
Expand Down

0 comments on commit f807dac

Please sign in to comment.