Skip to content

Commit

Permalink
[TOPI] Add generic batch norm (apache#9694)
Browse files Browse the repository at this point in the history
* Add topi batch norm and tests

* Handle none values correctly

* Return correct nun outputs for onnx

* Use moving var/mean and update tests

* Add a test for batch norm folding

* Fix comment

* Format with black

* Re-order test args to match interface

* Call fold constant manually
  • Loading branch information
michalpiszczek authored and baoxinqi committed Dec 27, 2021
1 parent 5c5bab9 commit fcf37f6
Show file tree
Hide file tree
Showing 12 changed files with 418 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,10 @@ def _impl_v1(cls, inputs, attr, params):
op_name="batch_norm",
ignores=["spatial", "is_test", "consumed_inputs", "momentum", "training_mode"],
)(inputs, attr, params)
return out[0]
# We only support test mode, so we return data, moving_mean, moving_var,
# and then moving_mean and moving_var again as placeholders for
# the expected "saved_mean", "saved_var".
return _expr.TupleWrapper(_expr.Tuple((*out, out[1], out[2])), 5)


class InstanceNorm(OnnxOpConverter):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ def legalize_batch_matmul(attrs, inputs, types):
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# batch_norm
reg.register_strategy("nn.batch_norm", strategy.batch_norm_strategy)
reg.register_pattern("nn.batch_norm", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# sparse_dense
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type):
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,29 @@ def batch_matmul_strategy(attrs, inputs, out_type, target):
return strategy


# batch_norm
def wrap_compute_batch_norm(topi_compute):
"""wrap batch_norm topi compute"""

def _compute_batch_norm(attrs, inputs, out_type):
return topi_compute(*inputs, attrs.axis, attrs.epsilon, attrs.center, attrs.scale)

return _compute_batch_norm


@override_native_generic_func("batch_norm_strategy")
def batch_norm_strategy(attrs, inputs, out_type, target):
"""batch_norm generic strategy"""
logger.warning("batch_norm is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_norm(topi.nn.batch_norm),
wrap_topi_schedule(topi.generic.schedule_batch_norm),
name="batch_norm.generic",
)
return strategy


# sparse dense
def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,23 @@ def schedule_batch_matmul(outs):
return _default_schedule(outs, False)


def schedule_batch_norm(outs):
"""Schedule for batch_norm
Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_transpose
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_correlation_nchw(outs):
"""Schedule for correlation_nchw
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .bitserial_conv2d import *
from .bitserial_dense import *
from .batch_matmul import *
from .batch_norm import *
from .sparse import *
from .pad import *
from .fifo_buffer import *
Expand Down
110 changes: 110 additions & 0 deletions python/tvm/topi/nn/batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Batch normalization."""
import typing

from tvm import te
from tvm import topi


def batch_norm(
data: te.Tensor,
gamma: te.Tensor,
beta: te.Tensor,
moving_mean: te.Tensor,
moving_var: te.Tensor,
axis: typing.Optional[int] = None,
epsilon: typing.Optional[float] = None,
center: typing.Optional[bool] = None,
scale: typing.Optional[bool] = None,
) -> typing.List[te.Tensor]:
"""Batch normalization layer (Ioffe and Szegedy, 2014).
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.
Parameters
----------
data : tvm.te.Tensor
Input to be batch-normalized.
gamma : tvm.te.Tensor
Scale factor to be applied to the normalized tensor.
beta : tvm.te.Tensor
Offset to be applied to the normalized tensor.
moving_mean : tvm.te.Tensor
Running mean of input.
moving_var : tvm.te.Tensor
Running variance of input.
axis : int, optional, default=1
Specify along which shape axis the normalization should occur.
epsilon : float, optional, default=1e-5
Small float added to variance to avoid dividing by zero.
center : bool, optional, default=True
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : bool, optional, defualt=True
If True, scale normalized tensor by gamma. If False, gamma
is ignored.
Returns
-------
output : list of tvm.te.Tensor
Normalized data with same shape as input
moving_mean : tvm.te.Tensor
Running mean of input.
moving_var : tvm.te.Tensor
Running variance of input.
"""
if axis is None:
axis = 1

if epsilon is None:
epsilon = 1e-5

if center is None:
center = True

if scale is None:
scale = True

shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]

moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)

out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)

if scale:
out = out * topi.reshape(gamma, shape)
if center:
out = out + topi.reshape(beta, shape)

# Moving mean and var aren't updated during test. To avoid
# placeholder reuse, we multiply by 1 and return them.
return [out, moving_mean * 1, moving_var * 1]
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .batch_norm import batch_norm
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .poolnd_python import poolnd_python
Expand Down
89 changes: 89 additions & 0 deletions python/tvm/topi/testing/batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Batch Normalization implemented in Numpy."""
import numpy as np


