Skip to content

Commit

Permalink
add load TF saved model as TFNet inference model (intel-analytics#1745)
Browse files Browse the repository at this point in the history
  • Loading branch information
glorysdj committed Nov 11, 2019
1 parent 6cd5916 commit ed3679c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
}

/**
* loads a TF model as TFNet
* loads a TF frozen model as TFNet
*
* @param modelPath the path of the tensorflow model file
* @param modelPath the path of the tensorflow frozen model file
*/
def doLoadTF(modelPath: String): Unit = {
doLoadTensorflowModel(modelPath, 1, 1, true)
}

/**
* loads a TF model as TFNet
* loads a TF frozen model as TFNet
*
* @param modelPath the path of the tensorflow model
* @param modelPath the path of the tensorflow frozen model
* @param intraOpParallelismThreads the num of intraOpParallelismThreads
* @param interOpParallelismThreads the num of interOpParallelismThreads
* @param usePerSessionThreads whether to perSessionThreads
Expand All @@ -120,6 +120,42 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
usePerSessionThreads)
}

/**
* loads a TF saved model as TFNet
*
* @param modelPath the path of the tensorflow saved model dir
* @param inputs the inputs of the model
* @param outputs the outputs of the model
*/
def doLoadTF(modelPath: String, inputs: Array[String], outputs: Array[String]): Unit = {
doLoadTensorflowSavedModel(modelPath, inputs, outputs, 1, 1, true)
}

/**
* loads a TF saved model as TFNet
*
* @param modelPath the path of the tensorflow saved model dir
* @param inputs the inputs of the model
* @param outputs the outputs of the model
* @param intraOpParallelismThreads the num of intraOpParallelismThreads
* @param interOpParallelismThreads the num of interOpParallelismThreads
* @param usePerSessionThreads whether to perSessionThreads
*/
def doLoadTF(modelPath: String,
inputs: Array[String],
outputs: Array[String],
intraOpParallelismThreads: Int,
interOpParallelismThreads: Int,
usePerSessionThreads: Boolean): Unit = {
doLoadTensorflowSavedModel(
modelPath,
inputs,
outputs,
intraOpParallelismThreads,
interOpParallelismThreads,
usePerSessionThreads)
}

/**
* loads a TF model as OpenVINO
*
Expand Down Expand Up @@ -362,6 +398,19 @@ class InferenceModel(private var autoScalingEnabled: Boolean = true,
offerModelQueue()
}

private def doLoadTensorflowSavedModel(modelPath: String,
inputs: Array[String],
outputs: Array[String],
intraOpParallelismThreads: Int,
interOpParallelismThreads: Int,
usePerSessionThreads: Boolean): Unit = {
clearModelQueue()
this.originalModel =
InferenceModelFactory.loadFloatModelForTFSavedModel(modelPath,
inputs, outputs, intraOpParallelismThreads, interOpParallelismThreads, usePerSessionThreads)
offerModelQueue()
}

private def doLoadTensorflowModelAsOpenVINO(modelPath: String,
modelType: String,
pipelineConfigPath: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ object InferenceModelFactory extends InferenceSupportive {
new FloatModel(model, metaModel, true)
}

def loadFloatModelForTFSavedModel(modelPath: String,
inputs: Array[String],
outputs: Array[String],
intraOpParallelismThreads: Int = 1,
interOpParallelismThreads: Int = 1,
usePerSessionThreads: Boolean = true): FloatModel = {
val sessionConfig = TFNet.SessionConfig(intraOpParallelismThreads,
interOpParallelismThreads, usePerSessionThreads)
val model = ModelLoader.loadFloatModelForTFSavedModel(modelPath, inputs, outputs, sessionConfig)
model.evaluate()
val metaModel = makeMetaModel(model)
new FloatModel(model, metaModel, true)
}


def loadOpenVINOModelForTF(modelPath: String,
modelType: String,
pipelineConfigPath: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,18 @@ object ModelLoader extends InferenceSupportive {
model
}
}

def loadFloatModelForTFSavedModel(modelPath: String,
inputs: Array[String],
outputs: Array[String],
config: TFNet.SessionConfig = TFNet.defaultSessionConfig)
: AbstractModule[Activity, Activity, Float] = {
timing("load model") {
logger.info(s"load model from $modelPath")
val model = TFNet.fromSavedModel(modelPath, inputs, outputs)
logger.info(s"loaded model as $model")
model
}
}
}

0 comments on commit ed3679c

Please sign in to comment.