forked from intel-analytics/BigDL-2.x
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add big TF model support (intel-analytics#2974)
* test * add print * add graphdef print * update * update * test * update * add print * update broadcast * fix spark file location * update broadcast * fix extra init * update broadcast * update property * update broadcast * restore broadcast * restore clone * fix clone * fix get extra * update collect weights * update * update * update property * update get extra param * update * update * restore * remove unused import * fix style * add methods * fix style
- Loading branch information
Showing
4 changed files
with
455 additions
and
27 deletions.
There are no files selected for viewing
395 changes: 395 additions & 0 deletions
395
orca/src/main/scala/com/intel/analytics/bigdl/orca/tfpark/TFModelBroadcast.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,395 @@ | ||
/* | ||
* Copyright 2018 Analytics Zoo Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.intel.analytics.zoo.tfpark | ||
|
||
import java.io.{IOException, ObjectInputStream, ObjectOutputStream} | ||
|
||
import com.intel.analytics.bigdl.Module | ||
import com.intel.analytics.bigdl.models.utils.{CachedModels, ModelBroadcast, ModelInfo} | ||
import com.intel.analytics.bigdl.nn.Container | ||
import com.intel.analytics.bigdl.nn.abstractnn.Activity | ||
import com.intel.analytics.bigdl.nn.mkldnn.{MklDnnLayer, TensorMMap} | ||
import com.intel.analytics.bigdl.nn.tf.Const | ||
import com.intel.analytics.bigdl.tensor.{QuantizedTensor, QuantizedType, Storage, Tensor} | ||
import com.intel.analytics.bigdl.tensor.TensorNumericMath.{NumericWildcard, TensorNumeric} | ||
import com.intel.analytics.bigdl.nn.Module | ||
import com.intel.analytics.bigdl.optim.DistriOptimizer.CacheV1 | ||
import com.intel.analytics.bigdl.utils.Engine | ||
import com.intel.analytics.bigdl.utils.intermediate.IRGraph | ||
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.EngineRef | ||
import com.intel.analytics.zoo.tfpark.Util._ | ||
import org.apache.commons.lang3.SerializationUtils | ||
import org.apache.spark.SparkContext | ||
import org.apache.spark.broadcast.Broadcast | ||
import org.apache.spark.rdd.RDD | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
import scala.reflect.ClassTag | ||
|
||
class TFModelBroadcast[T: ClassTag]() | ||
(implicit ev: TensorNumeric[T]) extends ModelBroadcast[T] { | ||
// private type NativeType = (String, (Array[TensorMMap], Array[TensorMMap])) | ||
private var broadcastModel: Broadcast[ModelInfo[T]] = _ | ||
private var broadcastConsts: Broadcast[Map[String, Tensor[_]]] = _ | ||
private var broadcastParameters: Broadcast[Array[Tensor[T]]] = _ | ||
private var broadcastExtraParameters: Broadcast[Array[Tensor[T]]] = _ | ||
// private var broadcastParametersNative: Broadcast[Array[NativeType]] = _ | ||
private var nodeNumber: Int = _ | ||
private var coreNumber: Int = _ | ||
|
||
private def setNodeAndCore(): Unit = { | ||
nodeNumber = EngineRef.getNodeNumber() | ||
coreNumber = EngineRef.getCoreNumber() | ||
} | ||
|
||
/** | ||
* broadcast the model | ||
* first get and clear Const values from the model | ||
* then get and clear the weight and bias parameters from the model | ||
* finally broadcast Const values, the parameters and model(without parameters) separately | ||
* | ||
* @param sc SparkContext | ||
* @param model model to broadcast | ||
* @return this | ||
*/ | ||
override def broadcast(sc: SparkContext, model: Module[T]): this.type = { | ||
CachedModels.deleteAll(uuid) // delete the models on driver | ||
|
||
|
||
// broadcast Consts | ||
// if (model.isInstanceOf[Container[_, _, T]]) { | ||
// val moduleConsts = getAndClearConsts(model.asInstanceOf[Container[_, _, T]]) | ||
// // TODO: broadcast Const, model structure and weight in the same broadcast. | ||
// broadcastConsts = sc.broadcast(moduleConsts) | ||
// } | ||
// broadcast weight and model | ||
val weightsBias = getAndClearWeightBias(model.parameters()) | ||
val extraParams = getAndClearExtraParameters(model.getExtraParameter()) | ||
broadcastModel = sc.broadcast(ModelInfo[T](uuid, model)) | ||
broadcastParameters = sc.broadcast(weightsBias) | ||
|
||
broadcastExtraParameters = sc.broadcast(extraParams) | ||
broadcastParameters = sc.broadcast(weightsBias) | ||
|
||
// For quantized model if we don't clone weightsBias, the original model will be released also | ||
// when we delete all models used in `ModelBroadcast`. | ||
putWeightBias(cloneParameters(weightsBias), model) | ||
initGradWeightBias(weightsBias, model) | ||
putExtraParams(extraParams, model) | ||
|
||
setNodeAndCore() | ||
this | ||
} | ||
|
||
/** | ||
* get the broadcast model | ||
* put the weight and bias back to the model | ||
* | ||
* @param initGradient If create a tensor for gradient when fetch the model. Please note that | ||
* the gradient is not needed in model inference | ||
* @return model | ||
*/ | ||
override def value(initGradient: Boolean = false, shareWeight: Boolean = true): Module[T] = { | ||
EngineRef.setCoreNumber(coreNumber) | ||
// Engine.setNodeAndCore(nodeNumber, coreNumber) | ||
CachedModels.deleteAll(this.uuid) | ||
|
||
val localModel = broadcastModel.value.model.cloneModule() | ||
val uuid = broadcastModel.value.uuid | ||
CachedModels.add(uuid, localModel) | ||
|
||
val parameters = if (shareWeight) { | ||
broadcastParameters.value | ||
} else { | ||
SerializationUtils.clone(broadcastParameters.value) | ||
} | ||
// | ||
// share weight | ||
putWeightBias(parameters, localModel) | ||
|
||
// // share Consts | ||
// if (localModel.isInstanceOf[Container[_, _, T]] && broadcastConsts.value.nonEmpty) { | ||
// putConsts(localModel.asInstanceOf[Container[_, _, T]], broadcastConsts.value) | ||
// } | ||
// init gradient | ||
if (initGradient) { | ||
initGradWeightBias(broadcastParameters.value, localModel) | ||
} | ||
|
||
putExtraParams(broadcastExtraParameters.value, localModel) | ||
|
||
localModel | ||
} | ||
|
||
override def broadcast(sc: SparkContext, model: Module[T], | ||
dummyInput: Activity): this.type = { | ||
this.broadcast(sc, model) | ||
this | ||
} | ||
|
||
override def value(initGradient: Boolean, shareWeight: Boolean, | ||
dummyInput: Activity): Module[T] = { | ||
val model = value(initGradient, shareWeight) | ||
model | ||
} | ||
} | ||
|
||
private[zoo] class ModelInfo[T: ClassTag](val uuid: String, @transient var model: Module[T])( | ||
implicit ev: TensorNumeric[T]) extends Serializable { | ||
@throws(classOf[IOException]) | ||
private def writeObject(out: ObjectOutputStream): Unit = { | ||
out.defaultWriteObject() | ||
val cloned = model.cloneModule() | ||
out.writeObject(cloned) | ||
CachedModels.add(uuid, cloned) | ||
} | ||
|
||
@throws(classOf[IOException]) | ||
private def readObject(in: ObjectInputStream): Unit = { | ||
in.defaultReadObject() | ||
model = in.readObject().asInstanceOf[Module[T]] | ||
CachedModels.add(uuid, model) | ||
} | ||
} | ||
|
||
|
||
private[zoo] object ModelInfo { | ||
def apply[T: ClassTag](uuid: String, model: Module[T])( | ||
implicit ev: TensorNumeric[T]): ModelInfo[T] = new ModelInfo[T](uuid, model) | ||
} | ||
|
||
|
||
private[zoo] object CachedModels { | ||
|
||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
import scala.collection._ | ||
import scala.collection.convert.decorateAsScala._ | ||
import scala.language.existentials | ||
|
||
type Modles = ArrayBuffer[Module[_]] | ||
|
||
private val cachedModels: concurrent.Map[String, Modles] = | ||
new ConcurrentHashMap[String, Modles]().asScala | ||
|
||
def add[T: ClassTag](uuid: String, model: Module[T])(implicit ev: TensorNumeric[T]): Unit = | ||
CachedModels.synchronized { | ||
val models = cachedModels.get(uuid) match { | ||
case Some(values) => values += model.asInstanceOf[Module[_]] | ||
case _ => ArrayBuffer(model.asInstanceOf[Module[_]]) | ||
} | ||
cachedModels.put(uuid, models.asInstanceOf[Modles]) | ||
} | ||
|
||
def deleteAll[T: ClassTag](currentKey: String)(implicit ev: TensorNumeric[T]): Unit = | ||
CachedModels.synchronized { | ||
val keys = cachedModels.keys | ||
for (key <- keys) { | ||
if (key != currentKey) { | ||
val models = cachedModels(key) | ||
for (model <- models) { | ||
model.release() | ||
} | ||
cachedModels.remove(key) | ||
} | ||
} | ||
} | ||
|
||
def deleteKey[T: ClassTag](key: String)(implicit ev: TensorNumeric[T]): Unit = | ||
CachedModels.synchronized { | ||
val keys = cachedModels.keys | ||
for (k <- keys) { | ||
if (k == key) { | ||
val models = cachedModels(key) | ||
for (model <- models) { | ||
model.release() | ||
} | ||
cachedModels.remove(key) | ||
} | ||
} | ||
} | ||
} | ||
|
||
object Util { | ||
|
||
private[zoo] def getAndClearWeightBias[T: ClassTag] | ||
(parameters: (Array[Tensor[T]], Array[Tensor[T]]))(implicit ev: TensorNumeric[T]) | ||
: Array[Tensor[T]] = { | ||
clearTensor(parameters._2) | ||
getAndClearParameters(parameters._1) | ||
} | ||
|
||
private[zoo] def getAndClearExtraParameters[T: ClassTag] | ||
(parameters: Array[Tensor[T]])(implicit ev: TensorNumeric[T]) | ||
: Array[Tensor[T]] = { | ||
getAndClearParameters(parameters) | ||
} | ||
|
||
private[zoo] def getAndClearParameters[T: ClassTag] | ||
(parameters: Array[Tensor[T]])(implicit ev: TensorNumeric[T]) | ||
: Array[Tensor[T]] = { | ||
if (parameters != null) { | ||
if (parameters.length != 0) { | ||
var i = 0 | ||
val retParams = new Array[Tensor[T]](parameters.length) | ||
// val isQuantized = parameters._1.exists(_.getTensorType == QuantizedType) | ||
val (isCompacted, storage) = { | ||
val storage = Storage(parameters(0).storage.array()) | ||
(parameters.map(_.nElement()).sum == storage.length(), storage) | ||
} | ||
|
||
// get parameters | ||
while (i < parameters.length) { | ||
if (parameters(i) != null) { | ||
val wb = parameters(i) | ||
retParams(i) = if (isCompacted) { | ||
Tensor[T](storage, wb.storageOffset(), wb.size(), wb.stride()) | ||
} else { | ||
Tensor[T](Storage(wb.storage().array()), wb.storageOffset(), wb.size(), wb.stride()) | ||
} | ||
i += 1 | ||
} | ||
} | ||
// clear parameters | ||
clearTensor(parameters) | ||
|
||
retParams | ||
} else { | ||
// just return an empty array when parameters is empty. | ||
Array() | ||
} | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
|
||
private def clearTensor[T: ClassTag](tensors: Array[Tensor[T]]) | ||
(implicit ev: TensorNumeric[T]): Unit = { | ||
if (tensors != null) { | ||
var i = 0 | ||
while (i < tensors.length) { | ||
if (tensors(i) != null) { | ||
tensors(i).set() | ||
} | ||
i += 1 | ||
} | ||
} | ||
} | ||
|
||
private[zoo] def putWeightBias[T: ClassTag](broadcastWeightBias: Array[Tensor[T]], | ||
localModel: Module[T])( | ||
implicit ev: TensorNumeric[T]): Unit = { | ||
val localWeightBias = localModel.parameters()._1 | ||
var i = 0 | ||
while (i < localWeightBias.length) { | ||
if (localWeightBias(i) != null) { | ||
clearAndSet(localWeightBias(i), broadcastWeightBias(i)) | ||
} | ||
i += 1 | ||
} | ||
|
||
def clearAndSet(old: Tensor[T], other: Tensor[T]): Unit = { | ||
old.set(other) | ||
} | ||
} | ||
|
||
private[zoo] def putExtraParams[T: ClassTag](broadcastExtraParams: Array[Tensor[T]], | ||
localModel: Module[T])( | ||
implicit ev: TensorNumeric[T]): Unit = { | ||
val localExtraParams = localModel.getExtraParameter() | ||
if (localExtraParams != null) { | ||
var i = 0 | ||
while (i < localExtraParams.length) { | ||
if (localExtraParams(i) != null) { | ||
localExtraParams(i).set(broadcastExtraParams(i)) | ||
|
||
} | ||
i += 1 | ||
} | ||
} | ||
|
||
} | ||
|
||
private[zoo] def initGradWeightBias[T: ClassTag](broadcastWeightBias: Array[Tensor[T]], | ||
localModel: Module[T])( | ||
implicit ev: TensorNumeric[T]): Unit = { | ||
val (localWeightBias, localGradWeightBias) = localModel.parameters() | ||
// init gradient with a compacted storage | ||
val storage = Storage[T](localGradWeightBias.map(_.nElement()).sum) | ||
val isQuantized = broadcastWeightBias.exists(_.getTensorType == QuantizedType) | ||
var i = 0 | ||
while (i < localWeightBias.length) { | ||
if (localWeightBias(i) != null) { | ||
val wb = broadcastWeightBias(i) | ||
wb.getTensorType match { | ||
case QuantizedType => | ||
localGradWeightBias(i).set(Tensor(1)) | ||
case _ => | ||
localGradWeightBias(i).set(storage, wb.storageOffset(), wb.size(), wb.stride()) | ||
} | ||
} | ||
i += 1 | ||
} | ||
} | ||
|
||
private[zoo] def cloneParameters[T: ClassTag] | ||
(parameters: Array[Tensor[T]])(implicit ev: TensorNumeric[T]) | ||
: Array[Tensor[T]] = { | ||
if (parameters != null) { | ||
if (parameters.length != 0) { | ||
var i = 0 | ||
val retParams = new Array[Tensor[T]](parameters.length) | ||
|
||
val (isCompacted, storage) = { | ||
val storage = Storage(parameters(0).storage.array()) | ||
(parameters.map(_.nElement()).sum == storage.length(), storage) | ||
} | ||
|
||
val resultStorage = if (isCompacted) { | ||
val resultStorage = Storage[T](storage.length()) | ||
System.arraycopy(storage.array(), parameters(0).storageOffset() - 1, | ||
resultStorage.array(), 0, storage.length()) | ||
resultStorage | ||
} else { | ||
null | ||
} | ||
|
||
// clone parameters | ||
while (i < parameters.length) { | ||
if (parameters(i) != null) { | ||
val wb = parameters(i) | ||
retParams(i) = if (isCompacted) { | ||
Tensor[T](resultStorage, wb.storageOffset(), wb.size(), wb.stride()) | ||
} else { | ||
wb.clone() | ||
} | ||
i += 1 | ||
} | ||
} | ||
|
||
retParams | ||
} else { | ||
// just return an empty array when parameters is empty. | ||
Array() | ||
} | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.