Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add conv2d_backward_weight op (without topi) #9954

Merged
merged 27 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 13 additions & 40 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
reshape_like,
strided_slice,
take,
tile,
transpose,
where,
repeat,
Expand Down Expand Up @@ -399,15 +398,14 @@ def conv2d_grad(orig, grad):
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(weight.checked_type.shape)
_, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape)
batch, in_channel, in_h, in_w = data_shape
out_channel, _, filter_h, filter_w = weight_shape
_, _, in_h, in_w = data_shape
_, _, filter_h, filter_w = weight_shape

# infer output_padding
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
get_const_tuple(attrs.padding), (filter_h, filter_w)
)
stride_h, stride_w = get_const_tuple(attrs.strides)
dilation_h, dilation_w = get_const_tuple(attrs.dilation)
out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w
output_padding = (in_h - out_h, in_w - out_w)
Expand All @@ -425,46 +423,21 @@ def conv2d_grad(orig, grad):
groups=attrs.groups,
output_padding=output_padding,
)
grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw

backward_weight = _nn.conv2d(
data,
backward_weight = _nn.conv2d_backward_weight(
grad,
strides=attrs.dilation,
data,
strides=attrs.strides,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
)
# infer shape of backward_weight
padded_weight_grad_h = (
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
) // dilation_h + 1
padded_weight_grad_w = (
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
) // dilation_w + 1
backward_weight = reshape(
backward_weight,
[
batch,
in_channel // attrs.groups,
out_channel,
padded_weight_grad_h,
padded_weight_grad_w,
],
dilation=attrs.dilation,
groups=attrs.groups,
channels=attrs.channels,
kernel_size=(filter_h, filter_w),
grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout,
data_layout=attrs.data_layout,
kernel_layout=attrs.kernel_layout,
out_dtype=attrs.out_dtype,
)
backward_weight = _sum(backward_weight, axis=0)
backward_weight = transpose(backward_weight, [1, 0, 2, 3])

assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(
backward_weight,
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
)

return [backward_data, backward_weight]

Expand Down
78 changes: 78 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.runtime import convert
from tvm.te.hybrid import script
from tvm.topi.utils import get_const_tuple
from tvm.topi.nn.utils import get_pad_tuple

from ....ir import container
from ....tir import expr
Expand Down Expand Up @@ -1061,6 +1062,83 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_injective_schedule("nn.batch_to_space_nd")


@reg.register_legalize("nn.conv2d_backward_weight")
def legalize_conv2d_backward_weight(attrs, inputs, types):
masahi marked this conversation as resolved.
Show resolved Hide resolved
"""Legalize conv2d_backward_weight op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
grad, data = inputs
data_shape = get_const_tuple(data.checked_type.shape)
weight_shape = get_const_tuple(types[2].shape)
_, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape)
batch, in_channel, in_h, in_w = data_shape
_, _, filter_h, filter_w = weight_shape
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(
get_const_tuple(attrs.padding), (filter_h, filter_w)
)
stride_h, stride_w = get_const_tuple(attrs.strides)
dilation_h, dilation_w = get_const_tuple(attrs.dilation)

grad = relay.tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = relay.reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = relay.reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw

backward_weight = relay.nn.conv2d(
data,
grad,
strides=attrs.dilation,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
)

# infer shape of backward_weight
padded_weight_grad_h = (
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
) // dilation_h + 1
padded_weight_grad_w = (
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
) // dilation_w + 1

backward_weight = relay.reshape(
backward_weight,
[
batch,
in_channel // attrs.groups,
out_channel,
padded_weight_grad_h,
padded_weight_grad_w,
],
)
backward_weight = relay.sum(backward_weight, axis=0)
backward_weight = relay.transpose(backward_weight, [1, 0, 2, 3])

assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w

if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = relay.strided_slice(
backward_weight,
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
)

return backward_weight


#####################
# Shape functions #
#####################
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3770,3 +3770,54 @@ def batch_to_space_nd(data, block_shape, crops):
"""

return _make.batch_to_space_nd(data, block_shape, crops)


def conv2d_backward_weight(
grad,
data,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
grad_layout="NCHW",
data_layout="NCHW",
kernel_layout="OIHW",
out_dtype="",
):
r"""The gradient of conv2d with respect to weight.

This operator takes the output gradient `grad` and convolves it with `data` as
the convolution kernel, to produce the gradient with respect to weight.

Note that the parameter `kernel_size` is the spatial size of the corresponding
forward convolution kernel, not that of `data`. `grad_layout` and
`kernel_layout` are the layouts of `grad` and the weight gradient respectively.

Other parameters are the same as the conv2d op. See its documentation for more
details.

"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation)
padding = get_pad_tuple2d(padding)

return _make.conv2d_backward_weight(
grad,
data,
strides,
padding,
dilation,
groups,
channels,
kernel_size,
grad_layout,
data_layout,
kernel_layout,
out_dtype,
)
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def check_grad(

fwd_func = run_infer_type(func)
bwd_func = run_infer_type(gradient(fwd_func, mode=mode))
bwd_func = run_opt_pass(bwd_func, relay.transform.Legalize())

if scale is None:
scale = 10 * eps
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@
from .nll_loss import nll_loss
from .dense import dense
from .searchsorted import searchsorted_ref
from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
76 changes: 76 additions & 0 deletions python/tvm/topi/testing/conv2d_backcward_weight_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-nested-blocks
"""Gradient of conv2d with respect to weight in python"""
import numpy as np


# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h
def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding):
"""Gradient of the conv2d op with respect to weight, in NCHW layout.

Parameters
----------
dy_np : numpy.ndarray
4-D with shape [batch, in_channel, out_height, out_width]

x_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]

kernel_size : tuple of two ints
Height and width of the weight

stride : tuple of two ints
Stride size, or [stride_height, stride_width]

padding : tuple of two ints
Spatial padding, or [pad_h, pad_w]

Returns
-------
b_np : np.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]

"""
N, C, H, W = x_np.shape
_, K, P, Q = dy_np.shape
R, S = kernel_size
pad_h, pad_w = padding
stride_h, stride_w = stride
dw = np.zeros((K, C, R, S)).astype(dy_np.dtype)

for k in range(K):
for r in range(R):
for s in range(S):
for c in range(C):
acc = 0
for n in range(N):
for p in range(P):
for q in range(Q):
coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s)

if (
coord[2] < H
and coord[2] >= 0
and coord[3] < W
and coord[3] >= 0
):
acc += dy_np[n, k, p, q] * x_np[coord]

dw[k, c, r, s] = acc

return dw
Loading