Skip to content

Commit

Permalink
hotfix ClassNLLCriterion with cloned target (#3081)
Browse files Browse the repository at this point in the history
* hotfix ClassNLLCriterion with cloned target
  • Loading branch information
Le-Zheng committed Nov 23, 2020
1 parent 0b91c45 commit 2a25887
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ class ClassNLLCriterion[@specialized(Float, Double) T: ClassTag]
s"input dim(${input.dim()})")
val nClasses = input.size(input.dim())
if (input.dim() == 1) {
if (target.dim() == 2 && target.size(1) == 1) {
target.squeeze()
val newTarget = if (target.dim() == 2 && target.size(1) == 1) {
target.clone().squeeze()
} else {
target
}
require(input.dim() == target.dim(),
require(input.dim() == newTarget.dim(),
"ClassNLLCriterion: " + ErrorInfo.constrainInputDimSameAsTarget +
s" Input dimension is: ${ input.dim() } , target dimension is: ${ target.dim() }")
val curTarget = ev.toType[Int](target.valueAt(1))
s" Input dimension is: ${ input.dim() } , target dimension is: ${ newTarget.dim() }")
val curTarget = ev.toType[Int](newTarget.valueAt(1))
assert(curTarget >= 1 && curTarget <= nClasses || curTarget == paddingValue,
s"curTarget ${curTarget} is out of range, should be 1 to ${nClasses}")
total_weight = if (weights != null) weights(Array(curTarget)) else ev.fromType[Int](1)
Expand Down

0 comments on commit 2a25887

Please sign in to comment.