Skip to content

Commit

Permalink
fix loss when minibatch size is different (intel-analytics#3021)
Browse files Browse the repository at this point in the history
* fix loss

* fix ut
  • Loading branch information
qiuxin2012 committed Jul 1, 2020
1 parent 89992a6 commit dd96ee3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,9 @@ class Loss[@specialized(Float, Double)T: ClassTag](
output.toTensor[T]
}
val loss = ev.toType[Float](criterion.forward(_output, _target))
val count = 1
val count = _target.size().head

new LossResult(loss, count)
new LossResult(loss * count, count)
}

override def format(): String = "Loss"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class EvaluatorSpec extends SparkContextLifeCycle with Matchers {

result(0)._1 should be (new AccuracyResult(0, 100))
result(1)._1 should be (new AccuracyResult(100, 100))
result(2)._1 should be (new LossResult(57.610695f, 25))
result(2)._1 should be (new LossResult(230.44278f, 100))
result(0)._1.result()._1 should be (0f)
result(1)._1.result()._1 should be (1f)
result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
Expand All @@ -82,7 +82,31 @@ class EvaluatorSpec extends SparkContextLifeCycle with Matchers {

result(0)._1 should be (new AccuracyResult(0, 100))
result(1)._1 should be (new AccuracyResult(100, 100))
result(2)._1 should be (new LossResult(57.610695f, 25))
result(2)._1 should be (new LossResult(230.44278f, 100))
result(0)._1.result()._1 should be (0f)
result(1)._1.result()._1 should be (1f)
result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
}

"Evaluator different MiniBatch" should "be correct" in {
RNG.setSeed(100)
val tmp = new Array[MiniBatch[Float]](25)
var i = 1
while (i <= tmp.length) {
val input = Tensor[Float](i, 28, 28).fill(0.8f)
val label = Tensor[Float](i).fill(1.0f)
tmp(i - 1) = MiniBatch(input, label)
i += 1
}
val model = LeNet5(classNum = 10)
val dataSet = DataSet.array(tmp, sc).toDistributed().data(train = false)

val result = model.evaluate(dataSet, Array(new Top1Accuracy[Float](), new Top5Accuracy[Float](),
new Loss[Float](CrossEntropyCriterion[Float]())))

result(0)._1 should be (new AccuracyResult(0, 325))
result(1)._1 should be (new AccuracyResult(325, 325))
result(2)._1 should be (new LossResult(748.93896f, 325))
result(0)._1.result()._1 should be (0f)
result(1)._1.result()._1 should be (1f)
result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class DnnPredictorSpec extends FlatSpec with Matchers with BeforeAndAfter {
dataSet.partitions.length should be(4)
result(0)._1 should be (new AccuracyResult(0, 100))
result(1)._1 should be (new AccuracyResult(100, 100))
result(2)._1 should be (new LossResult(16.130993f, 7))
result(2)._1 should be (new LossResult(230.44278f, 100))
result(0)._1.result()._1 should be (0f)
result(1)._1.result()._1 should be (1f)
result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
Expand All @@ -167,7 +167,7 @@ class DnnPredictorSpec extends FlatSpec with Matchers with BeforeAndAfter {
dataSet.partitions.length should be(4)
result(0)._1 should be (new AccuracyResult(0, 100))
result(1)._1 should be (new AccuracyResult(100, 100))
result(2)._1 should be (new LossResult(57.610695f, 25))
result(2)._1 should be (new LossResult(230.44278f, 100))
result(0)._1.result()._1 should be (0f)
result(1)._1.result()._1 should be (1f)
result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
Expand Down

0 comments on commit dd96ee3

Please sign in to comment.