Skip to content

Commit

Permalink
support samples with different size to one mini batch (intel-analytic…
Browse files Browse the repository at this point in the history
…s#2929)

* add to batch with resize

* meet comments
  • Loading branch information
zhangxiaoli73 committed Oct 11, 2019
1 parent bdbdd8e commit 5d71611
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ object MTImageFeatureToBatch {
}
}

object MTImageFeatureToBatchWithResize {
/**
* The transformer from ImageFeature to mini-batches, and extract ROI labels for segmentation
* if roi labels are set.
* @param sizeDivisible when it's greater than 0, height and wide should be divisible by this size
* @param batchSize global batch size
* @param transformer pipeline for pre-processing
* @param toRGB if converted to RGB, default format is BGR
*/
def apply(sizeDivisible: Int = -1, batchSize: Int, transformer: FeatureTransformer,
toRGB : Boolean = false): MTImageFeatureToBatch =
new RoiImageFeatureToBatchWithResize(sizeDivisible, batchSize, transformer, toRGB)
}

/**
* An abstract class to convert ImageFeature iterator to MiniBatches. This transformer will be run
* on each image feature. "processImageFeature" will be called to buffer the image features. When
Expand Down Expand Up @@ -267,3 +281,66 @@ class RoiMTImageFeatureToBatch private[bigdl](width: Int, height: Int,
RoiMiniBatch(featureTensor, labelData.view, isCrowdData.view, origSizeData.view)
}
}

/**
* A transformer pipeline wrapper to create RoiMiniBatch in multiple threads.
* Image features may have different sizes, so firstly we need to calculate max size in one batch,
* then padding all features to one batch with max size.
* @param sizeDivisible when it's greater than 0,
* height and wide will be round up to multiple of this divisible size
* @param totalBatchSize global batch size
* @param transformer pipeline for pre-processing
* @param toRGB
*/
class RoiImageFeatureToBatchWithResize private[bigdl](sizeDivisible: Int = -1, totalBatchSize: Int,
transformer: FeatureTransformer, toRGB: Boolean = false)
extends MTImageFeatureToBatch(totalBatchSize, transformer) {

private val labelData: Array[RoiLabel] = new Array[RoiLabel](batchSize)
private val isCrowdData: Array[Tensor[Float]] = new Array[Tensor[Float]](batchSize)
private val origSizeData: Array[(Int, Int, Int)] = new Array[(Int, Int, Int)](batchSize)
private var featureTensor: Tensor[Float] = null
private val imageBuffer = new Array[Tensor[Float]](batchSize)

private def getFrameSize(batchSize: Int): (Int, Int) = {
var maxHeight = 0
var maxWide = 0
for (i <- 0 until batchSize) {
maxHeight = math.max(maxHeight, imageBuffer(i).size(2))
maxWide = math.max(maxWide, imageBuffer(i).size(3))
}

if (sizeDivisible > 0) {
maxHeight = (math.ceil(maxHeight.toFloat / sizeDivisible) * sizeDivisible).toInt
maxWide = (math.ceil(maxWide.toFloat / sizeDivisible) * sizeDivisible).toInt
}
(maxHeight, maxWide)
}

override protected def processImageFeature(img: ImageFeature, position: Int): Unit = {
if (imageBuffer(position) == null) imageBuffer(position) = Tensor[Float]()
imageBuffer(position).resize(3, img.getHeight(), img.getWidth())
// save img to buffer
img.copyTo(imageBuffer(position).storage().array(), 0, toRGB = toRGB)
val isCrowd = img(RoiLabel.ISCROWD).asInstanceOf[Tensor[Float]]
val label = img.getLabel.asInstanceOf[RoiLabel]
require(label.bboxes.size(1) == isCrowd.size(1), "The number of detections" +
"in ImageFeature's ISCROWD should be equal to the number of detections in the RoiLabel")
isCrowdData(position) = isCrowd
labelData(position) = label
origSizeData(position) = img.getOriginalSize
}

override protected def createBatch(batchSize: Int): MiniBatch[Float] = {
val (height, wide) = getFrameSize(batchSize)
if (featureTensor == null) featureTensor = Tensor()
featureTensor.resize(batchSize, 3, height, wide).fill(0.0f)
// copy img buffer to feature tensor
for (i <- 0 to (batchSize - 1)) {
featureTensor.select(1, i + 1).narrow(2, 1, imageBuffer(i).size(2))
.narrow(3, 1, imageBuffer(i).size(3)).copy(imageBuffer(i))
}
RoiMiniBatch(featureTensor, labelData.view(0, batchSize),
isCrowdData.view(0, batchSize), origSizeData.view(0, batchSize))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

package com.intel.analytics.bigdl.transform.vision.image

import com.intel.analytics.bigdl.dataset.DataSet
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.transform.vision.image.label.roi.RoiLabel
import com.intel.analytics.bigdl.utils.{Engine, Table}
import com.intel.analytics.bigdl.utils.{Engine, T, Table}
import org.apache.spark.SparkContext
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

Expand All @@ -36,6 +37,155 @@ class MTImageFeatureToBatchSpec extends FlatSpec with Matchers with BeforeAndAft
if (null != sc) sc.stop()
}

"RoiImageFeatureToBatchWithMaxSize" should "work well" in {
val imgData = (0 to 10).map(n => {
val data = Tensor[Float](T(T(
T(0.6336, 0.3563, 0.1053, 0.6912, 0.3791, 0.2707, 0.6270, 0.8446,
0.2008, 0.3051, 0.6324, 0.4001, 0.6439, 0.2275, 0.8395, 0.6917),
T(0.5191, 0.6917, 0.3929, 0.8765, 0.6981, 0.2679, 0.5423, 0.8095,
0.1022, 0.0215, 0.1976, 0.3040, 0.3436, 0.0894, 0.5207, 0.9173),
T(0.7829, 0.8493, 0.6865, 0.5468, 0.8769, 0.0055, 0.5274, 0.6638,
0.5623, 0.6986, 0.9963, 0.9332, 0.3322, 0.2322, 0.7539, 0.1027),
T(0.8297, 0.7903, 0.7254, 0.2109, 0.4015, 0.7729, 0.7242, 0.6415,
0.0452, 0.5547, 0.7091, 0.8217, 0.6968, 0.7594, 0.3986, 0.5862),
T(0.6075, 0.6215, 0.8243, 0.7298, 0.5886, 0.3655, 0.6750, 0.4722,
0.1140, 0.2483, 0.8853, 0.4583, 0.2110, 0.8364, 0.2063, 0.4120),
T(0.3350, 0.3226, 0.9264, 0.3657, 0.1387, 0.9268, 0.8490, 0.3405,
0.1999, 0.2797, 0.8620, 0.2984, 0.1121, 0.9285, 0.3487, 0.1860),
T(0.4850, 0.4671, 0.4069, 0.5200, 0.5928, 0.1164, 0.1781, 0.1367,
0.0951, 0.8707, 0.8220, 0.3016, 0.8646, 0.9668, 0.7803, 0.1323),
T(0.3663, 0.6169, 0.6257, 0.8451, 0.1146, 0.5394, 0.5738, 0.7960,
0.4786, 0.6590, 0.5803, 0.0800, 0.0975, 0.1009, 0.1835, 0.5978)),
T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990,
0.9438, 0.7067, 0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965),
T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943,
0.9369, 0.3792, 0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393),
T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577,
0.4105, 0.6823, 0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339),
T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442,
0.9533, 0.1570, 0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328),
T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712,
0.5637, 0.5811, 0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492),
T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993,
0.7467, 0.6185, 0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538),
T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525,
0.7647, 0.6433, 0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787),
T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376,
0.5597, 0.6067, 0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081)),
T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990,
0.9438, 0.7067, 0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965),
T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943,
0.9369, 0.3792, 0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393),
T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577,
0.4105, 0.6823, 0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339),
T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442,
0.9533, 0.1570, 0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328),
T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712,
0.5637, 0.5811, 0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492),
T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993,
0.7467, 0.6185, 0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538),
T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525,
0.7647, 0.6433, 0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787),
T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376,
0.5597, 0.6067, 0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081))))
.transpose(1, 2).transpose(2, 3).contiguous()

