Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Slice op relu #11449

Merged
merged 8 commits into from
Jul 19, 2022
Merged

[Hexagon] Slice op relu #11449

merged 8 commits into from
Jul 19, 2022

Conversation

arangasa
Copy link
Contributor

@arangasa arangasa commented May 25, 2022

Raising PR on behalf of rasagna-quic (author)

Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

cc @mehrdadh

@Lunderberg Lunderberg changed the title Slice op relu [Hexagon] Slice op relu May 25, 2022
def relu_te_compute(Input, out_shape, dtype):
x = tvm.tir.const(0, dtype)
Output = te.compute(
out_shape, lambda n, h, w, c: tvm.te.max(Input[n, h, w, c], x), name="reluf16"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using out_shape as an argument to te.compute, I'd recommend using Input.shape. That way, the out_shape parameter could be removed, and the user wouldn't need to specify it independent of the Input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out.

j3, j4 = sch.split(j2, [None, 2])
sch.reorder(n, i1, j1, k1, i2, j3, k2, j4)
sch.transform_layout(block, 0, "read", transform_crouton_activation)
sch.set_axis_separator(block, 0, "read", [4])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, after #11269 lands, the set_axis_separator will be made automatically based on IndexMap.AXIS_SEPARATOR, similar to how it is handled in TE-based schedules.

@mehrdadh
Copy link
Member

@arangasa there are couple of lint issues. Please fix those so we can see the test pipeline.

@Lunderberg
Copy link
Contributor

Looks like there are still some lint errors. The CI's lint can be reproduced locally with tests/scripts/ci.py lint, which runs the linter using the same docker image and script as is used in the pre-commit CI.

@rasagna-quic
Copy link
Contributor

@mehrdadh @Lunderberg : Hi, can you please let me know if any more changes are needed? If not can you please approve the PR.

@mehrdadh
Copy link
Member

LGTM! I'll wait for @Lunderberg to take another look.

@rasagna-quic
Copy link
Contributor

@Lunderberg : Hi Eric, can you please review this patch again? Thanks.

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Lunderberg Lunderberg merged commit a1f27e5 into apache:main Jul 19, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* Add support for relu slice op.

* Format code

* removing out_shape in relu def and lint issues

* removing out_shape in relu def and lint issues

* Changes as per the new format

Co-authored-by: Venkat Rasagna Komatireddy <89959097+rasagna-quic@users.noreply.github.com>
Co-authored-by: Venkat Rasagna Reddy Komatireddy <rasagna@hu-rasagna-hyd.qualcomm.com>
mikeseven pushed a commit to mikeseven/tvm that referenced this pull request Sep 27, 2023
* Add support for relu slice op.

* Format code

* removing out_shape in relu def and lint issues

* removing out_shape in relu def and lint issues

* Changes as per the new format

Co-authored-by: Venkat Rasagna Komatireddy <89959097+rasagna-quic@users.noreply.github.com>
Co-authored-by: Venkat Rasagna Reddy Komatireddy <rasagna@hu-rasagna-hyd.qualcomm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants