Skip to content

Commit

Permalink
save TorchNet as Pytorch script module (intel-analytics#1564)
Browse files Browse the repository at this point in the history
* support save pytorch model to script

* unit test

* use temp folder

* add import

* correct evaluate

* import

* style fix
  • Loading branch information
hhbyyh committed Sep 25, 2019
1 parent 13ea6f8 commit 07f4ebe
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class TorchNet private(private val modelHolder: TorchModelHolder)
}

override def evaluate(): this.type = {
nativeRef
super.evaluate()
if (!weights.isEmpty) {
PytorchModel.updateWeightNative(nativeRef, weights.storage().array())
Expand Down Expand Up @@ -144,6 +145,14 @@ class TorchNet private(private val modelHolder: TorchModelHolder)
super.finalize()
PytorchModel.releaseModelNative(nativeRef)
}

/**
* export the model to path as a torch script module.
*/
def savePytorch(path : String, overWrite: Boolean = false): Unit = {
PytorchModel.updateWeightNative(this.nativeRef, weights.storage().array())
PytorchModel.saveModelNative(nativeRef, path)
}
}

object TorchNet {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,4 +301,8 @@ class PythonZooNet[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo
TorchCriterion(lossPath)
}

def torchNetSavePytorch(torchnet: TorchNet, path: String): Unit = {
torchnet.savePytorch(path)
}

}

0 comments on commit 07f4ebe

Please sign in to comment.