diff --git a/bigdl/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/DistriOptimizerV2.scala b/bigdl/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/DistriOptimizerV2.scala index cdaa781f773..a9ea527c3e8 100644 --- a/bigdl/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/DistriOptimizerV2.scala +++ b/bigdl/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/DistriOptimizerV2.scala @@ -186,12 +186,17 @@ object DistriOptimizerV2 extends AbstractOptimizer { val size = cached.parameter.size val weights = cached.modelWeights.head.narrow(1, offset, size) - TrainingTrace.time ( - cached.parameter.getWeights(weights).waitResult(), + val miniBatchBuffer = TrainingTrace.time ( + { + val weightsResults = cached.parameter.getWeights(weights) + val batch = context.fetchBatch(data) + weightsResults.waitResult() + batch + }, metrics )(Array(GET_WEIGHTS_AVERAGE, GET_WEIGHTS_EACH_NODE)) - val results = train(cached, data, context, metrics) + val results = train(cached, miniBatchBuffer, context, metrics) lossSum += results.loss recordsNum += results.records @@ -375,15 +380,14 @@ object DistriOptimizerV2 extends AbstractOptimizer { private case class TrainingResults(successed: Int, loss: Double, records: Int) private def train[T: ClassTag]( cached: Cache[T], - data: Iterator[MiniBatch[T]], + data: Array[MiniBatch[T]], context: TrainingContext[T], metrics: Metrics)(implicit ev: TensorNumeric[T]): TrainingResults = { - val miniBatchBuffer = context.preTrain(data) - val stackSize = miniBatchBuffer.head.size() + val stackSize = data.head.size() // ======================Start train models=================================== val modelsResult = TrainingTrace.time ( - context.train(miniBatchBuffer, cached.localModels, cached.localCriterions), + context.train(data, cached.localModels, cached.localCriterions), metrics )(Array(COMPUTING_TIME_EACH_NODE, COMPUTING_TIME_AVERAGE)) @@ -895,8 +899,7 @@ class TrainingContext[T: ClassTag]( recordsProcessed >= numSamples } - def preTrain[T: ClassTag](data: Iterator[MiniBatch[T]]): Array[MiniBatch[T]] = { - val syWStart = System.nanoTime() + def fetchBatch[T: ClassTag](data: Iterator[MiniBatch[T]]): Array[MiniBatch[T]] = { val miniBatchBuffer = new Array[MiniBatch[T]](subModelNumber) val batch = data.next() val stackSize = batch.size() / subModelNumber