Skip to content

Commit

Permalink
Fix AUC changing output and target (intel-analytics#2541)
Browse files Browse the repository at this point in the history
  • Loading branch information
hkvision committed Jul 8, 2020
1 parent c710643 commit f05594e
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ class AUC[T](thresholdNum: Int = 200)(implicit ev: TensorNumeric[T])
override def apply(output: Activity, target: Activity):
ValidationResult = {
val _output = if (output.asInstanceOf[Tensor[T]].dim() == 2) {
output.asInstanceOf[Tensor[T]].squeeze(2)
output.asInstanceOf[Tensor[T]].clone().squeeze(2)
} else {
output.asInstanceOf[Tensor[T]].squeeze()
output.asInstanceOf[Tensor[T]].clone().squeeze()
}

val _target = if (target.asInstanceOf[Tensor[T]].dim() == 2) {
target.asInstanceOf[Tensor[T]].squeeze(2)
target.asInstanceOf[Tensor[T]].clone().squeeze(2)
} else {
target.asInstanceOf[Tensor[T]].squeeze()
target.asInstanceOf[Tensor[T]].clone().squeeze()
}
require(_output.dim() <= 2 && _target.dim() <= 2,
s"${_output.dim()} dim format is not supported")
Expand Down

0 comments on commit f05594e

Please sign in to comment.