def batch_norm(
x: np.ndarray,
gamma: np.ndarray,
beta: np.ndarray,
moving_mean: np.ndarray,
moving_var: np.ndarray,
axis: int,
epsilon: float,
center: bool,
scale: bool,
):
"""Batch Normalization operator implemented in Numpy.
Parameters
----------
data : np.ndarray
Input to be batch-normalized.
gamma : np.ndarray
Scale factor to be applied to the normalized tensor.
beta : np.ndarray
Offset to be applied to the normalized tensor.
moving_mean : np.ndarray
Running mean of input.
moving_var : np.ndarray
Running variance of input.
axis : int
Specify along which shape axis the normalization should occur.
epsilon : float
Small float added to variance to avoid dividing by zero.
center : bool
If True, add offset of beta to normalized tensor, If False,
beta is ignored.
scale : bool
If True, scale normalized tensor by gamma. If False, gamma
is ignored.
Returns
-------
output : np.ndarray
Normalized data with same shape as input
moving_mean : np.ndarray
Running mean of input.
moving_var : np.ndarray
Running variance of input.
"""
shape = [1] * len(x.shape)
shape[axis] = x.shape[axis]

moving_mean_rs = moving_mean.reshape(shape)
moving_var_rs = moving_var.reshape(shape)

out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon)

if scale:
out = out * gamma.reshape(shape)
if center:
out = out + beta.reshape(shape)

return [out, moving_mean, moving_var]
3 changes: 2 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,8 @@ bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
reporter->Assign(types[4], TensorType({axis_size}, data->dtype));

// output is a tuple of the normed data (same shape as input), new running mean,
// and new running average (the latter two are both vectors of length dim)
// new running variance, saved mean and saved variance (the latter are all
// vectors of length dim)
std::vector<Type> fields;
auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}), data->dtype);
fields.push_back(TensorType(data->shape, data->dtype));
Expand Down
3 changes: 3 additions & 0 deletions src/topi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense)
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
.set_default(WrapSchedule(topi::generic::default_schedule));

TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm)
.set_default(WrapSchedule(topi::generic::default_schedule));

TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule))
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def test_batch_norm():
)
)

# axis=1
beta = relay.var("beta", relay.TensorType((3,), dtype))
gamma = relay.var("gamma", relay.TensorType((3,), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype))
Expand Down Expand Up @@ -427,6 +428,53 @@ def test_batch_norm():
)


def test_batch_norm_fold_const():
axis = 1
dtype = "float32"
shape = [4, 5, 6]

data_np = np.random.random(shape).astype(dtype)
beta_np = np.random.random(shape[axis]).astype(dtype)
gamma_np = np.random.random(shape[axis]).astype(dtype)
moving_mean_np = np.random.random(shape[axis]).astype(dtype)
moving_var_np = np.random.random(shape[axis]).astype(dtype)

data = relay.var("data", relay.TensorType(shape, dtype))
beta = relay.var("beta", relay.TensorType((shape[1],), dtype))
gamma = relay.var("gamma", relay.TensorType((shape[1],), dtype))
moving_mean = relay.var("moving_mean", relay.TensorType((shape[1],), dtype))
moving_var = relay.var("moving_var", relay.TensorType((shape[1],), dtype))
out = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=axis).astuple()
func = relay.Function([data, gamma, beta, moving_mean, moving_var], out)

out_const = relay.nn.batch_norm(
relay.const(data_np),
relay.const(gamma_np),
relay.const(beta_np),
relay.const(moving_mean_np),
relay.const(moving_var_np),
axis=axis,
).astuple()
func_const = relay.Function([], out_const)

# Build the module with constants to have FoldConstant transform batch_norm.
mod_const = tvm.IRModule.from_expr(func_const)
mod_const = relay.transform.FoldConstant()(mod_const)

const_data_out = mod_const["main"].body[0].data
const_moving_mean_out = mod_const["main"].body[1].data
const_moving_var_out = mod_const["main"].body[2].data

# Run the Relay func without constants. This will use SimplyInference instead.
vm_data_out, vm_moving_mean_out, vm_moving_var_out = relay.create_executor(
"vm", device=tvm.device("llvm"), target="llvm"
).evaluate(func)(data_np, gamma_np, beta_np, moving_mean_np, moving_var_np)

tvm.testing.assert_allclose(const_data_out.numpy(), vm_data_out.numpy())
tvm.testing.assert_allclose(const_moving_mean_out.numpy(), vm_moving_mean_out.numpy())
tvm.testing.assert_allclose(const_moving_var_out.numpy(), vm_moving_var_out.numpy())


@pytest.mark.xfail
def test_matmul_type_check():
dtype = "float16"
Expand Down
Loading

0 comments on commit fcf37f6

Please sign in to comment.