Skip to content

Commit

Permalink
Incremental Training for imagenet (intel-analytics#1391)
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Jul 31, 2019
1 parent 211b3a8 commit 6184b43
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 5 deletions.
27 changes: 27 additions & 0 deletions layers/utils/KerasUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,33 @@ object KerasUtils {
method.invoke(obj, args: _*)
}

private[zoo] def invokeMethodWithEv[T: ClassTag](
obj: String,
methodName: String,
args: Object*)(implicit ev: TensorNumeric[T]): Object = {
val clazz = Class.forName(obj)
val method =
try {
clazz.getMethod(methodName, args.map(_.getClass): _*)
} catch {
case t: Throwable =>
val methods = clazz.getMethods().filter(_.getName() == methodName)
require(methods.length == 1,
s"We should only found one result, but got ${methodName}: ${methods.length}")
methods(0)
}
val argsWithTag = args ++ Seq(implicitly[reflect.ClassTag[T]], ev)
method.invoke(obj, argsWithTag: _*)
}

private[zoo] def invokeMethodWithEv[T: ClassTag](
obj: Object,
methodName: String,
args: Object*)(implicit ev: TensorNumeric[T]): Object = {
val argsWithTag = args ++ Seq(implicitly[reflect.ClassTag[T]], ev)
invokeMethod(obj, methodName, argsWithTag: _*)
}

/**
* Count the total number of parameters for a KerasLayer.
* Return a tuple (total params #, trainable params #)
Expand Down
Loading

0 comments on commit 6184b43

Please sign in to comment.