val imf = ImageFeature()
imf(ImageFeature.floats) = data.storage().array()
imf(ImageFeature.label) = RoiLabel(
Tensor(new Array[Float](2), Array(2)),
Tensor(new Array[Float](2*4), Array(2, 4)),
null
)
imf(RoiLabel.ISCROWD) = Tensor(Array(0f, 1f), Array(2))
imf(ImageFeature.originalSize) = (8, 16, 3)
imf
}).toArray

val transformer = MTImageFeatureToBatchWithResize(10, 3,
new FeatureTransformer {}, toRGB = false)
val miniBatch = transformer(DataSet.array(imgData).data(false))

val expectedOutput = Tensor[Float](T(T(
T(0.6336, 0.3563, 0.1053, 0.6912, 0.3791, 0.2707, 0.6270, 0.8446, 0.2008, 0.3051,
0.6324, 0.4001, 0.6439, 0.2275, 0.8395, 0.6917, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.5191, 0.6917, 0.3929, 0.8765, 0.6981, 0.2679, 0.5423, 0.8095, 0.1022, 0.0215,
0.1976, 0.3040, 0.3436, 0.0894, 0.5207, 0.9173, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.7829, 0.8493, 0.6865, 0.5468, 0.8769, 0.0055, 0.5274, 0.6638, 0.5623, 0.6986,
0.9963, 0.9332, 0.3322, 0.2322, 0.7539, 0.1027, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.8297, 0.7903, 0.7254, 0.2109, 0.4015, 0.7729, 0.7242, 0.6415, 0.0452, 0.5547,
0.7091, 0.8217, 0.6968, 0.7594, 0.3986, 0.5862, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.6075, 0.6215, 0.8243, 0.7298, 0.5886, 0.3655, 0.6750, 0.4722, 0.1140, 0.2483,
0.8853, 0.4583, 0.2110, 0.8364, 0.2063, 0.4120, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.3350, 0.3226, 0.9264, 0.3657, 0.1387, 0.9268, 0.8490, 0.3405, 0.1999, 0.2797,
0.8620, 0.2984, 0.1121, 0.9285, 0.3487, 0.1860, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.4850, 0.4671, 0.4069, 0.5200, 0.5928, 0.1164, 0.1781, 0.1367, 0.0951, 0.8707,
0.8220, 0.3016, 0.8646, 0.9668, 0.7803, 0.1323, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.3663, 0.6169, 0.6257, 0.8451, 0.1146, 0.5394, 0.5738, 0.7960, 0.4786, 0.6590,
0.5803, 0.0800, 0.0975, 0.1009, 0.1835, 0.5978, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000)),
T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990, 0.9438, 0.7067,
0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943, 0.9369, 0.3792,
0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577, 0.4105, 0.6823,
0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442, 0.9533, 0.1570,
0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712, 0.5637, 0.5811,
0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993, 0.7467, 0.6185,
0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525, 0.7647, 0.6433,
0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376, 0.5597, 0.6067,
0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000)),
T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990, 0.9438, 0.7067,
0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943, 0.9369, 0.3792,
0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577, 0.4105, 0.6823,
0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442, 0.9533, 0.1570,
0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712, 0.5637, 0.5811,
0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993, 0.7467, 0.6185,
0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525, 0.7647, 0.6433,
0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376, 0.5597, 0.6067,
0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000),
T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000))))

miniBatch.foreach(batch => {
(batch.size() <= 3) should be (true)
val input = batch.getInput().asInstanceOf[Tensor[Float]]
val target = batch.getTarget().asInstanceOf[Table]
input.size() should be (Array(batch.size(), 3, 10, 20))
target.length() should be (batch.size())
for(i <- 1 to batch.size()) {
val in = input.select(1, i)
in should be(expectedOutput)
val t = target(i).asInstanceOf[Table]
t[Tensor[Float]](RoiLabel.ISCROWD) should be (Tensor(Array(0f, 1f), Array(2)))
t[(Int, Int, Int)](RoiLabel.ORIGSIZE) should be((8, 16, 3))
t[Tensor[Float]](RoiLabel.BBOXES).size() should be (Array(2, 4))
t[Tensor[Float]](RoiLabel.CLASSES).size() should be (Array(2))
}
})
}

// todo: There is a race-condition bug in MTImageFeatureToBatch
/*
"MTImageFeatureToBatch classification" should "work well" in {
Expand Down

0 comments on commit 5d71611

Please sign in to comment.