Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add conv2d_backward_weight op (without topi) #9954

Merged
merged 27 commits into from
Jan 19, 2022

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 18, 2022

This PR adds a Relay op for the gradient of conv2d op with respect to weight (wgrad for short). It is implemented simply using the existing equivalent expressions in

grad = tile(grad, [1, in_channel // attrs.groups, 1, 1])
grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow
data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw
backward_weight = _nn.conv2d(
data,
grad,
strides=attrs.dilation,
padding=attrs.padding,
dilation=attrs.strides,
groups=in_channel * batch,
)
# infer shape of backward_weight
padded_weight_grad_h = (
in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom
) // dilation_h + 1
padded_weight_grad_w = (
in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right
) // dilation_w + 1
backward_weight = reshape(
backward_weight,
[
batch,
in_channel // attrs.groups,
out_channel,
padded_weight_grad_h,
padded_weight_grad_w,
],
)
backward_weight = _sum(backward_weight, axis=0)
backward_weight = transpose(backward_weight, [1, 0, 2, 3])
assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(
backward_weight,
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
)
and make it the target for conv2d_backward_weight op legalization. So no topi op has been added for now.

The motivation for introducing this op is threefold:

  • Both cutlass and cuDNN have a dedicated op for wgrad. Having an op that maps one-to-one makes it easy to offload this op to these backends.
  • The existing implementation, as a composition of nn.con2d and other ops, works but it is likely to be inefficient since it involves tile(grad, [1, in_channel // attrs.groups, 1, 1]) , a larger group conv2d workload and other post-processing (sum, transpose, slice etc). A direct implementation would likely be much faster.
  • The third reason is more subtle but this was what necessitated this PR. If I want to use cuDNN or cutlass wgrad with NHWC layout, I'd run convert_layout pass on a backward graph, resulting in.
  %1 = tile(%dy, reps=[1, 4, 1, 1]) /* ty=Tensor[(2, 32, 32, 32), float32] */;
  %2 = reshape(%1, newshape=[-1, 1, 0, 0]) /* ty=Tensor[(64, 1, 32, 32), float32] */;
  %3 = layout_transform(%0, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 32, 32, 8), float32] */;
  %4 = layout_transform(%2, src_layout="OIHW", dst_layout="OHWI") /* ty=Tensor[(64, 32, 32, 1), float32] */;
  %5 = nn.conv2d(%3, %4, padding=[1, 1, 1, 1], groups=8, data_layout="NHWC", kernel_layout="OHWI") /* ty=Tensor[(1, 3, 3, 64), float32] */;
  %6 = layout_transform(%5, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 64, 3, 3), float32] */;
  %7 = reshape(%6, newshape=[2, 4, 8, 3, 3]) /* ty=Tensor[(2, 4, 8, 3, 3), float32] */;
  %8 = sum(%7, axis=[0]) /* ty=Tensor[(4, 8, 3, 3), float32] */;
  transpose(%8, axes=[1, 0, 2, 3]) /* ty=Tensor[(8, 4, 3, 3), float32] */

I cannot pattern match this graph and extract NHWC conv2d wgrad, since layout_transform ops are "too close" to nn.conv2d. I need them to happen before tile and after transpose. So this anti-behavior is the strong reason to want a dedicated wgrad op representation in Relay.

cc @vinx13 @tkonolige @comaniac @YuchenJin

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @masahi!

@masahi masahi merged commit fd5915a into apache:main Jan 19, 2022
yuanfz98 pushed a commit to yuanfz98/tvm that referenced this pull request Jan 24, 2022
* python plumbing

* add cpp def

* legalize worked

* clean up

* layout conversion doesnt work

* extract wgrad body

* fix convert layout

* black

* fix kernel size

* revert irrelevant change

* add doc, clarify the meanings of parameters

* update layout convert

* test passed

* fixed layout conversion

* update convert layout

* remove print

* remove layout convert for now

* minor fix

* removed unused import

* add wgrad python reference

* add test stub

* add doc

* test other stride and pad

* tweak

* more pylint filter

* fix typo in doc

* swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad)
crazydemo pushed a commit to crazydemo/tvm that referenced this pull request Jan 27, 2022
* python plumbing

* add cpp def

* legalize worked

* clean up

* layout conversion doesnt work

* extract wgrad body

* fix convert layout

* black

* fix kernel size

* revert irrelevant change

* add doc, clarify the meanings of parameters

* update layout convert

* test passed

* fixed layout conversion

* update convert layout

* remove print

* remove layout convert for now

* minor fix

* removed unused import

* add wgrad python reference

* add test stub

* add doc

* test other stride and pad

* tweak

* more pylint filter

* fix typo in doc

* swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad)
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* python plumbing

* add cpp def

* legalize worked

* clean up

* layout conversion doesnt work

* extract wgrad body

* fix convert layout

* black

* fix kernel size

* revert irrelevant change

* add doc, clarify the meanings of parameters

* update layout convert

* test passed

* fixed layout conversion

* update convert layout

* remove print

* remove layout convert for now

* minor fix

* removed unused import

* add wgrad python reference

* add test stub

* add doc

* test other stride and pad

* tweak

* more pylint filter

* fix typo in doc

* swap arg order (data, grad) to be consistent with conv2d_transpose(dgrad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants