diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/zooKeras/metrics/AUC.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/zooKeras/metrics/AUC.scala index c3cc0e62e9b..61d61ebcfb3 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/zooKeras/metrics/AUC.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/zooKeras/metrics/AUC.scala @@ -131,15 +131,15 @@ class AUC[T](thresholdNum: Int = 200)(implicit ev: TensorNumeric[T]) override def apply(output: Activity, target: Activity): ValidationResult = { val _output = if (output.asInstanceOf[Tensor[T]].dim() == 2) { - output.asInstanceOf[Tensor[T]].squeeze(2) + output.asInstanceOf[Tensor[T]].clone().squeeze(2) } else { - output.asInstanceOf[Tensor[T]].squeeze() + output.asInstanceOf[Tensor[T]].clone().squeeze() } val _target = if (target.asInstanceOf[Tensor[T]].dim() == 2) { - target.asInstanceOf[Tensor[T]].squeeze(2) + target.asInstanceOf[Tensor[T]].clone().squeeze(2) } else { - target.asInstanceOf[Tensor[T]].squeeze() + target.asInstanceOf[Tensor[T]].clone().squeeze() } require(_output.dim() <= 2 && _target.dim() <= 2, s"${_output.dim()} dim format is not supported")