-
Notifications
You must be signed in to change notification settings - Fork 729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save Keras-like model to pure keras or tensorflow protobuf. #1600
Changes from 5 commits
00c2b28
8ccd8af
738533d
f388524
5ad78d8
33bdc34
4fef833
8139084
697a4a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,21 +16,33 @@ | |
|
||
package com.intel.analytics.zoo.pipeline.api | ||
|
||
import java.io.{BufferedReader, BufferedWriter, FileOutputStream, FileWriter, InputStreamReader, File => JFile} | ||
import java.nio.ByteOrder | ||
import java.util | ||
|
||
import com.intel.analytics.bigdl.Module | ||
import com.intel.analytics.bigdl.nn.Graph._ | ||
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, Initializable} | ||
import com.intel.analytics.bigdl.nn.keras.KerasLayer | ||
import com.intel.analytics.bigdl.nn.keras.{KerasIdentityWrapper, KerasLayer} | ||
import com.intel.analytics.bigdl.nn.{Container, Graph, InitializationMethod} | ||
import com.intel.analytics.bigdl.nn.{Sequential => TSequential} | ||
import com.intel.analytics.bigdl.python.api.PythonBigDL | ||
import com.intel.analytics.bigdl.tensor.Tensor | ||
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric | ||
import com.intel.analytics.bigdl.utils.File | ||
import com.intel.analytics.bigdl.utils.{File, Shape} | ||
import com.intel.analytics.zoo.models.caffe.CaffeLoader | ||
import com.intel.analytics.bigdl.utils.serializer.ModuleLoader | ||
import com.intel.analytics.bigdl.utils.tf.{Session, TensorflowLoader} | ||
import com.intel.analytics.zoo.common.Utils | ||
import com.intel.analytics.zoo.pipeline.api.autograd.Variable | ||
import com.intel.analytics.zoo.pipeline.api.keras.layers.WordEmbedding | ||
import com.intel.analytics.zoo.pipeline.api.keras.layers.{KerasLayerWrapper, WordEmbedding} | ||
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils | ||
import com.intel.analytics.zoo.pipeline.api.keras.models.{KerasNet, Model, Sequential} | ||
import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, NetUtils} | ||
import org.apache.commons.io.FileUtils | ||
import org.apache.log4j.Logger | ||
import org.apache.spark.bigdl.api.python.BigDLSerDe | ||
import org.apache.zookeeper.KeeperException.UnimplementedException | ||
|
||
import scala.reflect.ClassTag | ||
|
||
|
@@ -53,6 +65,23 @@ trait Net { | |
new Variable( | ||
this.asInstanceOf[AbstractModule[Activity, Activity, T]].inputs(vars.map(_.node): _*)) | ||
} | ||
|
||
private[zoo] def toKeras2(dir: String): String = { | ||
throw new UnimplementedException() | ||
} | ||
|
||
/** | ||
* Get keras-like weights. | ||
* @tparam T | ||
* @return | ||
*/ | ||
private[zoo] def getKerasWeights(): Array[Tensor[Float]] = { | ||
if (this.asInstanceOf[AbstractModule[_, _, _]].parameters()._1.length != 0) { | ||
throw new UnimplementedException() | ||
} else { | ||
Array() | ||
} | ||
} | ||
} | ||
|
||
object Net { | ||
|
@@ -186,4 +215,192 @@ object Net { | |
implicit ev: TensorNumeric[T]): Session[T] = { | ||
TensorflowLoader.checkpoints(graphFile, binFile, byteOrder) | ||
} | ||
|
||
def saveToKeras2[T: ClassTag](model: Net, filePath: String, python: String = "python") | ||
(implicit ev: TensorNumeric[T]): Unit= { | ||
NetSaver.saveToKeras2(model.asInstanceOf[Module[T]], filePath, python) | ||
} | ||
|
||
def saveToTf[T: ClassTag](model: Net, dir: String, python: String = "python") | ||
(implicit ev: TensorNumeric[T]): Unit= { | ||
NetSaver.saveToTf(model.asInstanceOf[Module[T]], dir, python) | ||
} | ||
|
||
private[zoo] def getName(name: String): String = { | ||
name.split("\\.").last | ||
} | ||
|
||
private[zoo] def arrayToString(array: Seq[Int]): String = { | ||
s"(${array.mkString(", ")})" | ||
} | ||
|
||
private[zoo] def inputShapeToString(inputShape: Shape): String = { | ||
if (inputShape != null) { | ||
s", input_shape=(${inputShape.toSingle().mkString(", ")},)" | ||
} else { | ||
"" | ||
} | ||
} | ||
|
||
private[zoo] def activationToString(activation: AbstractModule[_, _, _], | ||
paramName: String = "activation"): String = { | ||
val trueActivation = if (activation.isInstanceOf[KerasIdentityWrapper[_]]) { | ||
activation.asInstanceOf[KerasIdentityWrapper[_]].layer | ||
} else { | ||
activation | ||
} | ||
if (activation != null) { | ||
s", $paramName='${KerasUtils.getActivationName(trueActivation)}'" | ||
} else { | ||
"" | ||
} | ||
|
||
} | ||
|
||
private[zoo] def booleanToString(boolean: Boolean, | ||
booleanName: String): String = { | ||
s", $booleanName=${if(boolean) "True" else "False"}" | ||
} | ||
|
||
private[zoo] def nameToString(name: String): String = { | ||
s", name='$name'" | ||
} | ||
|
||
|
||
protected object NetSaver { | ||
private val logger = Logger.getLogger(getClass) | ||
|
||
protected val header = | ||
""" | ||
|from keras.models import Sequential | ||
|from keras.layers import * | ||
|from pyspark.serializers import PickleSerializer | ||
| | ||
|def load_to_numpy(file): | ||
| in_file = open(file, "rb") | ||
| data = in_file.read() | ||
| in_file.close() | ||
| r=PickleSerializer().loads(data, encoding="bytes") | ||
| return r.to_ndarray() | ||
""".stripMargin + "\n" | ||
|
||
protected val tfHeader = | ||
""" | ||
|from zoo.util.tf import export_tf | ||
|from keras import backend as K | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why keras but not tensorflow.keras? |
||
|import tensorflow as tf | ||
""".stripMargin + "\n" | ||
|
||
def save[T: ClassTag]( | ||
m: Module[T], | ||
path: String, | ||
python: String, | ||
saveCommand: String) | ||
(implicit ev: TensorNumeric[T]): Unit = { | ||
val tmpDir = Utils.createTmpDir("ZooKeras") | ||
logger.info(s"Write to ${tmpDir}") | ||
val modelFile = tmpDir.toString + s"/${m.getName()}.py" | ||
val bw = new BufferedWriter(new FileWriter(modelFile)) | ||
bw.write(header) | ||
if (m.isInstanceOf[Sequential[T]]) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. raise exception here if otherwise? |
||
export(m.asInstanceOf[Sequential[T]], tmpDir.toString, bw) | ||
} | ||
bw.write(saveWeights(m, tmpDir.toString)) | ||
bw.write(saveCommand) | ||
bw.flush() | ||
bw.close() | ||
execCommand(s"${python} ${modelFile}") | ||
FileUtils.deleteDirectory(tmpDir.toFile()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. surround with finally ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When save failed, user or developer can check the python file to see what's happening. |
||
} | ||
|
||
def saveToTf[T: ClassTag](m: Module[T], path: String, python: String) | ||
(implicit ev: TensorNumeric[T]): Unit = { | ||
val saveCommand = tfHeader + | ||
s"export_tf(K.get_session(), '${path}', model.inputs, model.outputs)\n" | ||
save(m, path, python, saveCommand) | ||
} | ||
|
||
def saveToKeras2[T: ClassTag](m: Module[T], path: String, python: String) | ||
(implicit ev: TensorNumeric[T]): Unit = { | ||
save(m, path, python, s"model.save('$path')\n") | ||
} | ||
|
||
def execCommand(command: String): Unit = { | ||
val proc = Runtime.getRuntime().exec(command) | ||
proc.waitFor() | ||
if (proc.exitValue() != 0) { | ||
val error = new BufferedReader(new InputStreamReader(proc.getErrorStream())) | ||
val errorMsg = new StringBuilder() | ||
var line = error.readLine() | ||
while(line != null) { | ||
errorMsg.append(line + "\n") | ||
line = error.readLine() | ||
} | ||
throw new RuntimeException(s"Export Keras2 model failed:\n" + errorMsg.toString()) | ||
} | ||
|
||
} | ||
|
||
def export[T: ClassTag]( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In order to support the |
||
sequential: Sequential[T], | ||
path: String, | ||
writer: BufferedWriter): Unit = { | ||
writer.write(s"${sequential.getName()} = " + | ||
s"Sequential(name='${(sequential.getName())}')\n") | ||
val modules = sequential.modules(0).asInstanceOf[TSequential[T]].modules | ||
modules.foreach{ module => | ||
if (module.isInstanceOf[Sequential[T]]) { | ||
export(module.asInstanceOf[Sequential[T]], path, writer) | ||
writer.write(s"${sequential.getName()}.add(${module.getName})\n") | ||
} else if (module.isInstanceOf[Net]){ | ||
writer.write(s"${module.getName()} = ${module.asInstanceOf[Net].toKeras2(path)}\n") | ||
writer.write(s"${sequential.getName()}.add(${module.getName})\n") | ||
} else { | ||
throw new IllegalArgumentException(s"unkown type ${this.getClass.getName}") | ||
} | ||
} | ||
} | ||
|
||
private[zoo] def saveWeights[T: ClassTag]( | ||
module: AbstractModule[_, _, T], path: String) | ||
(implicit ev: TensorNumeric[T]): String = { | ||
val moduleName = module.getName() | ||
var i = 0 | ||
val wStrings = module.asInstanceOf[Net].getKerasWeights().map{p => | ||
val pName = s"${moduleName}_p${i}" | ||
val pPath = getUniqueFile(s"${path}/${pName}") | ||
saveToJTensor(p, pPath) | ||
i += 1 | ||
(s"${pName} = load_to_numpy('${pPath}')", | ||
pName) | ||
} | ||
val loadWeights = wStrings.map(_._1).mkString("\n") | ||
val weightsList = wStrings.map(_._2).mkString(",") | ||
loadWeights + "\n" + | ||
s"${moduleName}.set_weights([${weightsList}])\n" | ||
} | ||
|
||
private def getUniqueFile(path: String): JFile = { | ||
var file = new JFile(path) | ||
var i = 0 | ||
while(file.exists()) { | ||
file = new JFile(path+s".$i") | ||
i += 1 | ||
} | ||
file | ||
} | ||
|
||
private def saveToJTensor[T: ClassTag]( | ||
tensor: Tensor[T], file: JFile) | ||
(implicit ev: TensorNumeric[T]): Unit = { | ||
val pythonBigDL = new PythonBigDL[T]() | ||
val jt = pythonBigDL.toJTensor(tensor) | ||
val bytes = BigDLSerDe.dumps(jt) | ||
val fio = new FileOutputStream(file) | ||
fio.write(bytes) | ||
fio.flush() | ||
fio.close() | ||
} | ||
|
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,27 @@ class Dense[T: ClassTag]( | |
override val inputShape: Shape = null)(implicit ev: TensorNumeric[T]) | ||
extends BigDLDense[T](outputDim, init, activation, wRegularizer, bRegularizer, bias, | ||
inputShape) with Net { | ||
|
||
override private[zoo] def toKeras2(dir: String): String = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be more readable if they can be organized in json format and use
|
||
val inputString = Net.inputShapeToString(inputShape) | ||
val act = Net.activationToString(activation) | ||
val kname = Net.nameToString(getName()) | ||
s"${Net.getName(this.getClass.getName)}" + | ||
s"(units=$outputDim" + | ||
s"$inputString" + | ||
s", use_bias=${if(bias) "True" else "False"}" + | ||
s"${kname}" + | ||
s"$act)\n" | ||
} | ||
|
||
override private[zoo] def getKerasWeights(): Array[Tensor[Float]] = { | ||
val weights = this.parameters()._1 | ||
val kWeights = Array.tabulate(weights.length)(_ => Tensor[Float]()) | ||
weights(0) = weights(0).t().contiguous() | ||
weights(0).cast[Float](kWeights(0).resizeAs(weights(0))) | ||
weights(1).cast[Float](kWeights(1).resizeAs(weights(1))) | ||
kWeights | ||
} | ||
} | ||
|
||
object Dense { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not put this method in
KerasNet
so that users can directly callmodel.saveToKeras2
instead ofNet.save....