Skip to content

Commit

Permalink
pytorch native code cleanup (intel-analytics#1642)
Browse files Browse the repository at this point in the history
* code clean

* remove load

* fix style
  • Loading branch information
qiuxin2012 committed Sep 29, 2019
1 parent a731c53 commit fd42e16
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 30 deletions.
5 changes: 3 additions & 2 deletions TorchCriterion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ class TorchCriterion private(private val lossHolder: TorchModelHolder)
val (sto1, off1, shape1) = TorchCriterion.extract(inputTabel)
val (sto2, off2, shape2) = TorchCriterion.extract(targetTable)

val result = PytorchModel.lossForwardNative(nativeRef, sto1, off1, shape1, sto2, off2, shape2)
val result = PytorchModelWrapper.lossForwardNative(nativeRef, sto1, off1,
shape1, sto2, off2, shape2)
Tensor(result.getData, result.getShape).mean()
}

override def updateGradInput(input: Activity, target: Activity): Activity = {

val result = PytorchModel.lossBackwardNative(nativeRef)
val result = PytorchModelWrapper.lossBackwardNative(nativeRef)
if (result.length == 1) {
val resultTensor = Tensor(result(0).getData, result(0).getShape)
if (gradInput == null) {
Expand Down
31 changes: 3 additions & 28 deletions TorchNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class TorchNet private(private val modelHolder: TorchModelHolder)
PytorchModel.updateWeightNative(this.nativeRef, weights.storage().array())
}

val result = PytorchModel.modelForwardNative(nativeRef, this.isTraining(), sto1, off1, shape1)
val result = PytorchModelWrapper.modelForwardNative(nativeRef,
this.isTraining(), sto1, off1, shape1)
if (result.length == 1) {
val resultTensor = Tensor(result(0).getData, result(0).getShape)
if (output == null) {
Expand All @@ -115,7 +116,7 @@ class TorchNet private(private val modelHolder: TorchModelHolder)

val (sto1, off1, shape1) = TorchCriterion.extract(gradOutputTable)

val result = PytorchModel.modelBackwardNative(nativeRef, sto1, off1, shape1)
val result = PytorchModelWrapper.modelBackwardNative(nativeRef, sto1, off1, shape1)
// update gradients
gradients.resizeAs(weights)
val g = PytorchModel.getGradientNative(this.nativeRef)
Expand Down Expand Up @@ -156,16 +157,11 @@ class TorchNet private(private val modelHolder: TorchModelHolder)
}

object TorchNet {

PytorchModel.isLoaded
loadPytorchNatives() // load once per JVM

private val modelBytesRegistry = new RegistryMap[Array[Byte]]()

@transient
private lazy val inDriver = NetUtils.isDriver


class TorchModelHolder(@transient var torchBytes: Array[Byte], private var id: String)
extends SerializationHolder {

Expand Down Expand Up @@ -218,27 +214,6 @@ object TorchNet {
new TorchNet(new TorchModelHolder(modelbytes, modelPath))
}

// extract libs from zoo jar file
private def loadPytorchNatives(): Unit = {
loadNativelib("pytorch/libpytorch-engine.so")
}

private def loadNativelib(path: String): Unit = {
val inputStream = TorchNet.getClass.getResourceAsStream(s"/${path}")
val file = File.createTempFile("PytorchLoader", "tmp")
val src = Channels.newChannel(inputStream)
val dest = new FileOutputStream(file).getChannel
dest.transferFrom(src, 0, Long.MaxValue)
dest.close()
src.close()
val filePath = file.getAbsolutePath
try {
System.load(filePath)
} finally {
file.delete()
}
}

private[net] def loadPytorchModel(bytes: Array[Byte]): Long = {
var nativeRef = -1L
try {
Expand Down

0 comments on commit fd42e16

Please sign in to comment.