Skip to content

Commit

Permalink
Fix security issue in TFModelBroadcast (intel-analytics#3141)
Browse files Browse the repository at this point in the history
* update model broadcast

* update

* fix style

* fix style

* update
  • Loading branch information
jenniew committed Nov 29, 2020
1 parent 18d76e3 commit 2169b74
Showing 1 changed file with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.intel.analytics.zoo.tfpark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import java.io.{IOException, ObjectInputStream, ObjectOutputStream, InvalidClassException, InputStream, ObjectStreamClass}

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.models.utils.{CachedModels, ModelBroadcast, ModelInfo}
Expand Down Expand Up @@ -117,15 +117,14 @@ class TFModelBroadcast[T: ClassTag]()
} else {
SerializationUtils.clone(broadcastParameters.value)
}
//

// share weight
putWeightBias(parameters, localModel)

// // share Consts
// if (localModel.isInstanceOf[Container[_, _, T]] && broadcastConsts.value.nonEmpty) {
// putConsts(localModel.asInstanceOf[Container[_, _, T]], broadcastConsts.value)
// }
// init gradient
if (initGradient) {
initGradWeightBias(broadcastParameters.value, localModel)
}
Expand Down Expand Up @@ -161,17 +160,28 @@ private[zoo] class ModelInfo[T: ClassTag](val uuid: String, @transient var model
@throws(classOf[IOException])
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
model = in.readObject().asInstanceOf[Module[T]]
val vin = new ModelInfoObjectInputStream(in)
model = vin.readObject().asInstanceOf[Module[T]]
CachedModels.add(uuid, model)
}
}


private[zoo] object ModelInfo {
def apply[T: ClassTag](uuid: String, model: Module[T])(
implicit ev: TensorNumeric[T]): ModelInfo[T] = new ModelInfo[T](uuid, model)
}

private class ModelInfoObjectInputStream[T: ClassTag](val inputStream: InputStream)
extends ObjectInputStream(inputStream) {
@throws[IOException]
@throws[ClassNotFoundException]
override protected def resolveClass(desc: ObjectStreamClass): Class[_] = {
if (!desc.getName.equals(classOf[ModelInfo[T]].getName)) {
throw new InvalidClassException("Unrecognized Class", desc.getName)
}
super.resolveClass(desc)
}
}

private[zoo] object CachedModels {

Expand All @@ -183,7 +193,7 @@ private[zoo] object CachedModels {

type Modles = ArrayBuffer[Module[_]]

private val cachedModels: concurrent.Map[String, Modles] =
private val cachedModels =
new ConcurrentHashMap[String, Modles]().asScala

def add[T: ClassTag](uuid: String, model: Module[T])(implicit ev: TensorNumeric[T]): Unit =
Expand Down

0 comments on commit 2169b74

Please sign in to comment.