From c5c2582a07d9f6f26543725ddfa8b986b16628cb Mon Sep 17 00:00:00 2001 From: zhangxiaoli73 <380761639@qq.com> Date: Mon, 28 Oct 2019 16:07:27 +0800 Subject: [PATCH] add maskrcnn inference example (#2944) * add maskrcnn inference example * meet pr comments * add model download url --- .../com/intel/analytics/bigdl/dllib/optim/Evaluator.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dl/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Evaluator.scala b/dl/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Evaluator.scala index 6089286f5ee..de091ecf689 100644 --- a/dl/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Evaluator.scala +++ b/dl/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Evaluator.scala @@ -88,14 +88,14 @@ class Evaluator[T: ClassTag] private[optim](model: Module[T])(implicit ev: Tenso vMethods: Array[ValidationMethod[T]] ): Array[(ValidationResult, ValidationMethod[T])] = { - val dummyInput = dataset.takeSample(withReplacement = false, num = 1).head.getInput() val rdd = ConversionUtils.coalesce(dataset) val modelBroad = ModelBroadcast[T]().broadcast(rdd.sparkContext, - ConversionUtils.convert(model.evaluate()), dummyInput) + ConversionUtils.convert(model.evaluate())) val otherBroad = rdd.sparkContext.broadcast(vMethods) + rdd.mapPartitions(miniBatch => { - val localModel = modelBroad.value(false, true, dummyInput) + val localModel = modelBroad.value() val localMethod = otherBroad.value miniBatch.map(batch => { val output = localModel.forward(batch.getInput())