-
Notifications
You must be signed in to change notification settings - Fork 729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incremental Training for imagenet #1391
Changes from 15 commits
0bf1ebe
39aaea7
a3e1ca7
f01d657
a020f63
ba531d5
65135bd
401819f
918ba2f
4f9a1e2
4825759
41fe249
537a463
c1adca0
a316ae3
285c184
033cc4d
32b0fae
caa46e8
c07cced
f40c529
3c4f518
1abf724
f9efc2c
d3476c7
2fc958f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package com.intel.analytics.zoo.common | ||
|
||
import com.intel.analytics.bigdl.optim.Trigger | ||
import com.intel.analytics.bigdl.utils.{T, Table} | ||
|
||
/** | ||
* A trigger specifies a timespot or several timespots during training, | ||
* and a corresponding action will be taken when the timespot(s) | ||
* is reached. | ||
*/ | ||
trait ZooTrigger extends Trigger { | ||
protected var zooState: Table = T() | ||
|
||
/** | ||
* We also hold some training metrics to control trigger. | ||
* @param zooState zoo state table | ||
*/ | ||
private[zoo] def setZooState(zooState: Table): Unit = { | ||
this.zooState = zooState | ||
} | ||
} | ||
|
||
/** | ||
* A trigger that triggers an action when each epoch finishs. | ||
* Could be used as trigger in setValidation and setCheckpoint | ||
* in Optimizer, and also in TrainSummary.setSummaryTrigger. | ||
*/ | ||
case class EveryEpoch() extends ZooTrigger{ | ||
private var lastEpoch = -1 | ||
|
||
override def apply(state: Table): Boolean = { | ||
if (lastEpoch == -1) { | ||
lastEpoch = state[Int]("epoch") | ||
false | ||
} else { | ||
if (state[Int]("epoch") <= lastEpoch) { | ||
false | ||
} else { | ||
if (zooState.contains("numSlice") && zooState.contains("currentSlice") | ||
&& zooState[Int]("currentSlice") % zooState[Int]("numSlice") == 0) { | ||
lastEpoch = state[Int]("epoch") | ||
true | ||
} else { | ||
false | ||
} | ||
} | ||
} | ||
} | ||
} | ||
/** | ||
* A trigger that triggers an action every "n" iterations. | ||
* Could be used as trigger in setValidation and setCheckpoint | ||
* in Optimizer, and also in TrainSummary.setSummaryTrigger. | ||
* | ||
* @param interval - trigger interval "n" | ||
*/ | ||
case class SeveralIteration(interval: Int) extends ZooTrigger{ | ||
override def apply(state: Table): Boolean = { | ||
val curIteration = state[Int]("neval") - 1 | ||
curIteration != 0 && curIteration % interval == 0 | ||
} | ||
} | ||
|
||
/** | ||
* A trigger that triggers an action when training reaches | ||
* the number of epochs specified by "max". | ||
* Usually used in Optimizer.setEndWhen. | ||
* | ||
* @param max the epoch when the action takes place | ||
*/ | ||
case class MaxEpoch(max: Int) extends ZooTrigger{ | ||
override def apply(state: Table): Boolean = { | ||
state[Int]("epoch") > max | ||
} | ||
} | ||
|
||
/** | ||
* A trigger that triggers an action when training reaches | ||
* the number of iterations specified by "max". | ||
* Usually used in Optimizer.setEndWhen. | ||
* | ||
* @param max the iteration when the action takes place | ||
* | ||
*/ | ||
case class MaxIteration(max: Int) extends ZooTrigger { | ||
override def apply(state: Table): Boolean = { | ||
state[Int]("neval") > max | ||
} | ||
} | ||
|
||
/** | ||
* A trigger that triggers an action when validation score larger than "max" score | ||
* @param max max score | ||
*/ | ||
case class MaxScore(max: Float) extends ZooTrigger { | ||
override def apply(state: Table): Boolean = { | ||
state[Float]("score") > max | ||
} | ||
} | ||
|
||
/** | ||
* A trigger that triggers an action when training loss less than "min" loss | ||
* @param min min loss | ||
*/ | ||
case class MinLoss(min: Float) extends ZooTrigger { | ||
override def apply(state: Table): Boolean = { | ||
state[Float]("Loss") < min | ||
} | ||
} | ||
|
||
/** | ||
* A trigger contains other triggers and triggers when all of them trigger (logical AND) | ||
* @param first first trigger | ||
* @param others others triggers | ||
*/ | ||
case class And(first : ZooTrigger, others : ZooTrigger*) extends ZooTrigger { | ||
override def setZooState(zooState: Table): Unit = { | ||
super.setZooState(zooState) | ||
first.setZooState(zooState) | ||
others.foreach{zt => | ||
zt.setZooState(zooState) | ||
} | ||
} | ||
|
||
override def apply(state: Table): Boolean = { | ||
first.apply(state) && others.forall(_.apply(state)) | ||
} | ||
} | ||
|
||
/** | ||
* A trigger contains other triggers and triggers when any of them trigger (logical OR) | ||
* @param first first trigger | ||
* @param others others triggers | ||
*/ | ||
case class Or(first : ZooTrigger, others : ZooTrigger*) extends ZooTrigger { | ||
override def setZooState(zooState: Table): Unit = { | ||
super.setZooState(zooState) | ||
first.setZooState(zooState) | ||
others.foreach{zt => | ||
zt.setZooState(zooState) | ||
} | ||
} | ||
|
||
override def apply(state: Table): Boolean = { | ||
first.apply(state) || others.exists(_.apply(state)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,14 +19,13 @@ package com.intel.analytics.zoo.examples.inception | |
import java.nio.ByteBuffer | ||
|
||
import com.intel.analytics.bigdl.dataset._ | ||
import com.intel.analytics.bigdl.dataset.image.CropCenter | ||
import com.intel.analytics.bigdl.dataset.image.{BGRImgCropper, BGRImgNormalizer, BytesToBGRImg, MTLabeledBGRImgToBatch, HFlip => DatasetHFlip} | ||
import com.intel.analytics.bigdl.dataset.image.{BGRImgCropper, BGRImgNormalizer, BGRImgToSample, BytesToBGRImg, CropCenter, MTLabeledBGRImgToBatch, HFlip => DatasetHFlip} | ||
import com.intel.analytics.bigdl.tensor.Tensor | ||
import com.intel.analytics.bigdl.transform.vision.image._ | ||
import com.intel.analytics.bigdl.utils.{Engine, T} | ||
import com.intel.analytics.zoo.feature.image._ | ||
import com.intel.analytics.zoo.feature.{DistributedFeatureSet, FeatureSet} | ||
import com.intel.analytics.zoo.feature.pmem.{DRAM, MemoryType, PMEM} | ||
import com.intel.analytics.zoo.feature.pmem._ | ||
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.EngineRef | ||
import org.apache.hadoop.io.Text | ||
import org.apache.log4j.Logger | ||
|
@@ -78,27 +77,24 @@ object ImageNet2012 { | |
coresPerNode: Int, | ||
classNumber: Int, | ||
memoryType: MemoryType = DRAM, | ||
opencvPreprocessing: Boolean = false | ||
opencvPreprocessing: Boolean = false, | ||
dataStrategy: DataStrategy = PARTITIONED | ||
) | ||
: FeatureSet[MiniBatch[Float]] = { | ||
if (opencvPreprocessing) { | ||
logger.info("Using opencv preprocessing for training set") | ||
opencv(path, sc, imageSize, batchSize, | ||
nodeNumber, coresPerNode, classNumber, memoryType) | ||
nodeNumber, coresPerNode, classNumber, memoryType, dataStrategy) | ||
} else { | ||
val rawData = readFromSeqFiles(path, sc, classNumber) | ||
.setName("ImageNet2012 Training Set") | ||
val featureSet = FeatureSet.rdd(rawData, memoryType = memoryType) | ||
featureSet.transform( | ||
MTLabeledBGRImgToBatch[ByteRecord]( | ||
width = imageSize, | ||
height = imageSize, | ||
batchSize = batchSize, | ||
transformer = (BytesToBGRImg() | ||
val featureSet = FeatureSet.rdd(rawData, memoryType = memoryType, dataStrategy) | ||
featureSet.transform(BytesToBGRImg() | ||
-> BGRImgCropper(imageSize, imageSize) | ||
-> DatasetHFlip(0.5) | ||
-> BGRImgNormalizer(0.485, 0.456, 0.406, 0.229, 0.224, 0.225)) | ||
)) | ||
-> BGRImgNormalizer(0.485, 0.456, 0.406, 0.229, 0.224, 0.225) | ||
-> BGRImgToSample() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why add toSample and toBatch here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As the origin rdd is persisted on disk, when count the size of the dataset, MTLabeledBGRImgToBatch will throw an exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reverted, As I won't count the hold rdd on disk. |
||
-> SampleToMiniBatch(batchSize)) | ||
} | ||
} | ||
|
||
|
@@ -123,11 +119,12 @@ object ImageNet2012 { | |
nodeNumber: Int, | ||
coresPerNode: Int, | ||
classNumber: Int, | ||
memoryType: MemoryType = DRAM): FeatureSet[MiniBatch[Float]] = { | ||
memoryType: MemoryType = DRAM, | ||
dataStrategy: DataStrategy = PARTITIONED): FeatureSet[MiniBatch[Float]] = { | ||
val rawData = readFromSeqFiles(path, sc, classNumber) | ||
.map(byteRecordToImageFeature(_)) | ||
.setName("ImageNet2012 Training Set") | ||
val featureSet = FeatureSet.rdd(rawData, memoryType = memoryType) | ||
val featureSet = FeatureSet.rdd(rawData, memoryType = memoryType, dataStrategy) | ||
val transformer = ImagePixelBytesToMat() -> | ||
ImageRandomCrop(imageSize, imageSize) -> | ||
ImageChannelNormalize(0.485f, 0.456f, 0.406f, 0.229f, 0.224f, 0.225f) -> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If memoryType DRAM is used, the validation phase is ignored at the end of each epoch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed