diff --git a/dl/src/main/scala/com/intel/analytics/bigdl/common/Util.scala b/dl/src/main/scala/com/intel/analytics/bigdl/common/Util.scala index 81337e07c47..3b539b6fcd7 100644 --- a/dl/src/main/scala/com/intel/analytics/bigdl/common/Util.scala +++ b/dl/src/main/scala/com/intel/analytics/bigdl/common/Util.scala @@ -19,7 +19,7 @@ package com.intel.analytics.bigdl.utils import java.io._ import com.intel.analytics.bigdl._ -import com.intel.analytics.bigdl.nn.{Container, Graph} +import com.intel.analytics.bigdl.nn.Container import com.intel.analytics.bigdl.nn.tf.Const import com.intel.analytics.bigdl.tensor.TensorNumericMath.{NumericWildcard, TensorNumeric} import com.intel.analytics.bigdl.tensor._ @@ -166,6 +166,10 @@ object Util { var i = 0 while (i < tensors.length) { if (tensors(i) != null) { + if (tensors(i).getTensorType == QuantizedType) { + tensors(i).toQuantizedTensor.release() + } + tensors(i).set() } i += 1 @@ -179,10 +183,23 @@ object Util { var i = 0 while (i < localWeightBias.length) { if (localWeightBias(i) != null) { - localWeightBias(i).set(broadcastWeightBias(i)) + clearAndSet(localWeightBias(i), broadcastWeightBias(i)) } i += 1 } + + def clearAndSet(old: Tensor[T], other: Tensor[T]): Unit = { + if (old.getTensorType == QuantizedType && other.getTensorType == QuantizedType) { + val quantOld = old.asInstanceOf[QuantizedTensor[T]] + val quantOther = other.asInstanceOf[QuantizedTensor[T]] + + if (quantOld.getNativeStorage != quantOther.getNativeStorage) { + quantOld.release() + } + } + + old.set(other) + } } private[bigdl] def initGradWeightBias[T: ClassTag](