diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 79fe0d734b6d..dcdd4c5ce9a4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -474,7 +474,10 @@ def _impl_v1(cls, inputs, attr, params): op_name="batch_norm", ignores=["spatial", "is_test", "consumed_inputs", "momentum", "training_mode"], )(inputs, attr, params) - return out[0] + # We only support test mode, so we return data, moving_mean, moving_var, + # and then moving_mean and moving_var again as placeholders for + # the expected "saved_mean", "saved_var". + return _expr.TupleWrapper(_expr.Tuple((*out, out[1], out[2])), 5) class InstanceNorm(OnnxOpConverter): diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 3cec70d3d182..e9ec29c1b5c7 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -152,6 +152,11 @@ def legalize_batch_matmul(attrs, inputs, types): reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +# batch_norm +reg.register_strategy("nn.batch_norm", strategy.batch_norm_strategy) +reg.register_pattern("nn.batch_norm", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + # sparse_dense @reg.register_compute("nn.sparse_dense") def compute_sparse_dense(attrs, inputs, out_type): diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 11ce22646c16..461e755f5212 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -848,6 +848,29 @@ def batch_matmul_strategy(attrs, inputs, out_type, target): return strategy +# batch_norm +def wrap_compute_batch_norm(topi_compute): + """wrap batch_norm topi compute""" + + def _compute_batch_norm(attrs, inputs, out_type): + return topi_compute(*inputs, attrs.axis, attrs.epsilon, attrs.center, attrs.scale) + + return _compute_batch_norm + + +@override_native_generic_func("batch_norm_strategy") +def batch_norm_strategy(attrs, inputs, out_type, target): + """batch_norm generic strategy""" + logger.warning("batch_norm is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_batch_norm(topi.nn.batch_norm), + wrap_topi_schedule(topi.generic.schedule_batch_norm), + name="batch_norm.generic", + ) + return strategy + + # sparse dense def wrap_compute_sparse_dense(topi_compute): """wrap sparse dense topi compute""" diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 22a90aa2cd07..ba63c539133e 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -815,6 +815,23 @@ def schedule_batch_matmul(outs): return _default_schedule(outs, False) +def schedule_batch_norm(outs): + """Schedule for batch_norm + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of sparse_transpose + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_correlation_nchw(outs): """Schedule for correlation_nchw diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py index b5e766adbc12..d3d00305a17b 100644 --- a/python/tvm/topi/nn/__init__.py +++ b/python/tvm/topi/nn/__init__.py @@ -42,6 +42,7 @@ from .bitserial_conv2d import * from .bitserial_dense import * from .batch_matmul import * +from .batch_norm import * from .sparse import * from .pad import * from .fifo_buffer import * diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py new file mode 100644 index 000000000000..1b4fad762568 --- /dev/null +++ b/python/tvm/topi/nn/batch_norm.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Batch normalization.""" +import typing + +from tvm import te +from tvm import topi + + +def batch_norm( + data: te.Tensor, + gamma: te.Tensor, + beta: te.Tensor, + moving_mean: te.Tensor, + moving_var: te.Tensor, + axis: typing.Optional[int] = None, + epsilon: typing.Optional[float] = None, + center: typing.Optional[bool] = None, + scale: typing.Optional[bool] = None, +) -> typing.List[te.Tensor]: + """Batch normalization layer (Ioffe and Szegedy, 2014). + + Normalizes the input at each batch, i.e. applies a transformation + that maintains the mean activation close to 0 and the activation + standard deviation close to 1. + + Parameters + ---------- + data : tvm.te.Tensor + Input to be batch-normalized. + + gamma : tvm.te.Tensor + Scale factor to be applied to the normalized tensor. + + beta : tvm.te.Tensor + Offset to be applied to the normalized tensor. + + moving_mean : tvm.te.Tensor + Running mean of input. + + moving_var : tvm.te.Tensor + Running variance of input. + + axis : int, optional, default=1 + Specify along which shape axis the normalization should occur. + + epsilon : float, optional, default=1e-5 + Small float added to variance to avoid dividing by zero. + + center : bool, optional, default=True + If True, add offset of beta to normalized tensor, If False, + beta is ignored. + + scale : bool, optional, defualt=True + If True, scale normalized tensor by gamma. If False, gamma + is ignored. + + Returns + ------- + output : list of tvm.te.Tensor + Normalized data with same shape as input + + moving_mean : tvm.te.Tensor + Running mean of input. + + moving_var : tvm.te.Tensor + Running variance of input. + """ + if axis is None: + axis = 1 + + if epsilon is None: + epsilon = 1e-5 + + if center is None: + center = True + + if scale is None: + scale = True + + shape = [1] * len(data.shape) + shape[axis] = data.shape[axis] + + moving_mean_rs = topi.reshape(moving_mean, shape) + moving_var_rs = topi.reshape(moving_var, shape) + + out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) + + if scale: + out = out * topi.reshape(gamma, shape) + if center: + out = out + topi.reshape(beta, shape) + + # Moving mean and var aren't updated during test. To avoid + # placeholder reuse, we multiply by 1 and return them. + return [out, moving_mean * 1, moving_var * 1] diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 2d7d0a4b9e11..8f78805fff3b 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -49,6 +49,7 @@ from .gather_nd_python import gather_nd_python from .strided_slice_python import strided_slice_python, strided_set_python from .batch_matmul import batch_matmul +from .batch_norm import batch_norm from .slice_axis_python import slice_axis_python from .sequence_mask_python import sequence_mask from .poolnd_python import poolnd_python diff --git a/python/tvm/topi/testing/batch_norm.py b/python/tvm/topi/testing/batch_norm.py new file mode 100644 index 000000000000..0a79b6849d4e --- /dev/null +++ b/python/tvm/topi/testing/batch_norm.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Batch Normalization implemented in Numpy.""" +import numpy as np + + +def batch_norm( + x: np.ndarray, + gamma: np.ndarray, + beta: np.ndarray, + moving_mean: np.ndarray, + moving_var: np.ndarray, + axis: int, + epsilon: float, + center: bool, + scale: bool, +): + """Batch Normalization operator implemented in Numpy. + + Parameters + ---------- + data : np.ndarray + Input to be batch-normalized. + + gamma : np.ndarray + Scale factor to be applied to the normalized tensor. + + beta : np.ndarray + Offset to be applied to the normalized tensor. + + moving_mean : np.ndarray + Running mean of input. + + moving_var : np.ndarray + Running variance of input. + + axis : int + Specify along which shape axis the normalization should occur. + + epsilon : float + Small float added to variance to avoid dividing by zero. + + center : bool + If True, add offset of beta to normalized tensor, If False, + beta is ignored. + + scale : bool + If True, scale normalized tensor by gamma. If False, gamma + is ignored. + + Returns + ------- + output : np.ndarray + Normalized data with same shape as input + + moving_mean : np.ndarray + Running mean of input. + + moving_var : np.ndarray + Running variance of input. + """ + shape = [1] * len(x.shape) + shape[axis] = x.shape[axis] + + moving_mean_rs = moving_mean.reshape(shape) + moving_var_rs = moving_var.reshape(shape) + + out = (x - moving_mean_rs) / np.sqrt(moving_var_rs + epsilon) + + if scale: + out = out * gamma.reshape(shape) + if center: + out = out + beta.reshape(shape) + + return [out, moving_mean, moving_var] diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 574ecc0828dd..89ef2708ff27 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -745,7 +745,8 @@ bool BatchNormRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[4], TensorType({axis_size}, data->dtype)); // output is a tuple of the normed data (same shape as input), new running mean, - // and new running average (the latter two are both vectors of length dim) + // new running variance, saved mean and saved variance (the latter are all + // vectors of length dim) std::vector fields; auto vec_ty = TensorType(Array({data->shape[axis]}), data->dtype); fields.push_back(TensorType(data->shape, data->dtype)); diff --git a/src/topi/schedule.cc b/src/topi/schedule.cc index 21f863bb2e70..0999f00ffd11 100644 --- a/src/topi/schedule.cc +++ b/src/topi/schedule.cc @@ -230,6 +230,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense) TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) .set_default(WrapSchedule(topi::generic::default_schedule)); +TVM_REGISTER_GENERIC_FUNC(schedule_batch_norm) + .set_default(WrapSchedule(topi::generic::default_schedule)); + TVM_REGISTER_GENERIC_FUNC(schedule_pool) .set_default(WrapSchedule(topi::generic::default_schedule)) .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 97406e7e0d48..2f3f86349a86 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -387,6 +387,7 @@ def test_batch_norm(): ) ) + # axis=1 beta = relay.var("beta", relay.TensorType((3,), dtype)) gamma = relay.var("gamma", relay.TensorType((3,), dtype)) moving_mean = relay.var("moving_mean", relay.TensorType((3,), dtype)) @@ -427,6 +428,53 @@ def test_batch_norm(): ) +def test_batch_norm_fold_const(): + axis = 1 + dtype = "float32" + shape = [4, 5, 6] + + data_np = np.random.random(shape).astype(dtype) + beta_np = np.random.random(shape[axis]).astype(dtype) + gamma_np = np.random.random(shape[axis]).astype(dtype) + moving_mean_np = np.random.random(shape[axis]).astype(dtype) + moving_var_np = np.random.random(shape[axis]).astype(dtype) + + data = relay.var("data", relay.TensorType(shape, dtype)) + beta = relay.var("beta", relay.TensorType((shape[1],), dtype)) + gamma = relay.var("gamma", relay.TensorType((shape[1],), dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType((shape[1],), dtype)) + moving_var = relay.var("moving_var", relay.TensorType((shape[1],), dtype)) + out = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=axis).astuple() + func = relay.Function([data, gamma, beta, moving_mean, moving_var], out) + + out_const = relay.nn.batch_norm( + relay.const(data_np), + relay.const(gamma_np), + relay.const(beta_np), + relay.const(moving_mean_np), + relay.const(moving_var_np), + axis=axis, + ).astuple() + func_const = relay.Function([], out_const) + + # Build the module with constants to have FoldConstant transform batch_norm. + mod_const = tvm.IRModule.from_expr(func_const) + mod_const = relay.transform.FoldConstant()(mod_const) + + const_data_out = mod_const["main"].body[0].data + const_moving_mean_out = mod_const["main"].body[1].data + const_moving_var_out = mod_const["main"].body[2].data + + # Run the Relay func without constants. This will use SimplyInference instead. + vm_data_out, vm_moving_mean_out, vm_moving_var_out = relay.create_executor( + "vm", device=tvm.device("llvm"), target="llvm" + ).evaluate(func)(data_np, gamma_np, beta_np, moving_mean_np, moving_var_np) + + tvm.testing.assert_allclose(const_data_out.numpy(), vm_data_out.numpy()) + tvm.testing.assert_allclose(const_moving_mean_out.numpy(), vm_moving_mean_out.numpy()) + tvm.testing.assert_allclose(const_moving_var_out.numpy(), vm_moving_var_out.numpy()) + + @pytest.mark.xfail def test_matmul_type_check(): dtype = "float16" diff --git a/tests/python/topi/python/test_topi_batch_norm.py b/tests/python/topi/python/test_topi_batch_norm.py new file mode 100644 index 000000000000..202b6214bc7a --- /dev/null +++ b/tests/python/topi/python/test_topi_batch_norm.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the batch_norm operator.""" +import numpy as np +import pytest + +import tvm +from tvm import te +from tvm import topi +import tvm.testing +import tvm.topi.testing + + +_DEVICE = "llvm" +_BATCH_NORM_IMPLEMENT = { + "generic": (topi.nn.batch_norm, topi.generic.schedule_batch_norm), +} + + +@pytest.mark.parametrize( + "shape, axis, epsilon, center, scale", + [ + ((1,), 0, 0.1, True, True), + ((2, 3), 0, 0.1, True, True), + ((1, 2, 4), 0, 0.1, True, True), + ((1, 2, 3, 4), 0, 0.001, False, False), + ((2, 3, 4, 1), 1, 0.01, False, True), + ((3, 4, 1, 2), 2, 0.1, True, False), + ((4, 1, 2, 3), 3, 1.0, True, True), + ((1, 2, 4, 4, 5), 0, 0.1, True, True), + ], +) +def test_batch_norm(shape, axis, epsilon, center, scale): + x_np = np.random.random(shape).astype("float32") + gamma_np = np.random.random(shape[axis]).astype("float32") + beta_np = np.random.random(shape[axis]).astype("float32") + moving_mean_np = np.random.random(shape[axis]).astype("float32") + moving_var_np = np.random.random(shape[axis]).astype("float32") + + out_x_np, out_moving_mean_np, out_moving_var_np = tvm.topi.testing.batch_norm( + x_np, gamma_np, beta_np, moving_mean_np, moving_var_np, axis, epsilon, center, scale + ) + + x_te = te.placeholder(shape, name="x", dtype="float32") + gamma_te = te.placeholder((shape[axis],), name="gamma", dtype="float32") + beta_te = te.placeholder((shape[axis],), name="beta", dtype="float32") + moving_mean_te = te.placeholder((shape[axis],), name="moving_mean", dtype="float32") + moving_var_te = te.placeholder((shape[axis],), name="moving_var", dtype="float32") + + with tvm.target.Target(_DEVICE): + fcompute, fschedule = tvm.topi.testing.dispatch(_DEVICE, _BATCH_NORM_IMPLEMENT) + out_x, out_moving_mean, out_moving_var = fcompute( + x_te, gamma_te, beta_te, moving_mean_te, moving_var_te, axis, epsilon, center, scale + ) + s = fschedule([out_x, out_moving_mean, out_moving_var]) + + dev = tvm.device(_DEVICE, 0) + + x_tvm = tvm.nd.array(x_np, dev) + gamma_tvm = tvm.nd.array(gamma_np, dev) + beta_tvm = tvm.nd.array(beta_np, dev) + moving_mean_tvm = tvm.nd.array(moving_mean_np, dev) + moving_var_tvm = tvm.nd.array(moving_var_np, dev) + out_x_tvm = tvm.nd.array(np.zeros(shape, dtype=out_x.dtype), dev) + out_moving_mean_tvm = tvm.nd.array( + np.zeros((shape[axis],), dtype=out_moving_mean.dtype), dev + ) + out_moving_var_tvm = tvm.nd.array(np.zeros((shape[axis],), dtype=out_moving_var.dtype), dev) + + f = tvm.build( + s, + [ + x_te, + gamma_te, + beta_te, + moving_mean_te, + moving_var_te, + out_x, + out_moving_mean, + out_moving_var, + ], + _DEVICE, + ) + f( + x_tvm, + gamma_tvm, + beta_tvm, + moving_mean_tvm, + moving_var_tvm, + out_x_tvm, + out_moving_mean_tvm, + out_moving_var_tvm, + ) + + tvm.testing.assert_allclose(out_x_tvm.numpy(), out_x_np, rtol=1e-3) + tvm.testing.assert_allclose(out_moving_mean_tvm.numpy(), out_moving_mean_np, rtol=1e-3) + tvm.testing.assert_allclose(out_moving_var_tvm.numpy(), out_moving_var_np, rtol=1e-3) + + +if __name__ == "__main__": + test_batch_norm()