Skip to content

Commit

Permalink
move pretrain in DistriOptimizerV2 (intel-analytics#3016)
Browse files Browse the repository at this point in the history
* move getData

* rename

* remove time counting
  • Loading branch information
qiuxin2012 committed Jun 19, 2020
1 parent 0cf6e37 commit 6792cd6
Showing 1 changed file with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,17 @@ object DistriOptimizerV2 extends AbstractOptimizer {
val size = cached.parameter.size
val weights = cached.modelWeights.head.narrow(1, offset, size)

TrainingTrace.time (
cached.parameter.getWeights(weights).waitResult(),
val miniBatchBuffer = TrainingTrace.time (
{
val weightsResults = cached.parameter.getWeights(weights)
val batch = context.fetchBatch(data)
weightsResults.waitResult()
batch
},
metrics
)(Array(GET_WEIGHTS_AVERAGE, GET_WEIGHTS_EACH_NODE))

val results = train(cached, data, context, metrics)
val results = train(cached, miniBatchBuffer, context, metrics)

lossSum += results.loss
recordsNum += results.records
Expand Down Expand Up @@ -375,15 +380,14 @@ object DistriOptimizerV2 extends AbstractOptimizer {
private case class TrainingResults(successed: Int, loss: Double, records: Int)
private def train[T: ClassTag](
cached: Cache[T],
data: Iterator[MiniBatch[T]],
data: Array[MiniBatch[T]],
context: TrainingContext[T],
metrics: Metrics)(implicit ev: TensorNumeric[T]): TrainingResults = {
val miniBatchBuffer = context.preTrain(data)
val stackSize = miniBatchBuffer.head.size()
val stackSize = data.head.size()

// ======================Start train models===================================
val modelsResult = TrainingTrace.time (
context.train(miniBatchBuffer, cached.localModels, cached.localCriterions),
context.train(data, cached.localModels, cached.localCriterions),
metrics
)(Array(COMPUTING_TIME_EACH_NODE, COMPUTING_TIME_AVERAGE))

Expand Down Expand Up @@ -895,8 +899,7 @@ class TrainingContext[T: ClassTag](
recordsProcessed >= numSamples
}

def preTrain[T: ClassTag](data: Iterator[MiniBatch[T]]): Array[MiniBatch[T]] = {
val syWStart = System.nanoTime()
def fetchBatch[T: ClassTag](data: Iterator[MiniBatch[T]]): Array[MiniBatch[T]] = {
val miniBatchBuffer = new Array[MiniBatch[T]](subModelNumber)
val batch = data.next()
val stackSize = batch.size() / subModelNumber
Expand Down

0 comments on commit 6792cd6

Please sign in to comment.