diff --git a/python/singa/initializer.py b/python/singa/initializer.py index cb2f5a02b..90ec85497 100644 --- a/python/singa/initializer.py +++ b/python/singa/initializer.py @@ -17,44 +17,171 @@ # ============================================================================= '''Popular initialization methods for parameter values (Tensor objects). +credit: this module is adapted from keras +https://github.com/keras-team/keras/blob/master/keras/initializers.py + +All functions in this module change the input tensor in-place. + Example usages:: from singa import tensor from singa import initializer x = tensor.Tensor((3, 5)) - initializer.uniform(x, 3, 5) # use both fan_in and fan_out - initializer.uniform(x, 3, 0) # use only fan_in + initializer.he_uniform(x) + initializer.golorot_norm(x) ''' + from __future__ import division import math +import numpy as np +from deprecated import deprecated + + +def eye(t): + """Initialize the tensor with ones on the diagonal and zeros elsewhere. + + Note: it is implemented by calling numpy. + Do not call it within forward propagation when computation graph is enabled. + + # Arguments + t(Tensor): the matrix to be filled in. + """ + if len(t.shape) == 2: + raise ValueError("Only tensors with 2 dimensions are supported") + a = np.eye(t.shape[0], t.shape[1], dtype=np.float32) + t.copy_from(a) + + +def orthogonal(t, gain=1.0): + """Initializer that generates a random orthogonal matrix. + + Note: it is implemented by calling numpy. + Do not call it within forward propagation when computation graph is enabled. + + # Arguments + t(Tensor): the matrix to be filled in. + gain: Multiplicative factor to apply to the orthogonal matrix. + + # References + - [Exact solutions to the nonlinear dynamics of learning in deep + linear neural networks](http://arxiv.org/abs/1312.6120) + """ + if len(t.shape) == 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + a = np.random.normal(0.0, 1.0, t.shape).astype(np.float32) + u, _, v = np.linalg.svd(a, full_matrices=False) + # Pick the one with the correct shape. + q = u if u.shape == t.shape else v + q *= gain + t.copy_from(q) + + +def lecun_uniform(t): + """LeCun uniform initializer. + + It draws samples from a uniform distribution within [-limit, limit] + where `limit` is `sqrt(3 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + # Arguments + t(Tensor):the tensor to be filled in. + + # References + - [Efficient BackProp](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) + """ + _random_fill(t, scale=1., mode='fan_in', distribution='uniform') + + +def glorot_normal(t): + """Glorot normal initializer, also called Xavier normal initializer. + + It draws samples from a normal distribution centered on 0 + with `stddev = sqrt(2 / (fan_in + fan_out))` + where `fan_in` is the number of input units in the weight tensor + and `fan_out` is the number of output units in the weight tensor. + + # Arguments + t(Tensor):the tensor to be filled in. + + # References + - [Understanding the difficulty of training deep feedforward neural + networks](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) + """ + _random_fill(t, scale=1., mode='fan_avg', distribution='normal') + + +def glorot_uniform(t): + """Glorot uniform initializer, also called Xavier uniform initializer. + + It draws samples from a uniform distribution within [-limit, limit] + where `limit` is `sqrt(6 / (fan_in + fan_out))` + where `fan_in` is the number of input units in the weight tensor + and `fan_out` is the number of output units in the weight tensor. + + # Arguments + t(Tensor):the tensor to be filled in. + # References + - [Understanding the difficulty of training deep feedforward neural + networks](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) + """ + _random_fill(t, scale=1., mode='fan_avg', distribution='uniform') + + +def he_normal(t): + """He normal initializer. + + It draws samples from a truncated normal distribution centered on 0 + with `stddev = sqrt(2 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + # Arguments + t(Tensor):the tensor to be filled in. + + # References + - [Delving Deep into Rectifiers: Surpassing Human-Level Performance on + ImageNet Classification](http://arxiv.org/abs/1502.01852) + """ + _random_fill(t, scale=2., mode='fan_in', distribution='normal') + +def lecun_normal(t): + """LeCun normal initializer. -def uniform(t, fan_in=0, fan_out=0): + It draws samples from a truncated normal distribution centered on 0 + with `stddev = sqrt(1 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + + # Arguments + t(Tensor):the tensor to be filled in. + + # References + - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515) + - [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) + """ + _random_fill(t, scale=1., mode='fan_in', distribution='normal') + + +def he_uniform(t): '''Initialize the values of the input tensor following a uniform distribution with specific bounds. - Args: - fan_in(int): for the weight Tensor of a convolution layer, - fan_in = nb_channel * kh * kw; for dense layer, - fan_in = input_feature_length - fan_out(int): for the convolution layer weight Tensor, - fan_out = nb_filter * kh * kw; for the weight Tensor of a dense - layer, fan_out = output_feature_length + It draws samples from a uniform distribution within [-limit, limit] + where `limit` is `sqrt(6 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. - Ref: [Bengio and Glorot 2010]: Understanding the difficulty of - training deep feedforward neuralnetworks. + # Arguments + t(Tensor): the tensor to be filled in. + # References + - [Delving Deep into Rectifiers: Surpassing Human-Level Performance on + ImageNet Classification](http://arxiv.org/abs/1502.01852) ''' - assert fan_in > 0 or fan_out > 0, \ - 'fan_in and fan_out cannot be 0 at the same time' - avg = 2 - if fan_in * fan_out == 0: - avg = 1 - x = math.sqrt(3.0 * avg / (fan_in + fan_out)) - t.uniform(-x, x) + _random_fill(t, scale=2., mode='fan_in', distribution='uniform') +@deprecated(reason="Use he_normal or glorot_normal") def gaussian(t, fan_in=0, fan_out=0): '''Initialize the values of the input tensor following a Gaussian distribution with specific std. @@ -79,12 +206,11 @@ def gaussian(t, fan_in=0, fan_out=0): t.gaussian(0, std) +@deprecated(reason="Use glorot_normal") def xavier(t): '''Initialize the matrix parameter follow a Uniform distribution from [-sqrt(6/(fan_in + fan_out)), sqrt(6/(fan_in + fan_out))]. - Deprecated. Please use uniform() - Args: t (Tensor): the parater tensor ''' @@ -93,12 +219,11 @@ def xavier(t): t.uniform(-scale, scale) +@deprecated(reason="Use glorot_uniform") def glorot(t): '''Initialize the matrix parameter follow a Gaussian distribution with mean = 0 and std = sqrt(2.0 / (nb_row + nb_col)) - Deprecated. Please use gaussian() - Args: t (Tensor): the parater tensor ''' @@ -107,12 +232,11 @@ def glorot(t): t *= scale +@deprecated(reason="Use he_normal") def msra(t): '''Initialize the matrix parameter follow a Guassian distribution with mean = 0, std = math.sqrt(2.0 / nb_row). - Deprecated. Please use gaussian() - Ref [He, Zhang, Ren and Sun 2015]: Specifically accounts for ReLU nonlinearities. @@ -120,3 +244,96 @@ def msra(t): t (Tensor): the parater tensor ''' t.gaussian(0, math.sqrt(2.0 / t.shape[0])) + + +def _compute_fans(shape, data_format='channels_first'): + """Computes the number of input and output units for a weight shape. + # Arguments + shape: Integer shape tuple. + data_format: Image data format to use for convolution kernels. + Note that all kernels in Keras are standardized on the + `channels_last` ordering (even when inputs are set + to `channels_first`). + # Returns + A tuple of scalars, `(fan_in, fan_out)`. + # Raises + ValueError: in case of invalid `data_format` argument. + """ + if len(shape) == 2: + fan_in = shape[0] + fan_out = shape[1] + elif len(shape) in {3, 4, 5}: + # Assuming convolution kernels (1D, 2D or 3D). + # TH kernel shape: (depth, input_depth, ...) + # TF kernel shape: (..., input_depth, depth) + if data_format == 'channels_first': + receptive_field_size = np.prod(shape[2:]) + fan_in = shape[1] * receptive_field_size + fan_out = shape[0] * receptive_field_size + elif data_format == 'channels_last': + receptive_field_size = np.prod(shape[:-2]) + fan_in = shape[-2] * receptive_field_size + fan_out = shape[-1] * receptive_field_size + else: + raise ValueError('Invalid data_format: ' + data_format) + else: + # No specific assumptions. + fan_in = np.sqrt(np.prod(shape)) + fan_out = np.sqrt(np.prod(shape)) + return fan_in, fan_out + + +def _random_fill(t, scale, mode, distribution): + """Fill the tensor with values sampled from a distribution. + + With `distribution="normal"`, samples are drawn from a normal + distribution centered on zero, with `stddev = sqrt(scale / n)` where n is: + - number of input units in the weight tensor, if mode = "fan_in" + - number of output units, if mode = "fan_out" + - average of the numbers of input and output units, if mode = "fan_avg" + + With `distribution="uniform"`, + samples are drawn from a uniform distribution + within [-limit, limit], with `limit = sqrt(3 * scale / n)`. + + + Args: + t (Tensor): Tensor to be filled + scale (float): scale factor + mode (str): "fan_in" or "fan_out" or "fan_avg" + distribution (str): "normal" or "uniform" + + Raises: + ValueError: In case of an invalid value for scale, mode or distribution + """ + if scale <= 0.: + raise ValueError('`scale` must be a positive float. Got:', scale) + mode = mode.lower() + if mode not in {'fan_in', 'fan_out', 'fan_avg'}: + raise ValueError( + 'Invalid `mode` argument: ' + 'expected on of {"fan_in", "fan_out", "fan_avg"} ' + 'but got', mode) + distribution = distribution.lower() + if distribution not in {'normal', 'uniform'}: + raise ValueError( + 'Invalid `distribution` argument: ' + 'expected one of {"normal", "uniform"} ' + 'but got', distribution) + + fan_in, fan_out = _compute_fans(t.shape) + if mode == 'fan_in': + scale /= max(1., fan_in) + elif mode == 'fan_out': + scale /= max(1., fan_out) + else: + scale /= max(1., float(fan_in + fan_out) / 2) + if distribution == 'normal': + # 0.879... = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + # stddev = np.sqrt(scale) / .87962566103423978 + t.gaussian(0., np.sqrt(scale)) + else: + limit = np.sqrt(3. * scale) + t.uniform(-limit, limit) + + diff --git a/python/singa/metric.py b/python/singa/metric.py index 73d57df16..ed6a7c9bb 100644 --- a/python/singa/metric.py +++ b/python/singa/metric.py @@ -19,6 +19,8 @@ performance. The specific metric classes could be converted from C++ implmentation or implemented directly using Python. +Note: This module is deprecated. Please convert the prediction into numpy +array and use the sklearn to compute the metrics. Example usage:: diff --git a/python/singa/optimizer.py b/python/singa/optimizer.py index 8c252c7cd..639943a4b 100644 --- a/python/singa/optimizer.py +++ b/python/singa/optimizer.py @@ -17,6 +17,8 @@ # ============================================================================= '''This module includes a set of optimizers for updating model parameters. +Note: This module is deprecated. Please use the opt module. + Example usage:: from singa import optimizer diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py index cae151d81..67f246bc5 100644 --- a/python/singa/snapshot.py +++ b/python/singa/snapshot.py @@ -18,6 +18,9 @@ ''' This script includes io::snapshot class and its methods. +Note: This module is depreated. Please use the model module for +checkpoing and restore. + Example usages:: from singa import snapshot diff --git a/test/python/test_initializer.py b/test/python/test_initializer.py new file mode 100644 index 000000000..cbd082e60 --- /dev/null +++ b/test/python/test_initializer.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from singa import initializer +from singa import tensor +from singa import singa_wrap + +from cuda_helper import gpu_dev, cpu_dev + +import unittest +import numpy as np + + +class TestInitializer(unittest.TestCase): + + def setUp(self): + self.t1 = tensor.Tensor((40, 90)) + self.t2 = tensor.Tensor((30, 50, 8)) + self.t3 = tensor.Tensor((30, 50, 4, 8)) + + def compute_fan(self, shape): + if len(shape) == 2: + fan_in = shape[0] + fan_out = shape[1] + elif len(shape) in {3, 4, 5}: + fan_in = shape[1] * np.prod(shape[2:]) + fan_out = shape[0] * np.prod(shape[2:]) + else: + fan_in = fan_out = np.sqrt(np.prod(shape)) + + return fan_in, fan_out + + def he_uniform(self, dev): + + def init(shape): + fan_in, _ = self.compute_fan(shape) + limit = np.sqrt(6 / fan_in) + return limit + + self.t1.to_device(dev) + initializer.he_uniform(self.t1) + np_t1 = tensor.to_numpy(self.t1) + limit = init(self.t1.shape) + self.assertAlmostEqual(np_t1.max(), limit, delta=limit/10) + self.assertAlmostEqual(np_t1.min(), -limit, delta=limit/10) + self.assertAlmostEqual(np_t1.mean(), 0, delta=limit/10) + + self.t2.to_device(dev) + initializer.he_uniform(self.t2) + np_t2 = tensor.to_numpy(self.t2) + limit = init(self.t2.shape) + self.assertAlmostEqual(np_t2.max(), limit, delta=limit/10) + self.assertAlmostEqual(np_t2.min(), -limit, delta=limit/10) + self.assertAlmostEqual(np_t2.mean(), 0, delta=limit/10) + + self.t3.to_device(dev) + initializer.he_uniform(self.t3) + np_t3 = tensor.to_numpy(self.t3) + limit = init(self.t3.shape) + self.assertAlmostEqual(np_t3.max(), limit, delta=limit/10) + self.assertAlmostEqual(np_t3.min(), -limit, delta=limit/10) + self.assertAlmostEqual(np_t3.mean(), 0, delta=limit/10) + + + @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled') + def test_he_uniform_gpu(self): + self.he_uniform(gpu_dev) + + def test_he_uniform_cpu(self): + self.he_uniform(cpu_dev) + + def he_normal(self, dev): + + def init(shape): + fan_in, _ = self.compute_fan(shape) + stddev = np.sqrt(2 / fan_in) + return stddev + + self.t1.to_device(dev) + initializer.he_normal(self.t1) + np_t1 = tensor.to_numpy(self.t1) + stddev = init(self.t1.shape) + self.assertAlmostEqual(np_t1.mean(), 0, delta=stddev/10) + self.assertAlmostEqual(np_t1.std(), stddev, delta=stddev/10) + + self.t2.to_device(dev) + initializer.he_normal(self.t2) + np_t2 = tensor.to_numpy(self.t2) + stddev = init(self.t2.shape) + self.assertAlmostEqual(np_t2.mean(), 0, delta=stddev/10) + self.assertAlmostEqual(np_t2.std(), stddev, delta=stddev/10) + + self.t3.to_device(dev) + initializer.he_normal(self.t3) + np_t3 = tensor.to_numpy(self.t3) + stddev = init(self.t3.shape) + self.assertAlmostEqual(np_t3.mean(), 0, delta=stddev/10) + self.assertAlmostEqual(np_t3.std(), stddev, delta=stddev/10) + + @unittest.skipIf(not singa_wrap.USE_CUDA, 'CUDA is not enabled') + def test_he_normal_gpu(self): + self.he_uniform(gpu_dev) + + def test_he_normal_cpu(self): + self.he_uniform(cpu_dev) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file