Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDNN] Support gradient kernels #9986

Merged
merged 21 commits into from
Jan 22, 2022
Merged

Conversation

masahi
Copy link
Member

@masahi masahi commented Jan 19, 2022

This adds support for computing conv2d gradient wrt data and filter by cudnn. See the change in test_cudnn.py. The diff is large but majority of them are boilerplate.

Also enabled offloading the relay conv2d_backward_weight op introduced in #9954 directly to the cudnn equivalent without legalization. The op strategy for that op is only populated when cudnn is in enabled in the target, otherwise emit an error and suggest running Legalize pass. This feature can go into a separate PR if desired (the last part of commits).

After opening this PR, I realized that I haven't enabled offloading relay conv2d_transpose op to the cudnn dgrad kernel. I'll follow up later. (UPDATE: Done, can also go into a separate PR)

cc @vinx13 @comaniac @tkonolige @Laurawly @Hzfengsy @YuchenJin

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

commit 426e5dc
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:48:53 2022 +0900

    black

commit 211a58b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:43:52 2022 +0900

    fp16 also works

commit c2a34d4
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:36:36 2022 +0900

    nhwc test also worked

commit c0609ab
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:21:23 2022 +0900

    nchw test worked

commit 2bf68c7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:41:35 2022 +0900

    add test stub

commit c86b128
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:32:09 2022 +0900

    add python definition stub

commit 3166952
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:57:18 2022 +0900

    bwd filter compiled

commit e311ba3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:27:55 2022 +0900

    dgrad compiled

commit 47f35be
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:16:43 2022 +0900

    add dgrad stub

commit ebed032
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 17:01:56 2022 +0900

    cpplint

commit 834f54a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:55:58 2022 +0900

    remove cudnn get output

commit dcbd9c9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:28:07 2022 +0900

    more refactor

commit 146464e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 15:57:35 2022 +0900

    Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc
Copy link
Contributor

@YuchenJin YuchenJin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @masahi!

@masahi masahi merged commit d35b858 into apache:main Jan 22, 2022
@masahi
Copy link
Member Author

masahi commented Jan 22, 2022

Thanks @comaniac @Hzfengsy @YuchenJin

yuanfz98 pushed a commit to yuanfz98/tvm that referenced this pull request Jan 24, 2022
* Dgrad nchw, nhwc, fp16 working

commit 426e5dc
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:48:53 2022 +0900

    black

commit 211a58b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:43:52 2022 +0900

    fp16 also works

commit c2a34d4
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:36:36 2022 +0900

    nhwc test also worked

commit c0609ab
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:21:23 2022 +0900

    nchw test worked

commit 2bf68c7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:41:35 2022 +0900

    add test stub

commit c86b128
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:32:09 2022 +0900

    add python definition stub

commit 3166952
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:57:18 2022 +0900

    bwd filter compiled

commit e311ba3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:27:55 2022 +0900

    dgrad compiled

commit 47f35be
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:16:43 2022 +0900

    add dgrad stub

commit ebed032
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 17:01:56 2022 +0900

    cpplint

commit 834f54a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:55:58 2022 +0900

    remove cudnn get output

commit dcbd9c9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:28:07 2022 +0900

    more refactor

commit 146464e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 15:57:35 2022 +0900

    Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc

* add python function for cudnn wgrad

* adding wgrad test

* black

* wgrad nchw and nhwc worked

* remove bwd algo name stuff

* compute output shape properly

* swap arg order in wgrad

* add kernel size arg in test

* black

* cleanup

* more fix

* fix dgrad test

* support running relay conv2d_backward_weight directly with cudnn

* black

* refactor reference function to support nhwc

* removed unused function

* lint

* enable offloading conv2d_transpose to cudnn dgrad

* relax tol

* name fix, remove print
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* Dgrad nchw, nhwc, fp16 working

commit 426e5dc
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:48:53 2022 +0900

    black

commit 211a58b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:43:52 2022 +0900

    fp16 also works

commit c2a34d4
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:36:36 2022 +0900

    nhwc test also worked

commit c0609ab
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 11:21:23 2022 +0900

    nchw test worked

commit 2bf68c7
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:41:35 2022 +0900

    add test stub

commit c86b128
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 10:32:09 2022 +0900

    add python definition stub

commit 3166952
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:57:18 2022 +0900

    bwd filter compiled

commit e311ba3
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:27:55 2022 +0900

    dgrad compiled

commit 47f35be
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Tue Jan 18 06:16:43 2022 +0900

    add dgrad stub

commit ebed032
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 17:01:56 2022 +0900

    cpplint

commit 834f54a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:55:58 2022 +0900

    remove cudnn get output

commit dcbd9c9
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 16:28:07 2022 +0900

    more refactor

commit 146464e
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Jan 17 15:57:35 2022 +0900

    Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc

* add python function for cudnn wgrad

* adding wgrad test

* black

* wgrad nchw and nhwc worked

* remove bwd algo name stuff

* compute output shape properly

* swap arg order in wgrad

* add kernel size arg in test

* black

* cleanup

* more fix

* fix dgrad test

* support running relay conv2d_backward_weight directly with cudnn

* black

* refactor reference function to support nhwc

* removed unused function

* lint

* enable offloading conv2d_transpose to cudnn dgrad

* relax tol

* name fix, remove print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants