Skip to content

Commit

Permalink
support multi input models for nnframes (intel-analytics#1553)
Browse files Browse the repository at this point in the history
* support multi input for nnframes

* update ut

* add doc and unit test

* doc update

* scala style
  • Loading branch information
hhbyyh committed Aug 20, 2019
1 parent ba1c11f commit 8d9c3d1
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,44 @@ import scala.reflect.ClassTag
* @tparam F data type from feature column, E.g. Array[_] or Vector
* @tparam L data type from label column, E.g. Float, Double, Array[_] or Vector
*/
class FeatureLabelPreprocessing[F, L, T: ClassTag](
featureStep: Preprocessing[F, Tensor[T]],
labelStep: Preprocessing[L, Tensor[T]])(implicit ev: TensorNumeric[T])
extends Preprocessing[(F, Option[L]), Sample[T]] {
class FeatureLabelPreprocessing[F, X, L, T: ClassTag] private[zoo] (
featureStep: Preprocessing[F, X],
labelStep: Preprocessing[L, Tensor[T]]
)(implicit ev: TensorNumeric[T]) extends Preprocessing[(F, Option[L]), Sample[T]] {

override def apply(prev: Iterator[(F, Option[L])]): Iterator[Sample[T]] = {
prev.map { case (feature, label ) =>
val featureTensor = featureStep(Iterator(feature)).next()
label match {
case Some(l) =>
val labelTensor = labelStep(Iterator(l)).next()
Sample[T](featureTensor, labelTensor)
case None =>
Sample[T](featureTensor)
val featureTensors = featureStep(Iterator(feature)).next()
featureTensors match {
case ft: Tensor[T] =>
val ft = featureTensors.asInstanceOf[Tensor[T]]
label match {
case Some(l) =>
val labelTensor = labelStep(Iterator(l)).next()
Sample[T](ft, labelTensor)
case None =>
Sample[T](ft)
}
case fat: Array[Tensor[T]] =>
label match {
case Some(l) =>
val labelTensor = labelStep(Iterator(l)).next()
Sample[T](fat, labelTensor)
case None =>
Sample[T](fat)
}
case _ =>
throw new UnsupportedOperationException(
s"FeatureLabelPreprocessing expects table or tensor, but got $featureTensors")
}
}
}
}

object FeatureLabelPreprocessing {
def apply[F, L, T: ClassTag](
featureStep: Preprocessing[F, Tensor[T]],
def apply[F, X, L, T: ClassTag](
featureStep: Preprocessing[F, X],
labelStep: Preprocessing[L, Tensor[T]]
)(implicit ev: TensorNumeric[T]): FeatureLabelPreprocessing[F, L, T] =
)(implicit ev: TensorNumeric[T]): FeatureLabelPreprocessing[F, X, L, T] =
new FeatureLabelPreprocessing(featureStep, labelStep)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.feature.common

import com.intel.analytics.bigdl.dataset.Sample
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect.ClassTag

/**
* a Preprocessing that converts multiple tensors to Sample.
*/
class MultiTensorsToSample[T: ClassTag]()(implicit ev: TensorNumeric[T])
extends Preprocessing[Array[Tensor[T]], Sample[T]] {

override def apply(prev: Iterator[Array[Tensor[T]]]): Iterator[Sample[T]] = {
prev.map(Sample(_))
}
}

object MultiTensorsToSample {
def apply[F, T: ClassTag]()(implicit ev: TensorNumeric[T]): MultiTensorsToSample[T] =
new MultiTensorsToSample()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.feature.common

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric

import scala.reflect.ClassTag

/**
* a Preprocessing that converts Array[Float], Array[Double] or MLlib Vector to multiple tensors
* for multi-input models
* @param multiSizes dimensions of target Tensors.
*/
class SeqToMultipleTensors[T: ClassTag](multiSizes: Array[Array[Int]])
(implicit ev: TensorNumeric[T]) extends Preprocessing[Any, Array[Tensor[T]]] {

override def apply(prev: Iterator[Any]): Iterator[Array[Tensor[T]]] = {
prev.map { f =>
val tensors = f match {
case sd: Seq[Any] => matchSeq(sd)
case _ => throw new IllegalArgumentException("SeqToTensor only supports Float, Double, " +
s"Array[Float], Array[Double] or MLlib Vector but got $f")
}
tensors
}
}

def matchSeq(list: Seq[Any]): Array[Tensor[T]] = {
val rawData = list.head match {
case dd: Double => list.asInstanceOf[Seq[Double]].map(ev.fromType(_)).toArray
case ff: Float => list.asInstanceOf[Seq[Float]].map(ev.fromType(_)).toArray
case ii: Int => list.asInstanceOf[Seq[Int]].map(ev.fromType(_)).toArray
case _ => throw new IllegalArgumentException(s"SeqToTensor only supports Array[Int], " +
s"Array[Float] and Array[Double] for ArrayType, but got $list")
}

require(multiSizes.map(s => s.product).sum == rawData.length, s"feature columns length " +
s"${rawData.length} does not match with the sum of tensors" +
s" ${multiSizes.map(a => a.mkString(",")).mkString("\n")}")

var cur = 0
val tensors = multiSizes.map { size =>
val rawLength = size.product
val t = Tensor(rawData.slice(cur, cur + rawLength), size)
cur += rawLength
t
}
tensors
}
}


object SeqToMultipleTensors {
def apply[T: ClassTag](
multiSizes: Array[Array[Int]]
)(implicit ev: TensorNumeric[T]): SeqToMultipleTensors[T] =
new SeqToMultipleTensors[T](multiSizes)
}

0 comments on commit 8d9c3d1

Please sign in to comment.