diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/SoftMax.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/SoftMax.scala index a5d233f3d16..74ceb8a8081 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/SoftMax.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/SoftMax.scala @@ -84,6 +84,10 @@ object SoftMax{ (implicit ev: TensorNumeric[T]) : SoftMax[T] = { new SoftMax[T](pos) } + def apply[@specialized(Float, Double) T: ClassTag] + (implicit ev: TensorNumeric[T]) : SoftMax[T] = { + new SoftMax[T](1) + } // Notice: SoftMin will call this function private[nn] def updateOutput[T: ClassTag](input: Tensor[T], output: Tensor[T], results: Array[Future[Unit]], pos: Int = 1) (implicit ev: TensorNumeric[T]): Tensor[T] = { diff --git a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/torch/SoftMaxSpec.scala b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/torch/SoftMaxSpec.scala index 4a92954a1fd..71780722265 100644 --- a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/torch/SoftMaxSpec.scala +++ b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/torch/SoftMaxSpec.scala @@ -23,7 +23,7 @@ import scala.util.Random @com.intel.analytics.bigdl.tags.Serial class SoftMaxSpec extends TorchSpec { - "A SoftMax 1D input" should "generate correct output and grad" in { + "A SoftMax 1D input" should "generate correct output and grad" in { torchCheck() val layer = new SoftMax[Double]() val input = Tensor[Double](10) @@ -52,6 +52,33 @@ class SoftMaxSpec extends TorchSpec { println("Test case : SoftMax, Torch : " + luaTime + " s, Scala : " + scalaTime / 1e9 + " s") } + "A SoftMax 1D input without argument" should "generate correct output and grad" in { + torchCheck() + val layer = new SoftMax[Double] + val input = Tensor[Double](10) + input.apply1(_ => Random.nextDouble()) + val gradOutput = Tensor[Double](10) + gradOutput.apply1(_ => Random.nextDouble()) + + val start = System.nanoTime() + val output = layer.forward(input) + val gradInput = layer.backward(input, gradOutput) + val end = System.nanoTime() + val scalaTime = end - start + + val code = "module = nn.SoftMax()\n" + + "output = module:forward(input)\n" + + "gradInput = module:backward(input,gradOutput)" + + val (luaTime, torchResult) = TH.run(code, Map("input" -> input, "gradOutput" -> gradOutput), + Array("output", "gradInput")) + val luaOutput = torchResult("output").asInstanceOf[Tensor[Double]] + val luaGradInput = torchResult("gradInput").asInstanceOf[Tensor[Double]] + + output should be (luaOutput) + gradInput should be (luaGradInput) + } + "A SoftMax 2D input" should "generate correct output and grad" in { torchCheck() val layer = new SoftMax[Double]()