diff --git a/orca/src/main/scala/com/intel/analytics/bigdl/orca/tfpark/TFModelBroadcast.scala b/orca/src/main/scala/com/intel/analytics/bigdl/orca/tfpark/TFModelBroadcast.scala index 2865890d6ff..4e78edaa953 100644 --- a/orca/src/main/scala/com/intel/analytics/bigdl/orca/tfpark/TFModelBroadcast.scala +++ b/orca/src/main/scala/com/intel/analytics/bigdl/orca/tfpark/TFModelBroadcast.scala @@ -16,7 +16,7 @@ package com.intel.analytics.zoo.tfpark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream, InvalidClassException, InputStream, ObjectStreamClass} import com.intel.analytics.bigdl.Module import com.intel.analytics.bigdl.models.utils.{CachedModels, ModelBroadcast, ModelInfo} @@ -117,7 +117,7 @@ class TFModelBroadcast[T: ClassTag]() } else { SerializationUtils.clone(broadcastParameters.value) } - // + // share weight putWeightBias(parameters, localModel) @@ -125,7 +125,6 @@ class TFModelBroadcast[T: ClassTag]() // if (localModel.isInstanceOf[Container[_, _, T]] && broadcastConsts.value.nonEmpty) { // putConsts(localModel.asInstanceOf[Container[_, _, T]], broadcastConsts.value) // } - // init gradient if (initGradient) { initGradWeightBias(broadcastParameters.value, localModel) } @@ -161,17 +160,28 @@ private[zoo] class ModelInfo[T: ClassTag](val uuid: String, @transient var model @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - model = in.readObject().asInstanceOf[Module[T]] + val vin = new ModelInfoObjectInputStream(in) + model = vin.readObject().asInstanceOf[Module[T]] CachedModels.add(uuid, model) } } - private[zoo] object ModelInfo { def apply[T: ClassTag](uuid: String, model: Module[T])( implicit ev: TensorNumeric[T]): ModelInfo[T] = new ModelInfo[T](uuid, model) } +private class ModelInfoObjectInputStream[T: ClassTag](val inputStream: InputStream) + extends ObjectInputStream(inputStream) { + @throws[IOException] + @throws[ClassNotFoundException] + override protected def resolveClass(desc: ObjectStreamClass): Class[_] = { + if (!desc.getName.equals(classOf[ModelInfo[T]].getName)) { + throw new InvalidClassException("Unrecognized Class", desc.getName) + } + super.resolveClass(desc) + } +} private[zoo] object CachedModels { @@ -183,7 +193,7 @@ private[zoo] object CachedModels { type Modles = ArrayBuffer[Module[_]] - private val cachedModels: concurrent.Map[String, Modles] = + private val cachedModels = new ConcurrentHashMap[String, Modles]().asScala def add[T: ClassTag](uuid: String, model: Module[T])(implicit ev: TensorNumeric[T]): Unit =