diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/feature/transform/vision/image/MTImageFeatureToBatch.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/feature/transform/vision/image/MTImageFeatureToBatch.scala index 37c28cdb01a..20f1fb409a4 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/feature/transform/vision/image/MTImageFeatureToBatch.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/dllib/feature/transform/vision/image/MTImageFeatureToBatch.scala @@ -49,6 +49,20 @@ object MTImageFeatureToBatch { } } +object MTImageFeatureToBatchWithResize { + /** + * The transformer from ImageFeature to mini-batches, and extract ROI labels for segmentation + * if roi labels are set. + * @param sizeDivisible when it's greater than 0, height and wide should be divisible by this size + * @param batchSize global batch size + * @param transformer pipeline for pre-processing + * @param toRGB if converted to RGB, default format is BGR + */ + def apply(sizeDivisible: Int = -1, batchSize: Int, transformer: FeatureTransformer, + toRGB : Boolean = false): MTImageFeatureToBatch = + new RoiImageFeatureToBatchWithResize(sizeDivisible, batchSize, transformer, toRGB) +} + /** * An abstract class to convert ImageFeature iterator to MiniBatches. This transformer will be run * on each image feature. "processImageFeature" will be called to buffer the image features. When @@ -267,3 +281,66 @@ class RoiMTImageFeatureToBatch private[bigdl](width: Int, height: Int, RoiMiniBatch(featureTensor, labelData.view, isCrowdData.view, origSizeData.view) } } + +/** + * A transformer pipeline wrapper to create RoiMiniBatch in multiple threads. + * Image features may have different sizes, so firstly we need to calculate max size in one batch, + * then padding all features to one batch with max size. + * @param sizeDivisible when it's greater than 0, + * height and wide will be round up to multiple of this divisible size + * @param totalBatchSize global batch size + * @param transformer pipeline for pre-processing + * @param toRGB + */ +class RoiImageFeatureToBatchWithResize private[bigdl](sizeDivisible: Int = -1, totalBatchSize: Int, + transformer: FeatureTransformer, toRGB: Boolean = false) + extends MTImageFeatureToBatch(totalBatchSize, transformer) { + + private val labelData: Array[RoiLabel] = new Array[RoiLabel](batchSize) + private val isCrowdData: Array[Tensor[Float]] = new Array[Tensor[Float]](batchSize) + private val origSizeData: Array[(Int, Int, Int)] = new Array[(Int, Int, Int)](batchSize) + private var featureTensor: Tensor[Float] = null + private val imageBuffer = new Array[Tensor[Float]](batchSize) + + private def getFrameSize(batchSize: Int): (Int, Int) = { + var maxHeight = 0 + var maxWide = 0 + for (i <- 0 until batchSize) { + maxHeight = math.max(maxHeight, imageBuffer(i).size(2)) + maxWide = math.max(maxWide, imageBuffer(i).size(3)) + } + + if (sizeDivisible > 0) { + maxHeight = (math.ceil(maxHeight.toFloat / sizeDivisible) * sizeDivisible).toInt + maxWide = (math.ceil(maxWide.toFloat / sizeDivisible) * sizeDivisible).toInt + } + (maxHeight, maxWide) + } + + override protected def processImageFeature(img: ImageFeature, position: Int): Unit = { + if (imageBuffer(position) == null) imageBuffer(position) = Tensor[Float]() + imageBuffer(position).resize(3, img.getHeight(), img.getWidth()) + // save img to buffer + img.copyTo(imageBuffer(position).storage().array(), 0, toRGB = toRGB) + val isCrowd = img(RoiLabel.ISCROWD).asInstanceOf[Tensor[Float]] + val label = img.getLabel.asInstanceOf[RoiLabel] + require(label.bboxes.size(1) == isCrowd.size(1), "The number of detections" + + "in ImageFeature's ISCROWD should be equal to the number of detections in the RoiLabel") + isCrowdData(position) = isCrowd + labelData(position) = label + origSizeData(position) = img.getOriginalSize + } + + override protected def createBatch(batchSize: Int): MiniBatch[Float] = { + val (height, wide) = getFrameSize(batchSize) + if (featureTensor == null) featureTensor = Tensor() + featureTensor.resize(batchSize, 3, height, wide).fill(0.0f) + // copy img buffer to feature tensor + for (i <- 0 to (batchSize - 1)) { + featureTensor.select(1, i + 1).narrow(2, 1, imageBuffer(i).size(2)) + .narrow(3, 1, imageBuffer(i).size(3)).copy(imageBuffer(i)) + } + RoiMiniBatch(featureTensor, labelData.view(0, batchSize), + isCrowdData.view(0, batchSize), origSizeData.view(0, batchSize)) + } +} diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/transform/vision/image/MTImageFeatureToBatchSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/transform/vision/image/MTImageFeatureToBatchSpec.scala index 968e1603b84..adc9bd2e834 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/transform/vision/image/MTImageFeatureToBatchSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/transform/vision/image/MTImageFeatureToBatchSpec.scala @@ -16,9 +16,10 @@ package com.intel.analytics.bigdl.transform.vision.image +import com.intel.analytics.bigdl.dataset.DataSet import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.transform.vision.image.label.roi.RoiLabel -import com.intel.analytics.bigdl.utils.{Engine, Table} +import com.intel.analytics.bigdl.utils.{Engine, T, Table} import org.apache.spark.SparkContext import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} @@ -36,6 +37,155 @@ class MTImageFeatureToBatchSpec extends FlatSpec with Matchers with BeforeAndAft if (null != sc) sc.stop() } + "RoiImageFeatureToBatchWithMaxSize" should "work well" in { + val imgData = (0 to 10).map(n => { + val data = Tensor[Float](T(T( + T(0.6336, 0.3563, 0.1053, 0.6912, 0.3791, 0.2707, 0.6270, 0.8446, + 0.2008, 0.3051, 0.6324, 0.4001, 0.6439, 0.2275, 0.8395, 0.6917), + T(0.5191, 0.6917, 0.3929, 0.8765, 0.6981, 0.2679, 0.5423, 0.8095, + 0.1022, 0.0215, 0.1976, 0.3040, 0.3436, 0.0894, 0.5207, 0.9173), + T(0.7829, 0.8493, 0.6865, 0.5468, 0.8769, 0.0055, 0.5274, 0.6638, + 0.5623, 0.6986, 0.9963, 0.9332, 0.3322, 0.2322, 0.7539, 0.1027), + T(0.8297, 0.7903, 0.7254, 0.2109, 0.4015, 0.7729, 0.7242, 0.6415, + 0.0452, 0.5547, 0.7091, 0.8217, 0.6968, 0.7594, 0.3986, 0.5862), + T(0.6075, 0.6215, 0.8243, 0.7298, 0.5886, 0.3655, 0.6750, 0.4722, + 0.1140, 0.2483, 0.8853, 0.4583, 0.2110, 0.8364, 0.2063, 0.4120), + T(0.3350, 0.3226, 0.9264, 0.3657, 0.1387, 0.9268, 0.8490, 0.3405, + 0.1999, 0.2797, 0.8620, 0.2984, 0.1121, 0.9285, 0.3487, 0.1860), + T(0.4850, 0.4671, 0.4069, 0.5200, 0.5928, 0.1164, 0.1781, 0.1367, + 0.0951, 0.8707, 0.8220, 0.3016, 0.8646, 0.9668, 0.7803, 0.1323), + T(0.3663, 0.6169, 0.6257, 0.8451, 0.1146, 0.5394, 0.5738, 0.7960, + 0.4786, 0.6590, 0.5803, 0.0800, 0.0975, 0.1009, 0.1835, 0.5978)), + T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990, + 0.9438, 0.7067, 0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965), + T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943, + 0.9369, 0.3792, 0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393), + T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577, + 0.4105, 0.6823, 0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339), + T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442, + 0.9533, 0.1570, 0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328), + T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712, + 0.5637, 0.5811, 0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492), + T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993, + 0.7467, 0.6185, 0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538), + T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525, + 0.7647, 0.6433, 0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787), + T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376, + 0.5597, 0.6067, 0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081)), + T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990, + 0.9438, 0.7067, 0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965), + T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943, + 0.9369, 0.3792, 0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393), + T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577, + 0.4105, 0.6823, 0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339), + T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442, + 0.9533, 0.1570, 0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328), + T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712, + 0.5637, 0.5811, 0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492), + T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993, + 0.7467, 0.6185, 0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538), + T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525, + 0.7647, 0.6433, 0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787), + T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376, + 0.5597, 0.6067, 0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081)))) + .transpose(1, 2).transpose(2, 3).contiguous() + + val imf = ImageFeature() + imf(ImageFeature.floats) = data.storage().array() + imf(ImageFeature.label) = RoiLabel( + Tensor(new Array[Float](2), Array(2)), + Tensor(new Array[Float](2*4), Array(2, 4)), + null + ) + imf(RoiLabel.ISCROWD) = Tensor(Array(0f, 1f), Array(2)) + imf(ImageFeature.originalSize) = (8, 16, 3) + imf + }).toArray + + val transformer = MTImageFeatureToBatchWithResize(10, 3, + new FeatureTransformer {}, toRGB = false) + val miniBatch = transformer(DataSet.array(imgData).data(false)) + + val expectedOutput = Tensor[Float](T(T( + T(0.6336, 0.3563, 0.1053, 0.6912, 0.3791, 0.2707, 0.6270, 0.8446, 0.2008, 0.3051, + 0.6324, 0.4001, 0.6439, 0.2275, 0.8395, 0.6917, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.5191, 0.6917, 0.3929, 0.8765, 0.6981, 0.2679, 0.5423, 0.8095, 0.1022, 0.0215, + 0.1976, 0.3040, 0.3436, 0.0894, 0.5207, 0.9173, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.7829, 0.8493, 0.6865, 0.5468, 0.8769, 0.0055, 0.5274, 0.6638, 0.5623, 0.6986, + 0.9963, 0.9332, 0.3322, 0.2322, 0.7539, 0.1027, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.8297, 0.7903, 0.7254, 0.2109, 0.4015, 0.7729, 0.7242, 0.6415, 0.0452, 0.5547, + 0.7091, 0.8217, 0.6968, 0.7594, 0.3986, 0.5862, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.6075, 0.6215, 0.8243, 0.7298, 0.5886, 0.3655, 0.6750, 0.4722, 0.1140, 0.2483, + 0.8853, 0.4583, 0.2110, 0.8364, 0.2063, 0.4120, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.3350, 0.3226, 0.9264, 0.3657, 0.1387, 0.9268, 0.8490, 0.3405, 0.1999, 0.2797, + 0.8620, 0.2984, 0.1121, 0.9285, 0.3487, 0.1860, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.4850, 0.4671, 0.4069, 0.5200, 0.5928, 0.1164, 0.1781, 0.1367, 0.0951, 0.8707, + 0.8220, 0.3016, 0.8646, 0.9668, 0.7803, 0.1323, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.3663, 0.6169, 0.6257, 0.8451, 0.1146, 0.5394, 0.5738, 0.7960, 0.4786, 0.6590, + 0.5803, 0.0800, 0.0975, 0.1009, 0.1835, 0.5978, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000)), + T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990, 0.9438, 0.7067, + 0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943, 0.9369, 0.3792, + 0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577, 0.4105, 0.6823, + 0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442, 0.9533, 0.1570, + 0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712, 0.5637, 0.5811, + 0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993, 0.7467, 0.6185, + 0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525, 0.7647, 0.6433, + 0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376, 0.5597, 0.6067, + 0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000)), + T(T(0.6848, 0.7909, 0.0584, 0.5309, 0.5087, 0.3893, 0.5740, 0.8990, 0.9438, 0.7067, + 0.3653, 0.1513, 0.8279, 0.6395, 0.6875, 0.8965, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.8340, 0.4398, 0.5573, 0.2817, 0.1441, 0.7729, 0.0940, 0.9943, 0.9369, 0.3792, + 0.1262, 0.7556, 0.5480, 0.6573, 0.5901, 0.0393, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.1406, 0.5208, 0.4751, 0.6157, 0.5476, 0.9403, 0.0226, 0.6577, 0.4105, 0.6823, + 0.2789, 0.5607, 0.0228, 0.4178, 0.7816, 0.5339, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.6371, 0.0603, 0.3195, 0.6144, 0.2042, 0.1585, 0.1249, 0.9442, 0.9533, 0.1570, + 0.8457, 0.1685, 0.2243, 0.3009, 0.2149, 0.1328, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.7049, 0.6040, 0.5683, 0.3084, 0.2516, 0.1883, 0.0982, 0.7712, 0.5637, 0.5811, + 0.1678, 0.3323, 0.9634, 0.5855, 0.4315, 0.8492, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.6626, 0.1401, 0.7042, 0.3153, 0.6940, 0.5070, 0.6723, 0.6993, 0.7467, 0.6185, + 0.8907, 0.3982, 0.6435, 0.5429, 0.2580, 0.7538, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.3496, 0.3059, 0.1777, 0.7922, 0.9832, 0.5681, 0.6051, 0.1525, 0.7647, 0.6433, + 0.8886, 0.8596, 0.6976, 0.1161, 0.0092, 0.1787, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0386, 0.8511, 0.4545, 0.1208, 0.2020, 0.7471, 0.7825, 0.3376, 0.5597, 0.6067, + 0.8809, 0.6917, 0.1960, 0.4223, 0.9569, 0.6081, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000), + T(0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000)))) + + miniBatch.foreach(batch => { + (batch.size() <= 3) should be (true) + val input = batch.getInput().asInstanceOf[Tensor[Float]] + val target = batch.getTarget().asInstanceOf[Table] + input.size() should be (Array(batch.size(), 3, 10, 20)) + target.length() should be (batch.size()) + for(i <- 1 to batch.size()) { + val in = input.select(1, i) + in should be(expectedOutput) + val t = target(i).asInstanceOf[Table] + t[Tensor[Float]](RoiLabel.ISCROWD) should be (Tensor(Array(0f, 1f), Array(2))) + t[(Int, Int, Int)](RoiLabel.ORIGSIZE) should be((8, 16, 3)) + t[Tensor[Float]](RoiLabel.BBOXES).size() should be (Array(2, 4)) + t[Tensor[Float]](RoiLabel.CLASSES).size() should be (Array(2)) + } + }) + } + // todo: There is a race-condition bug in MTImageFeatureToBatch /* "MTImageFeatureToBatch classification" should "work well" in {