diff --git a/.github/workflows/cc_bot.yml b/.github/workflows/cc_bot.yml new file mode 100644 index 0000000000000..873fafa82a601 --- /dev/null +++ b/.github/workflows/cc_bot.yml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# GH actions. +# We use it to cover windows and mac builds +# Jenkins is still the primary CI + +name: PR + +on: + # See https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#pull_request_target + pull_request_target: + types: [assigned, opened, synchronize, reopened, edited, ready_for_review] + +concurrency: + group: PR-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + cc-reviewers: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + submodules: "recursive" + - name: Add cc'ed reviewers + env: + PR: ${{ toJson(github.event.pull_request) }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set -eux + python tests/scripts/github_cc_reviewers.py || echo step failed diff --git a/Jenkinsfile b/Jenkinsfile index 64b8465a1ea1b..1cd8575fa3fdc 100755 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -46,11 +46,11 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = "tlcpack/ci-lint:v0.67" -ci_gpu = "tlcpack/ci-gpu:v0.79" +ci_gpu = "tlcpack/ci-gpu:v0.80" ci_cpu = "tlcpack/ci-cpu:v0.80" ci_wasm = "tlcpack/ci-wasm:v0.71" ci_i386 = "tlcpack/ci-i386:v0.74" -ci_qemu = "tlcpack/ci-qemu:v0.08" +ci_qemu = "tlcpack/ci-qemu:v0.09" ci_arm = "tlcpack/ci-arm:v0.06" // <--- End of regex-scanned config. diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index d982601fc1bb1..ab6f1d2ea52ac 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -301,9 +301,6 @@ if(USE_HEXAGON_RPC) ${HEXAGON_RPC_OUTPUT} COPYONLY) set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${HEXAGON_RPC_OUTPUT}") - - # Used in `src/target/llvm/llvm_common.h` - add_definitions(-DTVM_USE_HEXAGON_LLVM) endif() if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") diff --git a/cmake/modules/HexagonSDK.cmake b/cmake/modules/HexagonSDK.cmake index e9bc2d2f873f5..42785116214e1 100644 --- a/cmake/modules/HexagonSDK.cmake +++ b/cmake/modules/HexagonSDK.cmake @@ -31,7 +31,7 @@ function(find_hexagon_sdk_root HEXAGON_SDK_PATH HEXAGON_ARCH) # Initial verification of the Hexagon SDK. message(STATUS "Checking Hexagon SDK root: ${HEXAGON_SDK_PATH}") - tvm_file_glob(GLOB_RECURSE VERSION_HEADERS "${HEXAGON_SDK_PATH}/*/version.h") + file(GLOB_RECURSE VERSION_HEADERS "${HEXAGON_SDK_PATH}/*/version.h") if(VERSION_HEADERS) foreach(HEADER IN LISTS VERSION_HEADERS) if(HEADER MATCHES "incs/version.h$") diff --git a/docker/Dockerfile.ci_arm b/docker/Dockerfile.ci_arm index 974998b9d6fe1..73ff0aee7d805 100644 --- a/docker/Dockerfile.ci_arm +++ b/docker/Dockerfile.ci_arm @@ -26,6 +26,17 @@ RUN apt-get install -y ca-certificates gnupg2 COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh +# Rust env +COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh +RUN bash /install/ubuntu_install_rust.sh +ENV RUSTUP_HOME /opt/rust +ENV CARGO_HOME /opt/rust +ENV PATH $PATH:$CARGO_HOME/bin + +# sccache +COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh +RUN bash /install/ubuntu_install_sccache.sh + COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh RUN bash /install/ubuntu_install_llvm.sh @@ -33,7 +44,7 @@ COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh RUN bash /install/ubuntu_install_cmake_source.sh diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index c24df99de6ecc..962d738a9fc2e 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -28,7 +28,7 @@ COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh @@ -53,6 +53,10 @@ ENV RUSTUP_HOME /opt/rust ENV CARGO_HOME /opt/rust ENV PATH $PATH:$CARGO_HOME/bin +# wasmtime +COPY install/ubuntu_install_wasmtime.sh /install/ubuntu_install_wasmtime.sh +RUN bash /install/ubuntu_install_wasmtime.sh + # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh @@ -126,3 +130,7 @@ ENV PATH /opt/arm/gcc-arm-none-eabi/bin:/opt/arm/FVP_Corstone_SSE-300/models/Lin # PaddlePaddle deps COPY install/ubuntu_install_paddle.sh /install/ubuntu_install_paddle.sh RUN bash /install/ubuntu_install_paddle.sh + +# sccache +COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh +RUN bash /install/ubuntu_install_sccache.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 3d189e3858b0a..d9ca255502fa5 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -29,7 +29,7 @@ COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false COPY install/ubuntu1804_install_llvm.sh /install/ubuntu1804_install_llvm.sh RUN bash /install/ubuntu1804_install_llvm.sh @@ -101,6 +101,10 @@ ENV RUSTUP_HOME /opt/rust ENV CARGO_HOME /opt/rust ENV PATH $PATH:$CARGO_HOME/bin +# wasmtime +COPY install/ubuntu_install_wasmtime.sh /install/ubuntu_install_wasmtime.sh +RUN bash /install/ubuntu_install_wasmtime.sh + # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh @@ -117,6 +121,10 @@ RUN bash /install/ubuntu_install_universal.sh COPY install/ubuntu_install_papi.sh /install/ubuntu_install_papi.sh RUN bash /install/ubuntu_install_papi.sh "cuda rocm" +# sccache +COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh +RUN bash /install/ubuntu_install_sccache.sh + # Environment variables ENV PATH=/usr/local/nvidia/bin:${PATH} ENV PATH=/usr/local/cuda/bin:${PATH} diff --git a/docker/Dockerfile.ci_i386 b/docker/Dockerfile.ci_i386 index 564731c12d36e..d4ce370c42051 100644 --- a/docker/Dockerfile.ci_i386 +++ b/docker/Dockerfile.ci_i386 @@ -32,7 +32,7 @@ COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh RUN bash /install/ubuntu_install_python.sh # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false COPY install/ubuntu_install_cmake_source.sh /install/ubuntu_install_cmake_source.sh RUN bash /install/ubuntu_install_cmake_source.sh diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index 20bcfe6de903a..08d3ebf14e14d 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -28,7 +28,7 @@ COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false RUN apt-get update && apt-get install -y doxygen graphviz curl @@ -41,6 +41,10 @@ ENV RUSTUP_HOME /opt/rust ENV CARGO_HOME /opt/rust ENV PATH $PATH:$CARGO_HOME/bin +# wasmtime +COPY install/ubuntu_install_wasmtime.sh /install/ubuntu_install_wasmtime.sh +RUN bash /install/ubuntu_install_wasmtime.sh + # java deps for rat COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh RUN bash /install/ubuntu_install_java.sh diff --git a/docker/Dockerfile.ci_qemu b/docker/Dockerfile.ci_qemu index f4f774697f2fe..2cae59c35d672 100644 --- a/docker/Dockerfile.ci_qemu +++ b/docker/Dockerfile.ci_qemu @@ -29,7 +29,7 @@ RUN bash /install/ubuntu1804_install_python_venv.sh ENV PATH=/opt/tvm-venv/bin:/opt/zephyr-sdk/sysroots/x86_64-pokysdk-linux/usr/bin:$PATH # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh @@ -42,6 +42,7 @@ COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh RUN bash /install/ubuntu_install_rust.sh ENV RUSTUP_HOME /opt/rust ENV CARGO_HOME /opt/rust +ENV PATH $PATH:$CARGO_HOME/bin # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh @@ -59,6 +60,10 @@ RUN bash /install/ubuntu_install_tensorflow.sh COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh RUN bash /install/ubuntu_install_tflite.sh +# sccache +COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh +RUN bash /install/ubuntu_install_sccache.sh + # Zephyr SDK deps COPY install/ubuntu_install_zephyr.sh /install/ubuntu_install_zephyr.sh COPY install/ubuntu_init_zephyr_project.sh /install/ubuntu_init_zephyr_project.sh diff --git a/docker/Dockerfile.ci_wasm b/docker/Dockerfile.ci_wasm index 03a1ded5572fa..1c901f12a2ec6 100644 --- a/docker/Dockerfile.ci_wasm +++ b/docker/Dockerfile.ci_wasm @@ -25,7 +25,7 @@ COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh RUN bash /install/ubuntu1804_install_python.sh # Globally disable pip cache -RUN pip config set global.cache-dir false +RUN pip config set global.no-cache-dir false COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh diff --git a/docker/build.sh b/docker/build.sh index 3ac74835f0abb..39fe7a0242461 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -199,7 +199,9 @@ if [[ -n ${COMMAND} ]]; then echo ${DOCKER_BINARY} ${DOCKER_BINARY} run --rm --pid=host \ -v ${WORKSPACE}:/workspace \ + ${SSH_AUTH_SOCK:+-v $SSH_AUTH_SOCK:/ssh-agent} \ -w /workspace \ + ${SSH_AUTH_SOCK:+-e "SSH_AUTH_SOCK=/ssh-agent"} \ -e "CI_BUILD_HOME=/workspace" \ -e "CI_BUILD_USER=$(id -u -n)" \ -e "CI_BUILD_UID=$(id -u)" \ diff --git a/docker/install/ubuntu_install_rust.sh b/docker/install/ubuntu_install_rust.sh index c07c0038a36b4..58f8256b03b3b 100755 --- a/docker/install/ubuntu_install_rust.sh +++ b/docker/install/ubuntu_install_rust.sh @@ -16,10 +16,7 @@ # specific language governing permissions and limitations # under the License. -set -e -set -u -set -o pipefail - +set -euxo pipefail export RUSTUP_HOME=/opt/rust export CARGO_HOME=/opt/rust @@ -29,12 +26,5 @@ export PATH=$CARGO_HOME/bin:$PATH rustup component add rustfmt rustup component add clippy -# install wasmtime -apt-get install -y --no-install-recommends libc6-dev-i386 -export WASMTIME_HOME=/opt/wasmtime -curl https://wasmtime.dev/install.sh -sSf | bash -export PATH="${WASMTIME_HOME}/bin:${PATH}" -rustup target add wasm32-wasi - # make rust usable by all users chmod -R a+w /opt/rust diff --git a/docker/install/ubuntu_install_sccache.sh b/docker/install/ubuntu_install_sccache.sh new file mode 100644 index 0000000000000..79ce1535c71ec --- /dev/null +++ b/docker/install/ubuntu_install_sccache.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +cargo install sccache + +# The docs specifically recommend hard links: https://github.com/mozilla/sccache#known-caveats +mkdir /opt/sccache +ln "$(which sccache)" /opt/sccache/cc +ln "$(which sccache)" /opt/sccache/c++ diff --git a/docker/install/ubuntu_install_wasmtime.sh b/docker/install/ubuntu_install_wasmtime.sh new file mode 100644 index 0000000000000..d1285b36b4291 --- /dev/null +++ b/docker/install/ubuntu_install_wasmtime.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +# install wasmtime (note: requires ubuntu_install_rust.sh to run first) +apt-get install -y --no-install-recommends libc6-dev-i386 +export WASMTIME_HOME=/opt/wasmtime +curl https://wasmtime.dev/install.sh -sSf | bash +export PATH="${WASMTIME_HOME}/bin:${PATH}" +rustup target add wasm32-wasi diff --git a/docs/conf.py b/docs/conf.py index e74df6cf1e0ed..2f650a88c9361 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -255,6 +255,7 @@ def git_describe_version(original_version): "introduction.py", "install.py", "tvmc_command_line_driver.py", + "tvmc_python.py", "autotvm_relay_x86.py", "tensor_expr_get_started.py", "autotvm_matmul_x86.py", diff --git a/docs/contribute/ci.rst b/docs/contribute/ci.rst new file mode 100644 index 0000000000000..1e78e9139eb50 --- /dev/null +++ b/docs/contribute/ci.rst @@ -0,0 +1,176 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _ci_guide: + +Using TVM's CI +============== + +TVM uses Jenkins for running Linux continuous integration (CI) tests on +`branches `_ and +`pull requests `_ through a +build configuration specified in a `Jenkinsfile `_. +Non-critical jobs run in GitHub Actions for Windows and MacOS jobs. + +A standard CI run looks something like this viewed in `Jenkins' BlueOcean viewer `_. +CI runs usually take several hours to complete and pull requests (PRs) cannot be merged before CI +has successfully completed. To diagnose failing steps, click through to the failing +pipeline stage then to the failing step to see the output logs. + +.. image:: https://github.com/tlc-pack/web-data/raw/main/images/contribute/ci.png + :width: 800 + :alt: The Jenkins UI for a CI run + + +Debugging Failures +****************** + +When CI fails for some reason, there are several methods to diagnose the issue. + +Jenkins Logs +------------ + +.. |pytest| replace:: ``pytest`` +.. _pytest: https://docs.pytest.org/en/6.2.x/ + +The first place to look for a failure is in the CI logs, follow the red Xs on +the failing job to view the logs. Note: + +* Jenkins does not display the full log by default, at the top of the log viewer + is a button "Show complete log" which will take you to a plaintext version of the log +* |pytest|_ failures are summarized at the bottom of the log but you will likely + need to scroll up to view the actual failure. + +Reproduce Failures +------------------ + +Most TVM Python tests run under |pytest|_ and +can be run as described in :ref:`pr-testing`. For a closer environment to the one +than runs in CI you can run the docker images directly, build TVM, and execute +tests inside the container. See :ref:`docker_images` for details. + +Keeping CI Green +**************** + +Developers rely on the TVM CI to get signal on their PRs before merging. +Occasionally breakages slip through and break ``main``, which in turn causes +the same error to show up on an PR that is based on the broken commit(s). Broken +commits can be identified `through GitHub `_ +via the commit status icon or via `Jenkins `_. +In these situations it is possible to either revert the offending commit or +submit a forward fix to address the issue. It is up to the committer and commit +author which option to choose, keeping in mind that a broken CI affects all TVM +developers and should be fixed as soon as possible. + +Skip CI for Reverts +------------------- + +For reverts and trivial forward fixes, adding ``[skip ci]`` to the revert's +commit message will cause CI to shortcut and only run lint. Committers should +take care that they only merge CI-skipped PRs to fix a failure on ``main`` and +not in cases where the submitter wants to shortcut CI to merge a change faster. + +.. code:: bash + + # Revert HEAD commit, make sure to insert '[skip ci]' at the beginning of + # the commit subject + git revert HEAD + git checkout -b my_fix + # After you have pushed your branch, create a PR as usual. + git push my_repo + # Example: Skip CI on a branch with an existing PR + # Adding this commit to an existing branch will cause a new CI run where + # Jenkins is skipped + git commit --allow-empty --message "[skip ci] Trigger skipped CI" + git push my_repo + +Handling Flaky Failures +*********************** + +.. https://stackoverflow.com/questions/4743845/format-text-in-a-link-in-restructuredtext/4836544#4836544 +.. |pytest's @xfail decorator| replace:: pytest's ``@xfail`` decorator +.. _pytest's @xfail decorator: https://docs.pytest.org/en/6.2.x/skipping.html#xfail-mark-test-functions-as-expected-to-fail +.. |strict=True| replace:: ``strict=True`` +.. _strict=True: https://docs.pytest.org/en/6.2.x/skipping.html#strict-parameter + +If you notice a failure on your PR that seems unrelated to your change, you should +search `recent GitHub issues related to flaky tests `_ and +`file a new issue `_ +if you don't see any reports of the failure. If a certain test or class of tests affects +several PRs or commits on ``main`` with flaky failures, the test should be disabled via +|pytest's @xfail decorator|_ with |strict=True|_ and the relevant issue linked in the +disabling PR. + +.. code:: python + + @pytest.mark.xfail(strict=False, reason="Flaky test: https://github.com/apache/tvm/issues/1234 + def test_something_flaky(): + pass + +``ci-docker-staging`` +********************* + +The `ci-docker-staging `_ +branch is used to test updates to Docker images and ``Jenkinsfile`` changes. When +running a build for a normal PR from a forked repository, Jenkins uses the code +from the PR except for the ``Jenkinsfile`` itself, which comes from the base branch. +When branches are built, the ``Jenkinsfile`` in the branch is used, so a committer +with write access must push PRs to a branch in apache/tvm to properly test +``Jenkinsfile`` changes. If your PR makes changes to the ``Jenkinsfile``, make sure +to @ a `committer `_ +and ask them to push your PR as a branch to test the changes. + +.. _docker_images: + +Docker Images +************* + +.. |top_of_the_Jenkinsfile| replace:: top of the ``Jenkinsfile`` +.. _top_of_the_Jenkinsfile: https://github.com/apache/tvm/blob/7481a297740f073b193a3f09b3e27f056e8c7f2e/Jenkinsfile#L48-L54 + +Each CI job runs most of its work inside a Docker container, built from files +in the `docker/ `_ folder. These +files are built nightly in Jenkins via the `docker-images-ci `_ job. +The images for these containers are hosted in the `tlcpack Docker Hub `_ +and referenced at the |top_of_the_Jenkinsfile|_. These can be inspected and run +locally via standard Docker commands. + +.. code:: bash + + # Beware: CI images can be several GB in size + # Get a bare docker shell in the ci-gpu container + docker run -it tlcpack/ci-gpu:v0.78 /bin/bash + +``docker/bash.sh`` will automatically grab the latest image from the ``Jenkinsfile`` +and help in mounting your current directory. + +.. code:: bash + + # Run the ci_cpu image specified in Jenkinsfile + cd tvm + bash docker/bash.sh ci_cpu + # the tvm directory is automatically mounted + # example: build tvm (note: this will overrwrite build/) + $ ./tests/scripts/task_config_build_cpu.sh + $ ./tests/scripts/task_build.sh build -j32 + + +Reporting Issues +**************** + +Issues with CI should be `reported on GitHub `_ +with a link to the relevant jobs, commits, or PRs. diff --git a/docs/contribute/committer_guide.rst b/docs/contribute/committer_guide.rst index 68885b6b927a3..3dc5bf07f3cdd 100644 --- a/docs/contribute/committer_guide.rst +++ b/docs/contribute/committer_guide.rst @@ -92,7 +92,7 @@ when they actively manage outstanding PRs, but watch the community less frequently in the rest of the time. Remember that your merit will never go away, so please -take your time and pace when contributing to the project:) +take your time and pace when contributing to the project :) Broad Collaboration @@ -101,37 +101,3 @@ Sometimes, we tend to only interact with people we know. However, broad collaborations are necessary to the success of the project. Try to keep that in mind, shepherd PRs for, and request code reviews from community members who you do not interact physically. - - -Keeping CI Green ----------------- -Developers rely on the TVM CI to get signal on their PRs before merging. -Occasionally breakges slip through and break ``main``, which in turn causes -the same error to show up on an PR that is based on the broken commit(s). -In these situations it is possible to either revert the offending commit or -submit a forward fix to address the issue. It is up to the committer and commit -author which option to choose, keeping in mind that a broken CI affects all TVM -developers and should be fixed as soon as possible. - -For reverts and trivial forward fixes, adding ``[skip ci]`` to the revert's -commit message will cause CI to shortcut and only run lint. Committers should -take care that they only merge CI-skipped PRs to fix a failure on ``main`` and -not in cases where the submitter wants to shortcut CI to merge a change faster. - -.. code:: bash - - # Example: Skip CI on a revert - # Revert HEAD commit, make sure to insert '[skip ci]' at the beginning of - # the commit subject - git revert HEAD - - git checkout -b my_fix - # After you have pushed your branch, create a PR as usual. - git push my_repo - - # Example: Skip CI on a branch with an existing PR - # Adding this commit to an existing branch will cause a new CI run where - # Jenkins is skipped - git commit --allow-empty --message "[skip ci] Trigger skipped CI" - git push my_repo - diff --git a/docs/contribute/community.rst b/docs/contribute/community.rst index 8867202a674c3..c41c7f394dd50 100644 --- a/docs/contribute/community.rst +++ b/docs/contribute/community.rst @@ -17,8 +17,8 @@ .. _community_guide: -TVM Community Guideline -======================= +TVM Community Guidelines +======================== TVM adopts the Apache style model and governs by merit. We believe that it is important to create an inclusive community where everyone can use, contribute to, and influence the direction of the project. See `CONTRIBUTORS.md `_ for the current list of contributors. @@ -42,7 +42,7 @@ Committers are individuals who are granted the write access to the project. A co - Quality of contributions: High-quality, readable code contributions indicated by pull requests that can be merged without a substantial code review. History of creating clean, maintainable code and including good test cases. Informative code reviews to help other contributors that adhere to a good standard. - Community involvement: active participation in the discussion forum, promote the projects via tutorials, talks and outreach. We encourage committers to collaborate broadly, e.g. do code reviews and discuss designs with community members that they do not interact physically. -The Project Management Committee(PMC) consists group of active committers that moderate the discussion, manage the project release, and proposes new committer/PMC members. Potential candidates are usually proposed via an internal discussion among PMCs, followed by a consensus approval, i.e. least 3 +1 votes, and no vetoes. Any veto must be accompanied by reasoning. PMCs should serve the community by upholding the community practices and guidelines TVM a better community for everyone. PMCs should strive to only nominate new candidates outside of their own organization. +The `Project Management Committee (PMC) `_ consists group of active committers that moderate the discussion, manage the project release, and proposes new committer/PMC members. Potential candidates are usually proposed via an internal discussion among PMCs, followed by a consensus approval, (i.e. at least 3 +1 votes, and no vetoes). Any veto must be accompanied by reasoning. PMCs should serve the community by upholding the community practices and guidelines TVM a better community for everyone. PMCs should strive to only nominate new candidates outside of their own organization. Reviewers diff --git a/docs/contribute/document.rst b/docs/contribute/document.rst index e3d12e83865d6..8658c0fea5062 100644 --- a/docs/contribute/document.rst +++ b/docs/contribute/document.rst @@ -26,7 +26,8 @@ it is a "simple, comprehensive and nearly universally-applicable scheme. It is proven in practice across a wide variety of fields and applications." This document describes the organization of TVM documentation, and how to write -new documentation. +new documentation. See `docs/README.md `_ +for instructions on building the docs. The Four Document Types *********************** @@ -40,8 +41,8 @@ without necessarily explaining why the software works the way it does. Those explanations can be saved for other document types. An introductory tutorial focuses on a successful first experience. These are the most important docs to turning newcomers into new users and developers. A fully end-to-end -tutorial— from installing TVM and supporting ML software, to creating and -training a model, to compiling to different architectures—will give a new +tutorial — from installing TVM and supporting ML software, to creating and +training a model, to compiling to different architectures — will give a new user the opportunity to use TVM in the most efficient way possible. A tutorial teaches a beginner something they need to know. This is in contrast with a how-to, which is meant to be an answer to a question that a user with some @@ -92,7 +93,7 @@ Within these documents you can explore contradictory and conflicting position, and help the reader make sense of how and why the software was built the way it is. It's not the place for how-tos and descriptions on how to accomplish tasks. They instead focus on higher level concepts that help with the understanding of -the project. Generally these are written by the architects and developers of +the project. Generally these are written by the architects and developers of the project, but can useful to help both users and developers to have a deeper understanding of why the software works the way it does, and how to contribute to it in ways that are consistent with the underlying design principles. @@ -124,18 +125,22 @@ Technical Details ***************** We use the `Sphinx `_ for the main documentation. -Sphinx support both the reStructuredText and markdown. When possible, we -encourage to use reStructuredText as it has richer features. Note that the -python doc-string and tutorials allow you to embed reStructuredText syntax. +Sphinx supports both reStructuredText and markdown. When possible, we +encourage reStructuredText as it has richer features. Note that the +Python doc-string and tutorials allow you to embed reStructuredText syntax. + +See +`docs/README.md `_ +for instructions on building the docs. Python Reference Documentation ------------------------------ -We use `numpydoc `_ format to -document the function and classes. The following snippet gives an example -docstring. We always document all the public functions, when necessary, -provide an usage example of the features we support(as shown below). +We use the `numpydoc `_ format to +document the function and classes. The following snippet gives an example +docstring. We always document all the public functions, when necessary, +provide an usage example of the features we support (as shown below). .. code:: python @@ -167,19 +172,19 @@ provide an usage example of the features we support(as shown below). """ return rv1 -Be careful to leave blank lines between sections of your documents. In the -above case, there has to be a blank line before `Parameters`, `Returns` and -`Examples` in order for the doc to be built correctly. To add a new function to -the doc, we need to add the `sphinx.autodoc -`_ rules to the -`docs/api/python `_). +Be careful to leave blank lines between sections of your documents. In the +above case, there has to be a blank line before ``Parameters``, ``Returns`` and +``Examples`` in order for the doc to be built correctly. To add a new function to +the docs, we need to add the `sphinx.autodoc +`_ rules to +`docs/reference/api/python `_). You can refer to the existing files under this folder on how to add the functions. C++ Reference Documentation --------------------------- -We use the doxgen format to document c++ functions. The following snippet +We use the doxygen format to document c++ functions. The following snippet shows an example of c++ docstring. .. code:: c++ @@ -200,15 +205,15 @@ add comments about code logics to improve readability. Sphinx Gallery How-Tos ---------------------- -We use the `sphinx-gallery `_ to build many -python how-tos. You can find the source code under `gallery -`_ quite self explanatory. +We use `sphinx-gallery `_ to build many +Python how-tos. You can find the source code under `gallery +`_. One thing that worth noting is that the comment blocks are written in reStructuredText instead of markdown so be aware of the syntax. -The how-to code will run on our build server to generate the document page. So +The how-to code will run on our build server to generate the document page. So we may have a restriction like not being able to access a remote Raspberry Pi, -in such case add a flag variable to the tutorial (e.g. `use_rasp`) and allow +in such case add a flag variable to the tutorial (e.g. ``use_rasp``) and allow users to easily switch to the real device by changing one flag. Then use the existing environment to demonstrate the usage. @@ -218,7 +223,7 @@ If you add a new categorization of how-to, you will need to add references to Refer to Another Location in the Document ----------------------------------------- -Please use sphinx's `:ref:` markup to refer to another location in the same doc. +Please use sphinx's ``:ref:`` markup to refer to another location in the same doc. .. code-block:: rst diff --git a/docs/contribute/git_howto.rst b/docs/contribute/git_howto.rst index 765153be220be..1271aad8a2684 100644 --- a/docs/contribute/git_howto.rst +++ b/docs/contribute/git_howto.rst @@ -23,39 +23,39 @@ Git Usage Tips Here are some tips for git workflow. -How to resolve a conflict with `main` -------------------------------------- +How to resolve a conflict with ``main`` +--------------------------------------- - First rebase to most recent main -.. code:: bash + .. code:: bash - # The first two steps can be skipped after you do it once. - git remote add upstream [url to tvm repo] - git fetch upstream - git rebase upstream/main + # The first two steps can be skipped after you do it once. + git remote add upstream [url to tvm repo] + git fetch upstream + git rebase upstream/main -- The git may show some conflicts it cannot merge, say `conflicted.py`. +- The git may show some conflicts it cannot merge, say ``conflicted.py``. - Manually modify the file to resolve the conflict. - After you resolved the conflict, mark it as resolved by -.. code:: bash + .. code:: bash - git add conflicted.py + git add conflicted.py - Then you can continue rebase by -.. code:: bash + .. code:: bash - git rebase --continue + git rebase --continue - Finally push to your fork, you may need to force push here. -.. code:: bash + .. code:: bash - git push --force + git push --force How to combine multiple commits into one @@ -66,35 +66,36 @@ to create a PR with set of meaningful commits. You can do it by following steps. - Before doing so, configure the default editor of git if you haven't done so before. -.. code:: bash + .. code:: bash - git config core.editor the-editor-you-like + git config core.editor the-editor-you-like - Assume we want to merge last 3 commits, type the following commands -.. code:: bash + .. code:: bash - git rebase -i HEAD~3 + git rebase -i HEAD~3 -- It will pop up an text editor. Set the first commit as `pick`, and change later ones to `squash`. +- It will pop up an text editor. Set the first commit as ``pick``, and change later ones to ``squash``. - After you saved the file, it will pop up another text editor to ask you modify the combined commit message. - Push the changes to your fork, you need to force push. -.. code:: bash + .. code:: bash - git push --force + git push --force Reset to the most recent main branch ------------------------------------ You can always use git reset to reset your version to the most recent main. -Note that all your ***local changes will get lost***. +Note that **all your local changes will get lost**. So only do it when you do not have local changes or when your pull request just get merged. .. code:: bash - git reset --hard [hash tag of main] + git fetch origin main + git reset --hard FETCH_HEAD Recover a Previous Commit after Reset diff --git a/docs/contribute/index.rst b/docs/contribute/index.rst index acacfdc8a6e26..aa893dbccb72b 100644 --- a/docs/contribute/index.rst +++ b/docs/contribute/index.rst @@ -48,4 +48,5 @@ Here are guidelines for contributing to various aspect of the project: error_handling pull_request git_howto - release_process + ci + release_process \ No newline at end of file diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index 23d2b1441ce8f..226e693e2c724 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -86,6 +86,8 @@ Here is the protocol to update CI image: - Tag the new version as the latest. - Periodically cleanup the old versions on local workers +.. _pr-testing: + Testing ------- Even though we have hooks to run unit tests automatically for each pull request, it's always recommended to run unit tests diff --git a/docs/legacy_redirect.py b/docs/legacy_redirect.py index 0f1dee6dbf240..56e8d26d0ba38 100644 --- a/docs/legacy_redirect.py +++ b/docs/legacy_redirect.py @@ -242,6 +242,10 @@ "tutorials/get_started/tvmc_command_line_driver.html", "../../tutorial/tvmc_command_line_driver.html", ], + [ + "tutorials/get_started/tvmc_python.html", + "../../tutorial/tvmc_python.html", + ], ] redirect_template = """ diff --git a/gallery/tutorial/tvmc_python.py b/gallery/tutorial/tvmc_python.py new file mode 100644 index 0000000000000..1f685589730fe --- /dev/null +++ b/gallery/tutorial/tvmc_python.py @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Getting Starting using TVMC Python: a high-level API for TVM +============================================================= +**Author**: +`Jocelyn Shiue `_ + +Hi! Here we explain the scripting tool designed for the complete TVM beginner. 🙂 + +Before we get started let's get an example model if you don't already have one. +Follow the steps to download a resnet model via the terminal: + + .. code-block:: python + + mkdir myscripts + cd myscripts + wget https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet50-v2-7.onnx + mv resnet50-v2-7.onnx my_model.onnx + touch tvmcpythonintro.py + +Let's start editing the python file in your favorite text editor. +""" + +################################################################################ +# Step 0: Imports +# ~~~~~~~~~~~~~~~ +# +# .. code-block:: python +# +# from tvm.driver import tvmc +# +# + +################################################################################ +# Step 1: Load a model +# ~~~~~~~~~~~~~~~~~~~~ +# +# Let's import our model into tvmc. This step converts a machine learning model from +# a supported framework into TVM's high level graph representation language called Relay. +# This is to have a unified starting point for all models in tvm. The frameworks we currently +# support are: Keras, ONNX, Tensorflow, TFLite, and PyTorch. +# +# .. code-block:: python +# +# model = tvmc.load('my_model.onnx') #Step 1: Load +# +# If you'd like to see the Relay, you can run: +# ``model.summary()`` +# +# All frameworks support overwriting the input shapes with a shape_dict argument. +# For most frameworks this is optional, but for Pytorch this is necessary as +# TVM cannot automatically search for it. +# +# .. code-block:: python +# +# #model = tvmc.load(my_model, shape_dict={'input1' : [1, 2, 3, 4], 'input2' : [1, 2, 3, 4]}) #Step 1: Load + shape_dict +# +# A suggested way to see the model's input/shape_dict is via `netron `_. After opening the model, +# click the first node to see the name(s) and shape(s) in the inputs section. + + +################################################################################ +# Step 2: Compile +# ~~~~~~~~~~~~~~~ +# +# Now that our model is in Relay, our next step is to compile it to a desired +# hardware to run on. We refer to this hardware as a target. This compilation process +# translates the model from Relay into a lower-level language that the +# target machine can understand. +# +# In order to compile a model a tvm.target string is required. +# To learn more about tvm.targets and their options look at the `documentation `_. +# Some examples include: +# +# 1. cuda (Nvidia GPU) +# 2. llvm (CPU) +# 3. llvm -mcpu=cascadelake (Intel CPU) +# +# .. code-block:: python +# +# package = tvmc.compile(model, target="llvm") #Step 2: Compile +# +# +# The compilation step returns a package. +# + +################################################################################ +# Step 3: Run +# ~~~~~~~~~~~ +# +# The compiled package can now be run on the hardware target. The device +# input options are: CPU, Cuda, CL, Metal, and Vulkan. +# +# .. code-block:: python +# +# result = tvmc.run(package, device="cpu") #Step 3: Run +# +# And you can print the results: +# ``print(results)`` +# + +################################################################################ +# Step 1.5: Tune [Optional & Recommended] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Run speed can further be improved by tuning. This optional step uses +# machine learning to look at each operation within a model (a function) and +# tries to find a faster way to run it. We do this through a cost model, and +# benchmarking possible schedules. +# +# The target is the same as compile. +# +# .. code-block:: python +# +# tvmc.tune(model, target="llvm") #Step 1.5: Optional Tune +# +# The terminal output should look like: +# +# .. code-block:: python +# +# [Task 1/13] Current/Best: 82.00/ 106.29 GFLOPS | Progress: (48/769) | 18.56 s +# [Task 1/13] Current/Best: 54.47/ 113.50 GFLOPS | Progress: (240/769) | 85.36 s +# ..... +# +# There may be UserWarnings that can be ignored. +# This should make the end result faster, but it can take hours to tune. +# +# See the section 'Saving the Tuning Results' below. Be sure to pass the tuning +# results into compile if you want the results to apply. +# +# .. code-block:: python +# +# #tvmc.compile(model, target="llvm", tuning_records = "records.log") #Step 2: Compile + +################################################################################ +# Save and then start the process in the terminal: +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# .. code-block:: python +# +# python my_tvmc_script.py +# +# Note: Your fans may become very active +# + +################################################################################ +# Example results: +# ~~~~~~~~~~~~~~~~ +# +# .. code-block:: python +# +# Time elapsed for training: 18.99 s +# Execution time summary: +# mean (ms) max (ms) min (ms) std (ms) +# 25.24 26.12 24.89 0.38 +# +# +# Output Names: +# ['output_0'] +# + + +################################################################################ +# Additional TVMC Functionalities +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +################################################################################ +# Saving the model +# ~~~~~~~~~~~~~~~~ +# +# To make things faster for later, after loading the model (Step 1) save the Relay version. +# The model will then appear where you saved it for later in the coverted syntax. +# +# .. code-block:: python +# +# model = tvmc.load('my_model.onnx') #Step 1: Load +# model.save(desired_model_path) +# +# + +################################################################################ +# Saving the package +# ~~~~~~~~~~~~~~~~~~ +# +# After the model has been compiled (Step 2) the package also is also saveable. +# +# .. code-block:: python +# +# tvmc.compile(model, target="llvm", package_path="whatever") +# +# new_package = tvmc.TVMCPackage(package_path="whatever") +# result = tvmc.run(new_package) #Step 3: Run +# +# + +################################################################################ +# Using Autoscheduler +# ~~~~~~~~~~~~~~~~~~~ +# +# Use the next generation of tvm to enable potentially faster run speed results. +# The search space of the schedules is automatically generated unlike +# previously where they needed to be hand written. (Learn more: +# `1 `_, +# `2 `_) +# +# .. code-block:: python +# +# tvmc.tune(model, target="llvm", enable_autoscheduler = True) +# +# + +################################################################################ +# Saving the tuning results +# ~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# The tuning results can be saved in a file for later reuse. +# +# Method 1: +# .. code-block:: python +# +# log_file = "hello.json" +# +# # Run tuning +# tvmc.tune(model, target="llvm",tuning_records=log_file) +# +# ... +# +# # Later run tuning and reuse tuning results +# tvmc.tune(model, target="llvm",tuning_records=log_file) +# +# Method 2: +# .. code-block:: python +# +# # Run tuning +# tuning_records = tvmc.tune(model, target="llvm") +# +# ... +# +# # Later run tuning and reuse tuning results +# tvmc.tune(model, target="llvm",tuning_records=tuning_records) +# + +################################################################################ +# Tuning a more complex model: +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# If you notice T's printing that look like ``.........T.T..T..T..T.T.T.T.T.T.`` +# increase the searching time frame: +# +# .. code-block:: python +# +# tvmc.tune(model,trials=10000,timeout=10,) +# + +################################################################################ +# Compiling a model for a remote device: +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# A remote procedural call (RPC) is useful when you would like to compile for hardware +# that is not on your local machine. The tvmc methods support this. +# To set up the RPC server take a look at the 'Set up RPC Server on Device' +# section in this `document `_. +# +# Within the TVMC Script include the following and adjust accordingly: +# +# .. code-block:: python +# +# tvmc.tune( +# model, +# target=target, # Compilation target as string // Device to compile for +# target_host=target_host, # Host processor +# hostname=host_ip_address, #The IP address of an RPC tracker, used when benchmarking remotely. +# port=port_number, # The port of the RPC tracker to connect to. Defaults to 9090. +# rpc_key=your_key, # The RPC tracker key of the target device. Required when rpc_tracker is provided +# ) +# diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 8313da067f09f..95fce13df02f6 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -115,7 +115,6 @@ class ScheduleRule : public runtime::ObjectRef { * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions * \param into_producer If allows to inline a block into its producer * \param into_consumer If allows to inline a block into its consumer - * \param into_cache_only If it only allows to inline into a block generated by cache_read/write * \param inline_const_tensor Always inline constant tensors * \param disallow_if_then_else Always disallow if-then-else-like constructs * \param require_ordered Always require the read-to-write mapping to be ordered @@ -125,7 +124,6 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule AutoInline(bool into_producer, // bool into_consumer, // - bool into_cache_only, // bool inline_const_tensor, // bool disallow_if_then_else, // bool require_injective, // @@ -154,6 +152,16 @@ class ScheduleRule : public runtime::ObjectRef { Optional vector_load_max_len, // Optional> reuse_read, // Optional> reuse_write); + /*! + * \brief Create a rule: add-rfactor to some blocks if needed + * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the + * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable + * parallelism. + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // + Optional max_innermost_factor); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The rule created diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 71be8d218d2d3..7b5326a449211 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -156,6 +156,11 @@ class TVM_DLL ModuleNode : public Object { * \return Possible source code when available. */ virtual std::string GetSource(const std::string& format = ""); + /*! + * \brief Get the format of the module, when available. + * \return Possible format when available. + */ + virtual std::string GetFormat(); /*! * \brief Get packed function from current module by name. * diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 210ed53a7904e..43f2379a0b562 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object { */ virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) = 0; + /*! + * \brief Sample a compute-at location of the given block + * \param block_rv The block whose compute-at location is to be sampled + * \param decision The sampling decision + * \return The sampled loop where the input block is to be computed at + */ + virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index d3fbf0f91214e..7bc3b697945db 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1361,6 +1361,13 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ */ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; +/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ +constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; + +/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ +constexpr const char* meta_schedule_random_compute_producer = + "meta_schedule.random_compute_producer"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 7a6cfa3644478..3a964eb77d1ba 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -484,6 +484,14 @@ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); */ TVM_DLL Pass ConvertForLoopsToSerial(); +/*! + * \brief This is the unified static memory planner pass that will + * plan for memory intra- and inter- PrimFuncs together. The pass + * requires all the function to be PrimFuncs including the main. + * \return The pass. + */ +TVM_DLL Pass UnifiedStaticMemoryPlanner(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/usmp/algo/greedy.h b/include/tvm/tir/usmp/algo/greedy.h new file mode 100644 index 0000000000000..8f0ed873593e2 --- /dev/null +++ b/include/tvm/tir/usmp/algo/greedy.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/tir/usmp/algo/greedy.h + * \brief This header file contains helper methods used in greedy algorithms + * for planning memory for USMP + */ +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/*! + * \brief This is the base class for Greedy Algorithms where the sorting + * is specialized in the extended classes based on the greedy criteria. + */ +class GreedyBase { + public: + GreedyBase() {} + /*! + * \brief This function should be implemented by the extended classes to sort the BufferInfo + * objects based on a criteria and then calling PostSortAllocation. + */ + virtual Map PlanMemory(const Array& buffer_info_arr) = 0; + + protected: + /*! + * \brief Rounds up the offset to satisfy the alignement requirement + */ + size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment); + + /*! + * \brief A helper function check whether a offset is valid given the constraints + */ + bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes); + + /*! + * \brief Selects a pool for placement in the given set of ordered pool candidates + */ + PoolInfo SelectPlacementPool( + const BufferInfo& buf_info, + const std::unordered_map& pool_offsets); + + /*! + * \brief This is the base allocation function that works on sorted BufferInfo objects based + * on the greedy heuristic. The sorting algorithm has to be called before calling this. + */ + Map PostSortAllocation( + const std::vector& buffer_info_vec); +}; + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/include/tvm/tir/usmp/algorithms.h b/include/tvm/tir/usmp/algorithms.h new file mode 100644 index 0000000000000..77276a2c931c5 --- /dev/null +++ b/include/tvm/tir/usmp/algorithms.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/algorithms.h + * \brief The memory planning algorithm for USMP + */ + +#ifndef TVM_TIR_USMP_ALGORITHMS_H_ +#define TVM_TIR_USMP_ALGORITHMS_H_ + +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/*! + * \brief The Greedy-by-Size algorithm to plan memory + * + * This will perform a greedy algorithm in deciding the offsets + * within provided Pools, using the size of the buffer. + * + * \return A Map of BufferInfo objects and their associated PoolAllocation + */ +Map GreedyBySize(const Array& buffer_info_arr, + const Integer& memory_pressure); + +/*! + * \brief The Greedy-by-Conflicts algorithm to plan memory + * + * This will perform a greedy algorithm in deciding the offsets + * within provided Pools, using the number of liveness conflicts of the buffer. + * + * \return A Map of BufferInfo objects and their associated PoolAllocation + */ +Map GreedyByConflicts(const Array& buffer_info_arr, + const Integer& memory_pressure); + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_ALGORITHMS_H_ diff --git a/include/tvm/tir/usmp/analysis.h b/include/tvm/tir/usmp/analysis.h new file mode 100644 index 0000000000000..a24851d33182f --- /dev/null +++ b/include/tvm/tir/usmp/analysis.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/analysis.h + * \brief The analysis passes for TIR-based Unified Static Memory Planner + */ + +#ifndef TVM_TIR_USMP_ANALYSIS_H_ +#define TVM_TIR_USMP_ANALYSIS_H_ + +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +/*! + * \brief Extract BufferInfo objects from a TIR IRModule + * + * This pass would extract the buffer information of allocate nodes + * including liveness conflict with other buffer info objects. + * + * \return A Map of BufferInfo objects and their associated Stmts + */ +BufferInfoAnalysis ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod); + +} // namespace usmp +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_ANALYSIS_H_ diff --git a/include/tvm/tir/usmp/transform.h b/include/tvm/tir/usmp/transform.h new file mode 100644 index 0000000000000..6de64704bd8ba --- /dev/null +++ b/include/tvm/tir/usmp/transform.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/usmp/transform.h + * \brief The transform passes for TIR-based Unified Static Memory Planner + */ + +#ifndef TVM_TIR_USMP_TRANSFORM_H_ +#define TVM_TIR_USMP_TRANSFORM_H_ + +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace transform { + +using Pass = tvm::transform::Pass; + +/*! + * \brief Convert the analyzed PoolAllocation to offsets from pool variables + * + * This pass would convert the main function to accept pool variables as an input + * that get passed onto the operator PrimFuncs. Furthermore, the static allocations + * will be converted to offsets within the pool variable. + * + * \return the pass + */ +TVM_DLL Pass ConvertPoolAllocationsToOffsets(const Map& pool_allocations, + Bool emit_tvmscript_printable = Bool(false)); + +/*! + * \brief Assign PoolInfo objects to tir.allocate nodes depending on the PrimFunc's target + * + * This pass would assign default PoolInfo objects to allocate nodes that are not otherwise + * annotated, depending on pool info supplied for each target. + * + * \return the pass + */ +TVM_DLL Pass AssignPoolInfo(); + +} // namespace transform +} // namespace usmp +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_USMP_TRANSFORM_H_ diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 30c8f2ddea492..582399865d6fd 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -26,10 +26,21 @@ #define TVM_TIR_USMP_UTILS_H_ #include +#include #include #include namespace tvm { + +/*! + * \brief PassContext option to enable the USMP + */ +constexpr const char* kUSMPEnableOption = "tir.usmp.enable"; +/*! + * \brief PassContext option to select the memory planning algorithm in USMP + */ +constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm"; + namespace tir { namespace usmp { @@ -59,22 +70,28 @@ struct PoolInfoNode : public Object { Integer size_hint_bytes; /*! \brief The accessibility from each Target*/ Map target_access; // 'rw' or 'ro' + /*! \brief Whether pool is internally generated. + * The internal pools will be generated as part of + * the entry point code generation of the executor*/ + bool is_internal = false; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pool_name", &pool_name); v->Visit("size_hint_bytes", &size_hint_bytes); v->Visit("target_access", &target_access); + v->Visit("is_internal", &is_internal); } bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && - equal(target_access, other->target_access); + equal(target_access, other->target_access) && equal(is_internal, other->is_internal); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pool_name); hash_reduce(size_hint_bytes); hash_reduce(target_access); + hash_reduce(is_internal); } static constexpr const char* _type_key = "tir.usmp.PoolInfo"; @@ -89,7 +106,8 @@ static const int kUnrestrictedPoolSizeHint = -1; class PoolInfo : public ObjectRef { public: TVM_DLL PoolInfo(String pool_name, Map target_access, - Integer size_hint_bytes = kUnrestrictedPoolSizeHint); + Integer size_hint_bytes = kUnrestrictedPoolSizeHint, + Bool is_internal = Bool(false)); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); }; @@ -268,7 +286,14 @@ class AllocatedPoolInfo : public ObjectRef { * * \param buffer_info_map IR-bound BufferInfo map */ -Array CreateArrayBufferInfo(const Map& buffer_info_map); +Array CreateArrayBufferInfo(const Map& buffer_info_map); + +/*! + * \brief Calculate workspace required to execute a IRModule with main expressed in TIR + * + * \param mod the IRModule with TIR-based main function + */ +Integer CalculateModuleWorkspaceSize(const IRModule& mod); /*! * \brief The allocate node attribute to indicate candidate memory pools. @@ -284,6 +309,16 @@ static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_poo */ Integer CalculateExtentsSize(const AllocateNode* op); +/*! + * \brief Joins the Stmt nodes with PoolAllocation objects + * + * \param buffer_info_to_stmt the map of BufferInfo objects to Stmt nodes + * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects + */ +Map AssignStmtPoolAllocations( + const Map& buffer_info_to_stmt, + const Map& buffer_info_to_pool_allocation); + } // namespace usmp } // namespace tir @@ -294,6 +329,12 @@ namespace attr { */ static constexpr const char* kPoolArgs = "pool_args"; +/*! + * \brief This is a IRModule attribute that contains all the PoolInfo objects + * as an Array. + */ +static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos"; + } // namespace attr } // namespace tvm diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 2227d440126df..9b92c7cc2773f 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -285,68 +285,6 @@ def conv_output_shape( return output -def _conv_output_shape_from_cudnn( - tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1 -): - """Get output shape of 2D or 3D convolution. The output of this - function should be identical to that of conv_output_shape, but - requires a GPU with CuDNN to be present. This is maintained for - testing purposes to validate the output of conv_output_shape. - - Paramters - --------- - tensor_format: int - 0: CUDNN_TENSOR_NCHW - 1: CUDNN_TENSOR_NHWC - 2: CUDNN_TENSOR_NCHW_VECT_C - pad: int or list - padding - stride: int or list - stride - dilation: int or list - dilation - x_shape: list - input shape - w_shape: list - weight shape - data_dtype: str - data type - conv_dtype: str - convolution type - groups: int - number of groups - - Returns - ------- - oshape: list - output shape - - """ - dims = len(x_shape) - assert dims in (4, 5) - - pad, stride, dilation, xshape, wshape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, x_shape, w_shape - ) - oshape = np.zeros((dims), dtype=np.int32) - - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn") - func( - tensor_format, - dims - 2, - _get_np_int32_array_handle(pad), - _get_np_int32_array_handle(stride), - _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(xshape), - _get_np_int32_array_handle(wshape), - _get_np_int32_array_handle(oshape), - data_dtype, - conv_dtype, - groups, - ) - return list(oshape) - - def conv_find_algo( tensor_format, pad, diff --git a/python/tvm/contrib/ethosu/cascader/__init__.py b/python/tvm/contrib/ethosu/cascader/__init__.py index 72f1667c61519..03753b4049bb6 100644 --- a/python/tvm/contrib/ethosu/cascader/__init__.py +++ b/python/tvm/contrib/ethosu/cascader/__init__.py @@ -20,6 +20,7 @@ for both performance and memory usage on Arm(R) Ethos(TM)-U NPUs. """ from .stripe_config import StripeConfig +from .block_config import BlockConfig from .propagator import Propagator from .graph import ( PerformanceInfo, @@ -27,7 +28,9 @@ Part, TESubgraph, CascaderGraph, + BufferMode, register_matcher, create_cascader_graph, ) from .parts import InlinePart, EthosuPart +from .device_config import EthosuDeviceConfig diff --git a/python/tvm/contrib/ethosu/cascader/block_config.py b/python/tvm/contrib/ethosu/cascader/block_config.py new file mode 100644 index 0000000000000..3281b8a3606fa --- /dev/null +++ b/python/tvm/contrib/ethosu/cascader/block_config.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Block config to hold an output block shape and a corresponding input block shape""" +from typing import List +import tvm._ffi + +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("contrib.ethosu.cascader.BlockConfig") +class BlockConfig(Object): + """BlockConfig class""" + + def __init__(self, output_shape: List[int], compute_cycles: int, output_cycles: int): + self.__init_handle_by_constructor__( + _ffi_api.BlockConfig, output_shape, compute_cycles, output_cycles + ) + + @property + def output_shape(self) -> List[int]: + return list(self._output_shape) + + @property + def compute_cycles(self) -> int: + return int(self._compute_cycles) + + @property + def output_cycles(self) -> int: + return int(self._output_cycles) + + def __repr__(self) -> str: + return f"BlockConfig(output_shape={self.output_shape})" diff --git a/python/tvm/contrib/ethosu/cascader/device_config.py b/python/tvm/contrib/ethosu/cascader/device_config.py new file mode 100644 index 0000000000000..68a218da26163 --- /dev/null +++ b/python/tvm/contrib/ethosu/cascader/device_config.py @@ -0,0 +1,793 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Device config class to hold information about the target hardware""" +from typing import Tuple, List, Dict, Optional +from functools import reduce + +import math + +from . import BlockConfig +from . import StripeConfig +from . import Propagator + + +def _round_up(a: int, b: int) -> int: + """Round up to a multiple of b""" + return ((a + b - 1) // b) * b + + +def _round_up_div(a: int, b: int) -> int: + """Divide by b and round up to a multiple of b""" + return (a + b - 1) // b + + +class _Shape: + """Helper class for dealing with Tensor shapes of different layouts""" + + def __init__(self, shape: List[int], layout="NHWC"): + if layout == "NHCWB16": + self.height = int(shape[1]) + self.width = int(shape[3]) + self.depth = int(shape[2]) * int(shape[4]) + else: + self.height = int(shape[1]) + self.width = int(shape[2]) + self.depth = int(shape[3]) + + def round_up(self, other: "_Shape"): + self.height = _round_up(self.height, other.height) + self.width = _round_up(self.width, other.width) + self.depth = _round_up(self.depth, other.depth) + + def area(self) -> int: + return self.height * self.width + + def as_list(self): + return [1, self.height, self.width, self.depth] + + +class EthosuDeviceConfig: + """Arm(R) Ethos(TM)-U NPU config class""" + + def __init__(self, device: str): + self._device = device + self._subkernel_limits = (8, 8) + self._output_cycles = (1, 2, 3, 4, 6) + self._split_depth = 16 + self._max_block_shape = _Shape([1, 32, 64, 128]) + self._bank_size_bytes = 1024 + if self._device == "ethos-u55-256": + self._micro_block = _Shape([1, 2, 2, 8]) + self._input_micro_block = _Shape([1, 2, 2, 8]) + self._delay_cycles = (2, 2) + self._activation_cycles = (0.25, 1) + self._output_units = 8 + + self._total_banks = 48 + self._reserved_banks = 4 + self._input_granularity = 8 + self._accumulator_granularity = {4: 16, 5: 20} + self._lut_reserved = True + elif self._device == "ethos-u55-128": + self._micro_block = _Shape([1, 1, 2, 8]) + self._input_micro_block = _Shape([1, 1, 2, 8]) + self._delay_cycles = (2, 3) + self._activation_cycles = (0.5, 1) + self._output_units = 4 + + self._total_banks = 24 + self._reserved_banks = 4 + self._input_granularity = 4 + self._accumulator_granularity = {4: 8, 5: 12} + self._lut_reserved = True + elif self._device == "ethos-u55-64": + self._micro_block = _Shape([1, 1, 1, 8]) + self._input_micro_block = _Shape([1, 1, 1, 8]) + self._delay_cycles = (2, 3) + self._activation_cycles = (1, 1) + self._output_units = 2 + + self._total_banks = 16 + self._reserved_banks = 2 + self._input_granularity = 2 + self._accumulator_granularity = {4: 4, 5: 8} + self._lut_reserved = False + elif self._device == "ethos-u55-32": + self._micro_block = _Shape([1, 1, 1, 4]) + self._input_micro_block = _Shape([1, 1, 1, 8]) + self._delay_cycles = (3, 7) + self._activation_cycles = (1, 2) + self._output_units = 1 + + self._total_banks = 16 + self._reserved_banks = 2 + self._input_granularity = 2 + self._accumulator_granularity = {4: 4, 5: 8} + self._lut_reserved = False + + def _get_output_cycles( + self, op_type: str, op_str: str, ifm_dtype: str, ofm_dtype: str, activation: str + ) -> float: + """Estimate cycles per output element for an NPU operator + + Parameters + ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" + op_str : str + The type of NPU operator. + "MAX" + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str + Datatype of the Ouput Feature Map tensor (OFM) + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + + Returns + ------- + float + The cycles per output element + """ + cycles = 0 + bw_limit = 0 + if op_type == "ethosu_pooling" and op_str == "MAX": + cycles = self._output_cycles[0] + elif op_type in ("ethosu_pooling", "ethosu_conv2d", "ethosu_depthwise_conv2d"): + cycles = self._output_cycles[1] if ifm_dtype == "int8" else self._output_cycles[2] + elif op_type == "ethosu_binary_elementwise": + # Binary Bandwidth Limitations + if ifm_dtype == "int8": + bw_limit = 0.125 if ofm_dtype == "int8" else 0.75 + elif ifm_dtype == "int16": + bw_limit = 0.75 if ofm_dtype == "int16" else 1 + else: + bw_limit = 1.5 + + if op_str in ("MIN", "MAX"): + cycles = self._output_cycles[1] + elif op_str == "MUL": + cycles = self._output_cycles[2] + if op_str in ("ADD", "SUB"): + if ofm_dtype == "int32": + cycles = ( + self._output_cycles[2] if ifm_dtype == "int32" else self._output_cycles[3] + ) + else: + cycles = self._output_cycles[4] + + elif op_type == "ethosu_unary_elementwise": + # Unary Bandwidth Limitations + if ifm_dtype == "int16": + bw_limit = 0.25 + elif ifm_dtype == "int32": + bw_limit = 1 + + if op_str == "CLZ": + cycles = self._output_cycles[1] + elif op_str in ("SHL", "SHR"): + cycles = self._output_cycles[2] + elif op_str in ("LRELU", "ABS"): + cycles = self._output_cycles[1] + if ifm_dtype == "int16": + bw_limit = 0.5 + + act_cycles = 0 + if activation == "CLIP": + act_cycles = self._activation_cycles[0] + elif activation in ("LUT", "TANH", "SIGMOID"): + act_cycles = self._activation_cycles[1] + + return max((cycles / self._output_units), act_cycles, bw_limit) + + def _get_delay_cycles(self, op_type: str, ifm_dtype: str) -> int: + """Get the number of delay cycles during a bubble + + Parameters + ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" + op_str : str + The type of NPU operator. + "MAX" + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + + Returns + ---------- + int + The amount of delay cycles + """ + if op_type in ("ethosu_conv2d", "ethosu_depthwise2d", "ethosu_pooling"): + if ifm_dtype == "int16": + return self._delay_cycles[1] + + return self._delay_cycles[0] + + return 0 + + def _get_weight_decoder_cycles(self, op_type: str) -> int: + """Get cycle estimate for weight decoding + + Parameters + ---------- + op_type: str + The NPU primitive operator + "ethosu_pooling" + + Returns + ---------- + int + Estimated cycles for weight decoding + """ + if op_type in ("ethosu_conv2d", "ethosu_depthwise2d"): + return 32 * self._micro_block.depth // 8 + + return 0 + + def get_output_quantum(self, ofm_layout: str) -> Tuple[int]: + """Get the atomic output volume + + Parameters + ---------- + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ---------- + Tuple[int] + The atomic output volume formatted to the ofm_layout parameter + """ + if ofm_layout == "NHCWB16": + return [ + 1, + self._micro_block.height, + 1, + self._micro_block.width, + self._micro_block.depth, + ] + + return self._micro_block.as_list() + + def _align(self, x: int, n: int) -> int: + return int(math.ceil(x / n) * n) + + def _get_input_size( + self, output_size: int, kernel_stride: int, border: int, upscaling_factor: int + ) -> int: + return int(math.ceil(((output_size - 1) * kernel_stride + border)) / upscaling_factor) + + def _get_dilated_kernel_size(self, kernel_size: int, dilation: int) -> int: + return (kernel_size - 1) * dilation + 1 + + def _get_input_block( + self, + output_block: _Shape, + input_shape: _Shape, + dtype: str, + op_type: str, + is_partkernel: bool, + stride_h: int, + stride_w: int, + dilated_kernel_h: int, + dilated_kernel_w: int, + upscaling_factor: int, + ) -> _Shape: + height = self._get_input_size( + output_block.height, + stride_h, + min(dilated_kernel_h, self._subkernel_limits[0]), + upscaling_factor, + ) + width = self._get_input_size( + output_block.width, + stride_w, + min(dilated_kernel_w, self._subkernel_limits[1]), + upscaling_factor, + ) + + if op_type == "ethosu_conv2d": + if dtype == "int8": + if is_partkernel: + depth = self._align(min(32, input_shape.depth), 8) + else: + depth = self._align(min(16, input_shape.depth), 8) + elif dtype == "int16": + depth = self._align(min(16, input_shape.depth), 4) + else: + depth = self._align(min(8, input_shape.depth), 2) + else: + depth = output_block.depth + + return _Shape( + [ + 1, + self._align(height, self._micro_block.height), + self._align(width, self._micro_block.width), + depth, + ] + ) + + def get_kernel_steps( + self, + op_type: str, + dilated_kernel_h: int, + dilated_kernel_w: int, + ifm_dtype: str, + is_partkernel: bool = False, + ) -> List[int]: + """Calculate the total number of subkernels and their sizes + + Parameters + ---------- + op_type : str + The NPU primitive operator + "ethosu_pooling" + dilated_kernel_h: int + Height of dilated kernel + dilated_kernel_w: int + Width of dilated kernel + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + is_partkernel: bool + Flag showing whether part-kernel first traversal is used + + Returns + ---------- + List[int] + List where each entry contains the amount of elements in one of the subkernels + """ + if op_type == "ethosu_binary_elementwise": + return [1] + + subkernels = self._get_subkernels(dilated_kernel_h, dilated_kernel_w) + + # Determine the number of kernel steps per subkernel + kernel_steps = [] + for y, x in subkernels: + subkernel_elements = x * y + if op_type == "ethosu_conv2d" and is_partkernel: + # Part-kernel-first traversal conv2d + divisor = 4 if ifm_dtype == "int8" else 2 + kernel_steps.append(int(_round_up_div(subkernel_elements, divisor))) + elif op_type == "ethosu_depthwise_conv2d": + kernel_steps.append(int(_round_up_div(subkernel_elements, 4))) + else: + # Depth-first traversal conv2d or pooling + kernel_steps.append(int(subkernel_elements)) + + return kernel_steps + + def _get_subkernels(self, dilated_kernel_h: int, dilated_kernel_w: int): + num_subkernels_y = _round_up_div(dilated_kernel_h, self._subkernel_limits[0]) + num_subkernels_x = _round_up_div(dilated_kernel_w, self._subkernel_limits[1]) + subkernels_y = [ + min((dilated_kernel_h - i * self._subkernel_limits[0]), self._subkernel_limits[0]) + for i in range(num_subkernels_y) + ] + subkernels_x = [ + min((dilated_kernel_w - i * self._subkernel_limits[1]), self._subkernel_limits[1]) + for i in range(num_subkernels_x) + ] + + subkernels = [] + for y in subkernels_y: + for x in subkernels_x: + subkernels.append((y, x)) + + return subkernels + + def _get_accumulator_width(self, op_type: str, ifm_dtype: str): + if ifm_dtype == "int16" and op_type != "ethosu_pooling": + return 5 + + return 4 + + def is_partkernel( + self, op_type: str, ifm_channels: int, ifm_dtype: str, kernel_elements: int + ) -> bool: + """Determine which block traversal strategy has better DPU utilization + + Parameters + ---------- + op_type: str + The NPU primitive operator + "ethosu_pooling" + ifm_channels: int + Number of input channels + ifm_dtype: str + Datatype of the Input Feature Map tensor (IFM) + kernel_elements: int + Total number of elements in the kernel + + Returns + ---------- + bool + True if partkernel first has best DPU utilization + """ + if op_type != "ethosu_conv2d": + return False + + depth_first_utilization = ifm_channels / _round_up( + ifm_channels, 32 if ifm_dtype == "int8" else 16 + ) + part_kernel_first_utilization = (ifm_channels / _round_up(ifm_channels, 8)) * ( + kernel_elements / _round_up(kernel_elements, 4 if ifm_dtype == "int8" else 2) + ) + + return part_kernel_first_utilization > depth_first_utilization or ifm_channels <= 8 + + def get_elementwise_block_config( + self, + ifm_propagator: Propagator, + ifm2_propagator: Optional[Propagator], + op_attrs: Dict, + ofm_shape: List[int], + output_layout: str, + input_layout: str, + input2_layout: Optional[str], + ifm_dtype: str, + ofm_dtype: str, + ) -> List[BlockConfig]: + """Get a suitable block config for an elementwise operator + + Parameters + ---------- + ifm_propagator: Propagator, + The propagator containing the data dependencies between input and output + ifm2_propagator: Propagator, + The propagator containing the data dependencies between input2 and output + op_attrs: Dict, + Dictionary containing operator attributes + ofm_shape: List[int], + Shape of the output tensor + output_layout: str, + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + input_layout: str, + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + input2_layout: str, + The layout of the Input2 Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm_dtype: str, + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str, + Datatype of the Output Feature Map tensor (OFM) + + Returns + ---------- + List[BlockConfig] + List containing a single suitable block config + """ + block_config = [] + output_shape = [int(a) for a in ofm_shape] + + op_type = op_attrs.get("op") + op_str = op_attrs.get("op_str") + activation = op_attrs.get("activation", "NONE") + + input_bytewidth = 1 if ifm_dtype == "int8" else 2 if ifm_dtype == "int16" else 4 + banks_available = self._total_banks - self._reserved_banks + if activation == "LUT" and not self._lut_reserved: + banks_available -= 2 + + # Split the block in half until it fits into SHRAM + if output_layout == "NHCWB16": + split_order = (a for a in [1, 3, 2]) + output_block = [ + output_shape[0], + min(output_shape[1], self._max_block_shape.height), + min(output_shape[2] * output_shape[4], self._max_block_shape.depth), + min(output_shape[3], self._max_block_shape.width), + 16, + ] + else: + split_order = (a for a in [1, 2, 3]) + output_block = [ + output_shape[0], + min(output_shape[1], self._max_block_shape.height), + min(output_shape[2], self._max_block_shape.width), + min(output_shape[3], self._max_block_shape.depth), + ] + split_axis = next(split_order) + while True: + # Create stripe config for output block + offset = [0] * len(output_block) + stripes = [1] * len(output_block) + order = [1, 2, 4, 3, 0] if output_layout == "NHCWB16" else [1, 2, 3, 4] + output_stripe_config = StripeConfig( + output_block, output_block, output_block, order, stripes, offset + ) + + # Propagate the output to obtain the two input blocks + input_block = _Shape(ifm_propagator.propagate(output_stripe_config).shape, input_layout) + if ifm2_propagator: + input2_block = _Shape( + ifm2_propagator.propagate(output_stripe_config).shape, input2_layout + ) + else: + # Unary elementwise + input2_block = _Shape([0, 0, 0, 0]) + + input_block.round_up(self._input_micro_block) + input2_block.round_up(self._input_micro_block) + + # Banks required for input block + input_bytes = input_block.area() * self._align(input_block.depth * input_bytewidth, 8) + input_banks = _round_up_div(input_bytes, self._bank_size_bytes) * 2 + input_banks = _round_up(input_banks, self._input_granularity) + + # Banks required for input2 block + input2_bytes = input2_block.area() * self._align( + input2_block.depth * input_bytewidth, 8 + ) + input2_banks = _round_up_div(input2_bytes, self._bank_size_bytes) * 2 + input2_banks = _round_up(input2_banks, self._input_granularity) + + # Check whether or not both IFMs fit into SHRAM + if (input_banks + input2_banks) <= banks_available: + output_cycles = self._get_output_cycles( + op_type, op_str, ifm_dtype, ofm_dtype, activation + ) + output_cycles *= reduce(lambda a, b: a * b, output_block, 1) + output_cycles = int(math.ceil(output_cycles)) + block_config.append(BlockConfig(output_block, 0, output_cycles)) + break + + if output_block[split_axis] == 1: + split_axis = next(split_order) + + output_block[split_axis] = _round_up_div(output_block[split_axis], 2) + + return block_config + + def get_valid_block_configs( + self, + ifm_propagator: Propagator, + op_attrs: Dict, + ofm_shape: List[int], + ofm_channels: int, + ifm_channels: int, + output_layout: str, + input_layout: str, + ifm_dtype: str, + ofm_dtype: str, + kernel_h: int = 1, + kernel_w: int = 1, + ) -> List[BlockConfig]: + """Get all of the valid block configs + + Parameters + ---------- + ifm_propagator: Propagator, + The propagator containing the data dependencies between input and output + op_attrs: Dict, + Dictionary containing operator attributes + ofm_shape: List[int], + Shape of the output tensor + ofm_channels: int, + Number of output channels + ifm_channels: int, + Number of input channels + output_layout: str, + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + input_layout: str, + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm_dtype: str, + Datatype of the Input Feature Map tensor (IFM) + ofm_dtype: str, + Datatype of the Output Feature Map tensor (OFM) + kernel_h: int, + Height of kernel + kernel_h: int + Width of kernel + + Returns + ---------- + List[BlockConfig] + List containing all of the valid block configs + """ + valid_block_configs = [] + + op_type = op_attrs.get("op") + op_str = op_attrs.get("op_str") + activation = op_attrs.get("activation", "NONE") + stride_h = int(op_attrs.get("stride_h", 1)) + stride_w = int(op_attrs.get("stride_w", 1)) + upscaling_factor = 1 if op_attrs.get("upscale", "NONE") == "NONE" else 2 + + subkernel_transform = ifm_propagator.transform + if output_layout == "NHCWB16": + output_shape = _Shape([1, ofm_shape[1], ofm_shape[3], ofm_channels]) + else: + output_shape = _Shape(ofm_shape) + + if input_layout == "NHCWB16": + subkernel_transform[1][-1] = min( + subkernel_transform[1][-1], self._subkernel_limits[0] - stride_h + ) + subkernel_transform[3][-1] = min( + subkernel_transform[3][-1], self._subkernel_limits[1] - stride_w + ) + else: + subkernel_transform[1][-1] = min( + subkernel_transform[1][-1], self._subkernel_limits[0] - stride_h + ) + subkernel_transform[2][-1] = min( + subkernel_transform[2][-1], self._subkernel_limits[1] - stride_w + ) + + subkernel_propagator = Propagator(subkernel_transform, ifm_propagator.offset) + + # Define search space + max_height = min(output_shape.height, self._max_block_shape.height) + min_height = max(self._micro_block.height, upscaling_factor) + + max_width = min(output_shape.width, self._max_block_shape.width) + min_width = max(self._micro_block.width, upscaling_factor) + + max_depth = min(ofm_channels, self._max_block_shape.depth) + min_depth = max(self._micro_block.depth, upscaling_factor) + + input_bytewidth = 1 if ifm_dtype == "int8" else 2 + acc_bytewidth = self._get_accumulator_width(op_type, ifm_dtype) + banks_available = self._total_banks - self._reserved_banks + if activation == "LUT" and not self._lut_reserved: + banks_available -= 2 + + # Input block depth has additional limitations for Operators that require full input depth + input_block_depth = 0 + is_partkernel = self.is_partkernel(op_type, ifm_channels, ifm_dtype, kernel_h * kernel_w) + if op_type == "ethosu_conv2d": + if is_partkernel: + input_block_depth = min(ifm_channels, 16) + else: + input_block_depth = min(ifm_channels, 32) + + for depth in range(min_depth, max_depth + min_depth, min_depth): + if (depth < output_shape.depth) and (depth % self._split_depth != 0): + # Block depth has to be less than full depth or a multiple of the split depth + continue + + for width in range(min_width, max_width + min_width, min_width): + for height in range(min_height, max_height + min_height, min_height): + if output_layout == "NHCWB16": + output_block = ( + 1, + height, + 1 + ((depth - 1) // 16), + width, + _round_up( + min(16, max(ofm_channels, min_depth)), self._micro_block.depth + ), + ) + order = [1, 2, 4, 3, 0] + else: + output_block = (1, height, width, depth) + order = [1, 2, 3, 4] + + offset = [0] * len(output_block) + stripes = [1] * len(output_block) + block_stripe_config = StripeConfig( + output_block, + output_block, + output_block, + order, + stripes, + offset, + ) + + # Propagate output block + input_block = subkernel_propagator.propagate(block_stripe_config) + + input_block_shape = _Shape(input_block.shape, input_layout) + input_block_shape.round_up(self._input_micro_block) + + output_block_shape = _Shape(output_block, output_layout) + + if op_type == "ethosu_conv2d": + input_block_shape.depth = input_block_depth + + # Banks required for input block + input_bytes = input_block_shape.area() * self._align( + input_block_shape.depth * input_bytewidth, 8 + ) + input_banks = _round_up_div(input_bytes, self._bank_size_bytes) * 2 + input_banks = _round_up(input_banks, self._input_granularity) + + # Banks required for accumulation + acc_depth = _round_up(min(output_block_shape.depth, ofm_channels), 8) + acc_bytes = ( + output_block_shape.area() * self._align(acc_depth, 8) * acc_bytewidth + ) + acc_banks = _round_up_div(acc_bytes, self._bank_size_bytes) * 2 + acc_banks = _round_up(acc_banks, self._accumulator_granularity[acc_bytewidth]) + + if (input_banks + acc_banks) <= banks_available: + output_cycles = self._get_output_cycles( + op_type, op_str, ifm_dtype, ofm_dtype, activation + ) + output_cycles *= reduce(lambda a, b: a * b, output_block, 1) + output_cycles = int(math.ceil(output_cycles)) + compute_cycles = self._estimate_compute_cycles_per_block( + op_type, + output_block_shape, + input_block_shape, + kernel_h, + kernel_w, + ifm_channels, + is_partkernel, + ) + valid_block_configs.append( + BlockConfig(output_block, compute_cycles, output_cycles) + ) + else: + # Block config does not fit into SHRAM + # Any Block config that is strictly larger than this one will also fail + break + + return valid_block_configs + + def _estimate_compute_cycles_per_block( + self, + op_type: str, + block_shape: _Shape, + input_block_shape: _Shape, + kernel_h: int, + kernel_w: int, + input_channels: int, + ifm_dtype: str, + is_partkernel: bool = False, + ) -> Tuple[int, int]: + # Calculate the amount of micro blocks per block, per axis + num_quantum_x = _round_up_div(block_shape.width, self._micro_block.width) + num_quantum_y = _round_up_div(block_shape.height, self._micro_block.height) + num_quantum_z = _round_up_div(block_shape.depth, self._micro_block.depth) + num_quantum_xy = num_quantum_x * num_quantum_y + + kernel_steps = self.get_kernel_steps(op_type, kernel_h, kernel_w, ifm_dtype, is_partkernel) + + wd_cycles = self._get_weight_decoder_cycles(op_type) + delay_cycles = self._get_delay_cycles(op_type, ifm_dtype) + cycle_quantum = 4 + + compute_cycles = 0 + for subkernel_steps in kernel_steps: + subkernel_cycles = 1 if op_type == "ethosu_pooling" else subkernel_steps + compute_cycles += ( + max(wd_cycles, cycle_quantum * num_quantum_xy) * subkernel_cycles * num_quantum_z + ) + + if num_quantum_xy == 1: + if num_quantum_z == 1: + compute_cycles += delay_cycles * subkernel_steps + elif subkernel_steps > 1: + compute_cycles += delay_cycles * (subkernel_steps - 1) * num_quantum_z + + if is_partkernel: + compute_cycles *= _round_up_div(input_block_shape.depth, 8) + + if op_type == "ethosu_conv2d": + compute_cycles *= _round_up_div(input_channels, input_block_shape.depth) + + return compute_cycles diff --git a/python/tvm/contrib/ethosu/cascader/graph.py b/python/tvm/contrib/ethosu/cascader/graph.py index 9b22e632ff892..7aa4a26513cd2 100644 --- a/python/tvm/contrib/ethosu/cascader/graph.py +++ b/python/tvm/contrib/ethosu/cascader/graph.py @@ -16,6 +16,7 @@ # under the License. """Graph objects to define compute graphs for the NPU cascader.""" from typing import List, Dict +from enum import IntEnum from collections import namedtuple import numpy as np @@ -24,6 +25,7 @@ from tvm.runtime import Object from .stripe_config import StripeConfig +from .device_config import EthosuDeviceConfig from . import _ffi_api @@ -34,6 +36,11 @@ TESubgraph = namedtuple("TESubgraph", ["input_tensors", "output_tensor"]) +class BufferMode(IntEnum): + RECOMPUTE = 0 + ROLLING = 1 + + @tvm._ffi.register_object("contrib.ethosu.cascader.PerformanceInfo") class PerformanceInfo(Object): """PerformanceInfo class""" @@ -113,9 +120,9 @@ def get_stripe_align_hint(self) -> List[int]: return list(_ffi_api.PartGetStripeAlignHint(self)) def get_performance_info( - self, stripe_config: StripeConfig, is_rolling: bool + self, stripe_config: StripeConfig, buffer_mode: BufferMode ) -> PerformanceInfo: - return _ffi_api.PartGetPerformanceInfo(self, stripe_config, is_rolling) + return _ffi_api.PartGetPerformanceInfo(self, stripe_config, buffer_mode) @property def input_tensors(self): @@ -188,7 +195,9 @@ def register_matcher(matcher): return matcher -def create_cascader_graph(te_graph: TESubgraph, const_dict: Dict[int, np.ndarray]) -> CascaderGraph: +def create_cascader_graph( + te_graph: TESubgraph, const_dict: Dict[int, np.ndarray], device_config: EthosuDeviceConfig +) -> CascaderGraph: """Create a CascaderGraph from a Tensor Expression graph and constant dictionary. Parameters @@ -197,6 +206,8 @@ def create_cascader_graph(te_graph: TESubgraph, const_dict: Dict[int, np.ndarray The Tensor Expression graph. const_dict : Dict[int, np.ndarray] The constant dictionary. + device_config : EthosuDeviceConfig + Target device configuration. Returns ------- @@ -227,7 +238,7 @@ def _visit_tensor(tensor): input_tensors = [] # Check whether any of the registered matchers match the current tensor for matcher in REGISTERED_MATCHERS: - part = matcher(tensor) + part = matcher(tensor, device_config) if part: input_tensors = part.subgraph.input_tensors break diff --git a/python/tvm/contrib/ethosu/cascader/parts.py b/python/tvm/contrib/ethosu/cascader/parts.py index 9cc67d5760dd5..12588799a66a6 100644 --- a/python/tvm/contrib/ethosu/cascader/parts.py +++ b/python/tvm/contrib/ethosu/cascader/parts.py @@ -20,6 +20,8 @@ from .propagator import Propagator from .graph import Part, TESubgraph +from .block_config import BlockConfig +from .stripe_config import StripeConfig from . import _ffi_api @@ -52,7 +54,9 @@ def __init__( te_subgraph: TESubgraph, propagators: List[Propagator], output_quantum: List[int], - quantum_cycles: int, + subkernels: int, + valid_block_configs: List[BlockConfig], + weight_tensor_idx: int = -1, ): self.__init_handle_by_constructor__( _ffi_api.EthosuPart, @@ -60,5 +64,10 @@ def __init__( te_subgraph.output_tensor, propagators, output_quantum, - quantum_cycles, + subkernels, + valid_block_configs, + weight_tensor_idx, ) + + def get_block_config(self, stripe_config: StripeConfig) -> BlockConfig: + return _ffi_api.EthosuPartGetBlockConfig(self, stripe_config) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index ef081f2d79e3e..e640aad89231c 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -51,11 +51,11 @@ def get_hexagon_rpc_dir() -> pathlib.Path: for path in libinfo.find_lib_path(): rpc_dir = os.path.join(os.path.dirname(path), "hexagon_rpc") if os.path.isdir(rpc_dir): - HEXAGON_RPC_DIR = pathlib.Path(rpc_dir) + HEXAGON_RPC_DIR = rpc_dir break else: raise "hexagon_rpc was not found." - return HEXAGON_RPC_DIR + return pathlib.Path(HEXAGON_RPC_DIR) class HexagonLauncher: diff --git a/python/tvm/contrib/pipeline_executor.py b/python/tvm/contrib/pipeline_executor.py index 7cabb8b3d2edc..c75aa3dad43b9 100644 --- a/python/tvm/contrib/pipeline_executor.py +++ b/python/tvm/contrib/pipeline_executor.py @@ -54,6 +54,8 @@ def build(pipe_configs): raise RuntimeError('"module_connection" is missing') if "input_connection" not in config: raise RuntimeError('"input_connection" is missing') + if "param_connection" not in config: + raise RuntimeError('"param_connection" is missing') mod_n_configs = config["module_connection"] config_len = len(mod_n_configs) @@ -91,6 +93,7 @@ def build(pipe_configs): # map of global input and subgraph input, and the "module_connection" is used to # record module dependency. string_config = {} + string_config["param_connection"] = config["param_connection"] string_config["input_connection"] = config["input_connection"] string_config["module_connection"] = module_string_config @@ -114,6 +117,8 @@ def __init__(self, module): # Get the packed functions from the pipeline executor. self._get_num_outputs = self.module["get_num_outputs"] self._get_input_pipeline_map = self.module["get_input_pipeline_map"] + self._get_params_group_pipeline_map = self.module["get_params_group_pipeline_map"] + self._set_param = self.module["set_param"] def get_input_pipeline_map(self, name): """Using the "name" to get the corresponding subgraph index and also get the "input name" @@ -125,6 +130,39 @@ def get_input_pipeline_map(self, name): """ return self._get_input_pipeline_map(name) + def get_params_group_pipeline_map(self, name): + """Use the name of the parameters group to get the corresponding runtime module index. + + Parameters + ---------- + name: str + The parameter group name. + + Returns + ------- + module_index: int + The index of the runtime module. + """ + return self._get_params_group_pipeline_map(name) + + def set_params(self, params_group_name, params_data): + """Set the parameter group value given the parameter group name. Note that the parameter + group name is declared in the pipeline executor config. + + Parameters + ---------- + params_group_name : str + The parameters group name. + + params_data : Dict[str, NDArray] + A map from parameter name to data. + """ + if not params_data: + raise RuntimeError('"params_data is empty!"') + + for key, val in params_data.items(): + self._set_param(params_group_name, key, val) + @property def num_outputs(self): """Get the number of outputs. @@ -311,9 +349,19 @@ def connect(self, binding): if self.io_owner == binding.io_owner: raise RuntimeError("Can not bind itself.") + if self.io_type == "param" and not self.is_pipeline_executor_interface(): + raise RuntimeError( + 'The "param" binding can only be used by a pipeline executor interface!' + ) + if not self.is_pipeline_executor_interface() and self.io_type == "input": raise RuntimeError("Module can only bind from output interface!") + if self.io_type == "param" and binding.io_type != "param": + raise RuntimeError( + 'A global "param" interface can only be bind with a module "param" interface!' + ) + if ( not self.is_pipeline_executor_interface() and not binding.is_pipeline_executor_interface() @@ -412,6 +460,7 @@ def __init__(self, mod=None): self.output_type = InferType()(mod)["main"].checked_type.ret_type self.input_bindings = PipelineConfig.BindingList(self, "input") self.output_bindings = PipelineConfig.BindingList(self, "output") + self.param_binding = PipelineConfig.Binding(self, "param", "param") def __eq__(self, other): if isinstance(other, PipelineConfig.ModuleWrapper): @@ -427,6 +476,9 @@ def __getitem__(self, key): if key == "output": return self.output_bindings + if key == "param": + return self.param_binding + raise RuntimeError(f"{key} not found!") raise RuntimeError('The data type of "key" is not supported!') @@ -483,14 +535,21 @@ def __init__(self): self.mod_wrapper = {} self.input_bindings = self.BindingList(self, "input") self.output_bindings = self.BindingList(self, "output") + # There is a map of global parameters group and module index. + self.param_group_bindings = self.BindingList(self, "param") def __str__(self): # Get configuration information as a string. # Use topological sort to get correct module order. self.dag_topology_sort() + # Getting the parameters dependencies. + param_dump = "Params\n" + for param_name in self.param_group_bindings.bindings: + inf = self.param_group_bindings.bindings[param_name] + param_dump += str(inf) + "\n" # Get the input dependencies. - input_dump = "Inputs\n" + input_dump = "\nInputs\n" for input_name in self.input_bindings.bindings: inf = self.input_bindings.bindings[input_name] input_dump += str(inf) + "\n" @@ -516,7 +575,7 @@ def __str__(self): for name in sorted(output.keys()): output_dump += f" |output({name}) : {output[name]}\n" - return input_dump + output_dump + connections_dump + return param_dump + input_dump + output_dump + connections_dump def __getitem__(self, key): if isinstance(key, tvm.ir.module.IRModule): @@ -529,8 +588,12 @@ def __getitem__(self, key): return self.input_bindings if key == "output": return self.output_bindings + if key == "param_group": + return self.param_group_bindings + + raise RuntimeError(f"{key} not found!") - raise RuntimeError(f"{key} not found.") + raise RuntimeError(f'The key type "{type(key)}" is not supported!') def get_config(self): """Get the configuration information in dictionary form, this configuration @@ -541,7 +604,6 @@ def get_config(self): self.dag_topology_sort() mconfig = {} module_connection = {} - input_connection = {} for mod in self.mod_wrapper: # Generate pipeline configuration. mconf = {} @@ -579,22 +641,33 @@ def get_config(self): "dev": module.dev, } - # Create a map of pipeline input and subgraph input. - input_connection = [] - for input_name in self.input_bindings.bindings: - input_dict = self.input_bindings.bindings[input_name].get_binding_dict() - if "interface_name" not in input_dict["connection"][0]: - raise RuntimeError("interface_name is missing in connection config!") - # Creating the map of global interface and subgraph interface. - input_map = { - "global_interface_name": input_dict["interface_name"], - "mod_idx": input_dict["connection"][0]["mod_idx"], - "module_interface_name": input_dict["connection"][0]["interface_name"], - } - input_connection.append(input_map) + # Creating a map including pipeline inputs and subgraph inputs. + input_connection = [] + for input_name in self.input_bindings.bindings: + input_dict = self.input_bindings.bindings[input_name].get_binding_dict() + if "interface_name" not in input_dict["connection"][0]: + raise RuntimeError("interface_name is missing in connection config!") + # Creating the map including global interfaces and subgraph interfaces. + input_map = { + "global_interface_name": input_dict["interface_name"], + "mod_idx": input_dict["connection"][0]["mod_idx"], + "module_interface_name": input_dict["connection"][0]["interface_name"], + } + input_connection.append(input_map) + + # Create a map including global parameters groups and modules. + param_connection = [] + for param_name in self.param_group_bindings.bindings: + param_dict = self.param_group_bindings.bindings[param_name].get_binding_dict() + param_map = { + "global_param_name": param_dict["interface_name"], + "mod_idx": param_dict["connection"][0]["mod_idx"], + } + param_connection.append(param_map) mconfig["module_connection"] = module_connection mconfig["input_connection"] = input_connection + mconfig["param_connection"] = param_connection return mconfig def dag_topology_sort(self): @@ -613,8 +686,12 @@ def dag_topology_sort(self): mlist += temp_list + mod_wrapper_sort = {} for mod, i in zip(mlist, range(len(mlist))): self.mod_wrapper[mod].set_idx_name(i) + mod_wrapper_sort[mod] = self.mod_wrapper[mod] + + self.mod_wrapper = mod_wrapper_sort def get_mod_idx(self, mod): # Return the module index. diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 6ee052ecba687..2e8fa1e777d68 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,3 +16,4 @@ # under the License. """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc +from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py new file mode 100644 index 0000000000000..501e4423196c1 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that verifies if the GPU code is correct""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.VerifyGPUCode") +class VerifyGPUCode(Postproc): + """A postprocessor that verifies if the GPU code is correct""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index b90780d5bfdb5..475c43a3fda17 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -16,4 +16,7 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ +from .add_rfactor import AddRFactor +from .auto_inline import AutoInline from .schedule_rule import PyScheduleRule, ScheduleRule +from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py new file mode 100644 index 0000000000000..72f9fc92f96e4 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add-rfactor Rule that add-rfactor to some blocks if needed""" +from typing import Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AddRFactor") +class AddRFactor(ScheduleRule): + """Rules for add-rfactor to some blocks if needed. + + Parameters + ---------- + max_jobs_per_core: int + The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU + parallelism, i.e. `num_cores * max_jobs_per_core`. + Use -1 to disable parallelism. + max_innermost_factor: Optional[int] = None + The maximum size of the innermost factor. None means no limit. + """ + + def __init__( + self, + max_jobs_per_core: int = 16, + max_innermost_factor: Optional[int] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + max_innermost_factor, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py new file mode 100644 index 0000000000000..22206f3fcc248 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AutoInline") +class AutoInline(ScheduleRule): + """Rule that inlines spatial blocks if it satisfies some conditions + + Parameters + ---------- + into_producer : bool + If allows to inline a block into its producer + into_consumer : bool + If allows to inline a block into its consumer + inline_const_tensor : bool + Always inline constant tensors + disallow_if_then_else : bool + Always disallow if-then-else-like constructs + require_injective : bool + Always require the read-to-write mapping to be ordered + require_ordered : bool + Always require the read-to-write mapping to be injective + disallow_op : Optional[List[str]] + The operators that are disallowed in auto inline + """ + + def __init__( + self, + into_producer: bool, + into_consumer: bool, + inline_const_tensor: bool, + disallow_if_then_else: bool, + require_injective: bool, + require_ordered: bool, + disallow_op: Optional[List[str]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member + into_producer, + into_consumer, + inline_const_tensor, + disallow_if_then_else, + require_injective, + require_ordered, + disallow_op, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py new file mode 100644 index 0000000000000..2355b0bfa8e54 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rule that randomly select a compute-at location for a free block""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.RandomComputeLocation") +class RandomComputeLocation(ScheduleRule): + """A rule that randomly select a compute-at location for a free block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py new file mode 100644 index 0000000000000..020869da4b10d --- /dev/null +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default schedule rules""" +from tvm.meta_schedule.schedule_rule import ( + AddRFactor, + AutoInline, + ScheduleRule, +) +from tvm.target import Target + + +def auto_inline(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def add_rfactor(target: Target) -> ScheduleRule: + """Default schedule rules for with add_rfactor""" + if target.kind.name == "llvm": + return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py new file mode 100644 index 0000000000000..10e31e7213cbd --- /dev/null +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.tir import Schedule +from tvm.tir.schedule import Trace + + +def check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py new file mode 100644 index 0000000000000..49a60a27526aa --- /dev/null +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -0,0 +1,877 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in TE""" +# pylint: disable=missing-docstring +from typing import Tuple + +from tvm import te, tir, topi + + +def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X") + y = te.placeholder((B, K, M), name="Y") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]), + name="Z", + ) + return (x, y, z) + + +def conv1d_nlc( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, output) + + +def conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO), name="weight") + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rh, rw, rc, co] + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv3d_ndhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + D: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, D, H, W, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, kernel_size, CI // groups, CO), name="weight" + ) + batch_size, in_d, in_h, in_w, _ = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name="rd") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + ( + padded[ + n, + d * stride + rd * dilation, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rd, rh, rw, rc, co] + ), + axis=[rd, rh, rw, rc], + ), + name="conv3d_ndhwc", + ) + return (inputs, weight, output) + + +def depthwise_conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + C: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + factor: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + assert int(factor) == 1, "Not optimized for factor != 1" + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + c // factor, + ] + * weight[c % factor, rh, rw, c // factor] + ), + axis=[rh, rw], + ), + name="depth_conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_transpose_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight") + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple( + padding, (filter_h, filter_w) + ) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad( + inputs, + [ + 0, + (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, + 0, + ], + [ + 0, + (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, + 0, + ], + ) + + # remove extra padding introduced by dilatation + idx_div = te.indexdiv + idx_mod = te.indexmod + border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h) + border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idx_div(indices[i], strides[i])) + not_zero.append(idx_mod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name="rc") + rh = te.reduce_axis((0, filter_h), name="rh") + rw = te.reduce_axis((0, filter_w), name="rw") + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) + * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc], + ), + name="conv2d_transpose_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_capsule_nhwijc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + capsule_size: int = 4, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name="weight" + ) + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name="cap_k") + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + ( + padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co] + ), + axis=[rh, rw, cap_k, rc], + ), + name="conv2d_capsule_nhwijc", + ) + return (inputs, weight, output) + + +def norm_bmn( # pylint: disable=invalid-name,missing-docstring + B: int, + M: int, + N: int, +) -> Tuple[te.Tensor, te.Tensor]: + a = te.placeholder((B, M, N), name="A") + i = te.reduce_axis((0, M), name="i") + j = te.reduce_axis((0, N), name="j") + c = te.compute( + (B,), + lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]), + name="C", + ) + d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D") + return (a, d) + + +def conv2d_nhwc_without_layout_rewrite( # pylint: disable=invalid-name + Input: int, + Filter: int, + stride: int, + padding: int, + dilation: int, + out_dtype="float32", +): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation + because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape # type: ignore + kernel_h, kernel_w, _channel, num_filter = Filter.shape # type: ignore + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel = num_filter + out_height = topi.utils.simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 + ) + out_width = topi.utils.simplify( + (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 + ) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[ + nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc + ].astype(out_dtype) + * Filter[ry, rx, rc, ff].astype(out_dtype), # type: ignore + axis=[ry, rx, rc], + ), + name="Conv2dOutput", + tag="conv2d_nhwc", + ) + return Output + + +def conv2d_nhwc_bn_relu( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + strides: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + data = te.placeholder((N, H, W, CI), name="data") + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel") + bias = te.placeholder((CO,), name="bias") + bn_scale = te.placeholder((CO,), name="bn_scale") + bn_offset = te.placeholder((CO,), name="bn_offset") + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l], name="bias_add" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], name="bn_mul" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], name="bn_add" + ) + out = topi.nn.relu(conv) + return (data, kernel, bias, bn_offset, bn_scale, out) + + +def transpose_batch_matmul( # pylint: disable=invalid-name,missing-docstring + batch: int, + seq_len: int, + n_head: int, + n_dim: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + query = te.placeholder((batch, seq_len, n_head, n_dim), name="query") + value = te.placeholder((batch, seq_len, n_head, n_dim), name="value") + query_T = te.compute( + (batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], + name="query_T", + ) + value_T = te.compute( + (batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], + name="value_T", + ) + k = te.reduce_axis((0, n_dim), name="k") + out = te.compute( + (batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j], axis=[k]), + name="C", + ) + return (query, value, out) + + +def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + tile_size = 4 # _infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name="inputs") + N, H, W, CI = topi.utils.get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, "float32") + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + _rkh = te.reduce_axis((0, KH), name="r_kh") + _rkw = te.reduce_axis((0, KW), name="r_kw") + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][ + idxmod(p, nW) * m + nu + ][ci], + name="input_tile", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + data_pack = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: te.sum( + input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + bgemm = te.compute( + (alpha, alpha, P, CO), + lambda eps, nu, p, co: te.sum( + data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci] + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + inverse = te.compute( + (m, m, P, CO), + lambda vh, vw, p, co: te.sum( + bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] + ), + name="inverse", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + ) + + # output + output = te.compute( + (N, H, W, CO), + lambda n, h, w, co: inverse[ + idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co + ], + name="conv2d_winograd", + ) + + return (inputs, kernel_pack, output) + + +def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((k, m), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + return (a, b, c) + + +def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((k, m), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + return (a, b, c) + + +def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((m, k), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def matmul_relu_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((k, m), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def conv2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + return (x, w, y) + + +def conv2d_nchw_bias_bn_relu( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + b = te.placeholder((co, 1, 1), name="B") + bn_scale = te.placeholder((co, 1, 1), name="bn_scale") + bn_offset = te.placeholder((co, 1, 1), name="bn_offset") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0, 0], name="bias_add") + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0], name="bn_mul" + ) + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0, 0], name="bn_add" + ) + y = topi.nn.relu(y) + return (x, w, b, bn_scale, bn_offset, y) + + +def max_pool2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + padding: int, +) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding, padding], "max") + return (x, y) + + +def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + a = te.placeholder((m, n), name="A") + b = topi.nn.softmax(a, axis=1) + + return (a, b) + + +def create_te_workload(name: str, idx: int) -> tir.PrimFunc: + workload_func, params = CONFIGS[name] + return te.create_prim_func(workload_func(*params[idx])) # type: ignore + + +CONFIGS = { + "C1D": ( + conv1d_nlc, + [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), + # (1, 256, 64, 128, 1, 2, 0), + # (1, 256, 64, 64, 1, 1, 0), + # (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), + # (1, 128, 128, 128, 3, 1, 1), + # (1, 64, 256, 512, 3, 2, 1), + # (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), + ], + ), + "C2D": ( + conv2d_nhwc, + [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), + # (1, 56, 56, 64, 128, 3, 2, 1), + # (1, 56, 56, 64, 128, 1, 2, 0), + # (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), + # (1, 28, 28, 128, 256, 3, 2, 1), + # (1, 28, 28, 128, 256, 1, 2, 0), + # (1, 28, 28, 128, 128, 3, 1, 1), + # (1, 14, 14, 256, 512, 3, 2, 1), + # (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "C3D": ( + conv3d_ndhwc, + [ + # Derived from conv2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), + # (1, 16, 56, 56, 64, 128, 3, 2, 1), + # (1, 16, 56, 56, 64, 128, 1, 2, 0), + # (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), + # (1, 16, 28, 28, 128, 256, 3, 2, 1), + # (1, 16, 28, 28, 128, 256, 1, 2, 0), + # (1, 16, 28, 28, 128, 128, 3, 1, 1), + # (1, 16, 14, 14, 256, 512, 3, 2, 1), + # (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "GMM": ( + batch_matmul_nkkm, + [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), + ], + ), + "GRP": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1, 1, 4), + # (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), + # (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0, 1, 4), + # (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), + # (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), + # (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), + # (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), + # (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1, 1, 4), + ], + ), + "DIL": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3, 2), + # (1, 56, 56, 64, 128, 3, 2, 1 , 2), + # (1, 56, 56, 64, 128, 1, 2, 0 , 2), + # (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0, 2), + # (1, 28, 28, 128, 256, 3, 2, 1, 2), + # (1, 28, 28, 128, 256, 1, 2, 0, 2), + # (1, 28, 28, 128, 128, 3, 1, 1, 2), + # (1, 14, 14, 256, 512, 3, 2, 1, 2), + # (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1, 2), + ], + ), + "DEP": ( + depthwise_conv2d_nhwc, + [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), + # (1, 56, 56, 128, 3, 1, 1), + # (1, 56, 56, 128, 3, 2, 1), + # (1, 28, 28, 256, 3, 1, 1), + # (1, 28, 28, 256, 3, 2, 1), + # (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), + ], + ), + "T2D": ( + conv2d_transpose_nhwc, + [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), + ], + ), + "CAP": ( + conv2d_capsule_nhwijc, + [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), + ], + ), + "NRM": ( + norm_bmn, + [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), + ], + ), + "SFM": ( + softmax_mn, + [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + ], + ), + "C2d-BN-RELU": ( + conv2d_nhwc_bn_relu, + [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "TBG": ( + transpose_batch_matmul, + [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + ], + ), +} diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index bf54956503c2e..9f65a0bef1096 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -86,10 +86,12 @@ def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): for dso_mod in dso_modules: if dso_mod.type_key == "c": + assert dso_mod.format in ["c", "cc", "cpp"] + ext = dso_mod.format index = mod_indices["src"] mod_indices["src"] += 1 parent_dir = os.path.join(host_codegen_dir, "src") - file_name = os.path.join(parent_dir, f"{lib_name}{index}.c") + file_name = os.path.join(parent_dir, f"{lib_name}{index}.{ext}") elif dso_mod.type_key == "llvm": index = mod_indices["lib"] mod_indices["lib"] += 1 diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 0884b249df488..7666691aa19f4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -18,13 +18,13 @@ import tvm from tvm import relay +from tvm import ir from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator from tvm.relay.backend.contrib.ethosu import util from tvm.relay.expr_functor import ExprMutator -from tvm.ir.transform import Pass # pylint: disable=unused-import from tvm.relay.backend.contrib.ethosu.op import op_attrs @@ -109,13 +109,11 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return new_call -@relay.transform.function_pass(opt_level=1, name="LUTsOptimizer") -class LUTsOptimizer(Pass): +@ir.transform.module_pass(opt_level=1, name="LUTsOptimizer") +class LUTsOptimizer: """Register LUTsOptimizer as a relay pass.""" - def transform_function( - self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ - ) -> tvm.IRModule: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: """Visit relay nodes in the given module. Parameters @@ -131,7 +129,13 @@ def transform_function( New module with optimized LUTs. """ assert len(mod.functions.items()) == 1, "Module can only contain one function." - return OptimizeLUTs().visit(func) + global_var, func = mod.functions.items()[0] + optimized_func = OptimizeLUTs().visit(func) + mod.update_func(global_var, optimized_func) + return mod + + def __call__(self, *args, **kwargs): + pass class LayoutOptimization(ExprMutator): @@ -247,19 +251,23 @@ def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: return super().visit_call(call) -@relay.transform.function_pass(opt_level=1, name="LayoutOptimizer") -class LayoutOptimizer(Pass): +@ir.transform.module_pass(opt_level=1, name="LayoutOptimizer") +class LayoutOptimizer: """Register LayoutOptimizer as a Relay pass.""" - def transform_function( - self, func: tvm.relay.function.Function, mod: tvm.IRModule, _ - ) -> tvm.IRModule: + def transform_module(self, mod: tvm.ir.IRModule, _) -> tvm.IRModule: """A pass to optimize the layout of NPU operations. If both the producer and consumer of a tensor are NPU operators, then the layout is converted from NHWC to NHCWB16 as this is the layout NPU uses internally.""" assert len(mod.functions.items()) == 1, "Module can only contain one function." - return LayoutOptimization().visit(func) + global_var, func = mod.functions.items()[0] + optimized_func = LayoutOptimization().visit(func) + mod.update_func(global_var, optimized_func) + return mod + + def __call__(self, *args, **kwargs): + pass @tvm._ffi.register_func("relay.ext.ethos-u.constant_updater") diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index ede9cd46371e4..d52f3ba6eca55 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -230,7 +230,7 @@ def __call__(self, *args, **kwargs): def sigmoid_calc_func(x: float) -> float: """Function to calculate the values for sigmoid""" - # Thse limits are inherited from TFLite + # These limits are inherited from TFLite upper_limit = 8.0 lower_limit = -8.0 @@ -290,8 +290,6 @@ def callback( "OHWI": params.weights.shape[1:3], "HWOI": params.weights.shape[0:2], } - if str(params.weights.layout) not in kernel_size_map.keys(): - raise UnsupportedLayout(str(params.weights.layout)) activation_map = {"clip": "CLIP"} weight_to_ohwi_transform_map = {"HWIO": [3, 0, 1, 2]} weights_values = params.weights.values @@ -375,13 +373,9 @@ def callback( channels_map = { "NHWC": 3, } - if str(params.ofm.layout) not in channels_map.keys(): - raise UnsupportedLayout(str(params.ofm.layout)) kernel_shape_map = { "HWOI": params.weights.shape[0:2], } - if str(params.weights.layout) not in kernel_shape_map.keys(): - raise UnsupportedLayout(str(params.weights.layout)) weights_values = params.weights.values weights_values_ohwi = np.moveaxis(weights_values, [0, 1, 2, 3], [1, 2, 0, 3]) @@ -470,8 +464,6 @@ def callback( channels_map = { "NHWC": 3, } - if str(params.ofm.layout) not in channels_map.keys(): - raise UnsupportedLayout(str(params.ofm.layout)) activation_map = {"clip": "CLIP"} if params.activation: @@ -632,11 +624,6 @@ def callback( params = self.params_class(post.op.body) params.ifm.tensor = post.args[1] if params.reversed_operands else post.args[0] params.ifm2.tensor = post.args[0] if params.reversed_operands else post.args[1] - channels_map = { - "NHWC": 3, - } - if str(params.ofm.layout) not in channels_map.keys(): - raise UnsupportedLayout(str(params.ofm.layout)) activation_map = {"clip": "CLIP"} if params.activation: @@ -665,8 +652,8 @@ def callback( ifm2_zero_point=int(params.ifm2.q_params.zero_point), ofm_scale=float(params.ofm.q_params.scale_f32), ofm_zero_point=int(params.ofm.q_params.zero_point), - ifm_channels=params.ifm.shape[-1], - ifm2_channels=params.ifm2.shape[-1], + ifm_channels=params.ifm.shape[-1] if params.ifm.shape else 1, + ifm2_channels=params.ifm2.shape[-1] if params.ifm2.shape else 1, reversed_operands=params.reversed_operands, ofm_dtype=params.ofm.dtype, activation=activation, @@ -963,9 +950,6 @@ def callback( params = self.params_class(post.op.body) params.ifm.tensor = post.args[0] - if str(params.ofm.layout) != "NHWC": - raise UnsupportedLayout(str(params.ofm.layout)) - activation_map = {"clip": "CLIP"} if params.activation: activation = activation_map[params.activation.op.name] diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py index c1d39556d11d2..8446b0c2e4ad7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -17,7 +17,10 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for binary_elementwise""" import operator +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -123,6 +126,12 @@ def binary_elementwise_compute( te.Tensor The Output Feature Map tensor. """ + assert ifm.shape[0] == 1 + assert ifm2.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ifm2_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute( ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0) @@ -187,5 +196,147 @@ def binary_elementwise_compute( attrs=binary_elementwise_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ifm2_matrix = [ + [1, 0, 0, 0, 0], + [0, (1 - int(broadcast[1])), 0, 0, int(broadcast[1])], + [0, 0, (1 - int(broadcast[2])), 0, int(broadcast[2])], + [0, 0, 0, (1 - int(broadcast[3])), int(broadcast[3])], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + ifm2_matrix = np.matmul(ifm2_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + if ifm2_layout == "NHCWB16": + ifm2_matrix = np.matmul(nhwc_to_nhcwb16, ifm2_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, 0, 0, 0] if ifm_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + ifm2_propagator = Propagator( + ifm2_matrix, + [0, 0, 0, 0] if ifm2_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "ifm2_propagator": ifm2_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ifm_channels) + return dma_ofm_compute( + binary_elementwise, + ofm_layout, + ofm_zero_point, + ofm_scale, + ifm_channels, + attrs=propagator_attrs, + ) + + +@register_matcher +def match_ethosu_binary_elementwise(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Binary Elementwise. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + binary_elementwise = convert_to_nhcwb16.op.input_tensors[0] + if binary_elementwise.op.name != "ethosu_binary_elementwise": + return None + pad = binary_elementwise.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + pad2 = binary_elementwise.op.input_tensors[1] + if pad2.op.name != "ethosu_pad": + return None + convert_to_nhwc2 = pad2.op.input_tensors[0] + if convert_to_nhwc2.op.name != "ethosu_convert_to_nhwc": + return None + read2 = convert_to_nhwc2.op.input_tensors[0] + if read2.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + read2.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["ifm2_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + input2_layout = convert_to_nhwc2.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + block_config = device_config.get_elementwise_block_config( + propagators[0], + propagators[1], + binary_elementwise.op.attrs, + output_tensor.shape, + output_layout, + input_layout, + input2_layout, + ifm_dtype, + ofm_dtype, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + 1, + block_config, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 766af0dbbeef1..ea2290ef1e5fe 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -180,7 +180,7 @@ def conv2d_compute( [1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0], - [0, 0, 16, 0, 0, 0], + [0, 0, 16, 0, 1, -16], [0, 0, 0, 0, 0, 1], ] ifm_matrix = [ @@ -236,7 +236,7 @@ def conv2d_compute( @register_matcher -def match_ethosu_conv2d(output_tensor): +def match_ethosu_conv2d(output_tensor, device_config): """Match a Tensor Expression corresponding to an NPU Conv2D. If the Tensor Expression matches, an EthosuPart will be created that models the @@ -246,6 +246,8 @@ def match_ethosu_conv2d(output_tensor): ---------- output_tensor : tvm.te.Tensor The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration Returns ------- @@ -277,17 +279,52 @@ def match_ethosu_conv2d(output_tensor): conv2d.op.input_tensors[1], conv2d.op.input_tensors[2], ] + subgraph = TESubgraph(input_tensors, output_tensor) propagators = [ write.op.attrs["ifm_propagator"], write.op.attrs["weights_propagator"], write.op.attrs["bias_propagator"], ] - # TODO(@jacobbohlin) Both the output_quantum and quantum_cycles here are placeholders, - # needs true implementation. - if convert_to_nhcwb16.op.attrs["layout"] == "NHWC": - output_quantum = [1, 2, 2, 1] - else: - output_quantum = [1, 2, 1, 2, 1] - quantum_cycles = 1000 - return EthosuPart(subgraph, propagators, output_quantum, quantum_cycles) + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels, kernel_height, kernel_width = (int(axis) for axis in input_tensors[1].shape[0:3]) + kernel_elements = kernel_height * kernel_width + + is_part_kernel = device_config.is_partkernel( + conv2d.op.name, ifm_channels, ifm_dtype, kernel_elements + ) + subkernels = len( + device_config.get_kernel_steps( + conv2d.op.name, kernel_height, kernel_width, ifm_dtype, is_part_kernel + ) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + conv2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + kernel_height, + kernel_width, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + 1, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index f54f2f3654e27..ff09662cc14af 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -17,8 +17,11 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for depthwise convolutions""" from typing import Tuple, Union, List +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -110,9 +113,10 @@ def depthwise_conv2d_compute( assert ifm_layout in {"NHWC", "NHCWB16"} assert ofm_layout in {"NHWC", "NHCWB16"} - stride_h, stride_w = strides - dilation_h, dilation_w = dilation - channels, kernel_h, kernel_w, _ = weight.shape + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + dilation_h, dilation_w = [int(v) for v in dilation] + channels, kernel_h, kernel_w, _ = [int(v) for v in weight.shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, channels, padding) @@ -165,5 +169,155 @@ def depthwise_conv2d_compute( attrs=depthwise_conv2d_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weights_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ] + bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weights_matrix = np.matmul(weights_matrix, nhcwb16_to_nhwc).tolist() + bias_matrix = np.matmul(bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + weights_propagator = Propagator( + weights_matrix, + [0, 0, 0, 0], + ) + bias_propagator = Propagator( + bias_matrix, + [0, 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + "weights_propagator": weights_propagator, + "bias_propagator": bias_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels) + return dma_ofm_compute( + depthwise, ofm_layout, ofm_zero_point, ofm_scale, channels, attrs=propagator_attrs + ) + + +@register_matcher +def match_ethosu_depthwise_conv2d(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Depthwise Conv2D. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration. + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + depthwise2d = convert_to_nhcwb16.op.input_tensors[0] + if depthwise2d.op.name != "ethosu_depthwise_conv2d": + return None + pad = depthwise2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + depthwise2d.op.input_tensors[1], + depthwise2d.op.input_tensors[2], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + write.op.attrs["weights_propagator"], + write.op.attrs["bias_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels, kernel_height, kernel_width = (int(axis) for axis in input_tensors[1].shape[0:3]) + + subkernels = len( + device_config.get_kernel_steps(depthwise2d.op.name, kernel_height, kernel_width, ifm_dtype) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + depthwise2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + kernel_height, + kernel_width, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + 1, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/inline.py b/python/tvm/relay/backend/contrib/ethosu/te/inline.py index 95e7342d5e827..79631f4b8c1c7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/inline.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/inline.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=unused-argument """Tensor Expressions for operations that will be inlined""" import numpy as np # type: ignore @@ -24,7 +25,7 @@ @register_matcher -def match_ethosu_inline(output_tensor): +def match_ethosu_inline(output_tensor, device_config): """Match a Tensor Expression corresponding to an operator that will be inlined. If the Tensor Expression matches, an InlinePart will be created that models the @@ -37,6 +38,8 @@ def match_ethosu_inline(output_tensor): ---------- output_tensor : tvm.te.Tensor The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration Returns ------- diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index e98a72db7f02e..aaf79e8a8c8d4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -18,7 +18,10 @@ """Tensor Expressions for poolings""" from typing import Tuple +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher + from .dma import dma_ofm_compute, dma_ifm_compute @@ -99,8 +102,13 @@ def pooling_compute( te.Tensor The OFM tensor. """ - stride_h, stride_w = strides - pool_shape_h, pool_shape_w = pool_shape + assert ifm.shape[0] == 1 + assert ifm_layout in {"NHWC", "NHCWB16"} + assert ofm_layout in {"NHWC", "NHCWB16"} + + padding = [int(v) for v in padding] + stride_h, stride_w = [int(v) for v in strides] + pool_shape_h, pool_shape_w = [int(v) for v in pool_shape] # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding) @@ -114,6 +122,8 @@ def pooling_compute( pooling_attrs = { "op": "ethosu_pooling", "pooling_type": pooling_type, + "pool_shape_h": pool_shape_h, + "pool_shape_w": pool_shape_w, "stride_h": stride_h, "stride_w": stride_w, "activation": activation, @@ -144,5 +154,128 @@ def pooling_compute( attrs=pooling_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (pool_shape_h - stride_h)], + [0, 0, stride_w, 0, (pool_shape_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + ifm_propagator = Propagator( + ifm_matrix, + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0], + ) + propagator_attrs = { + "ifm_propagator": ifm_propagator, + } + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + return dma_ofm_compute( + pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels, attrs=propagator_attrs + ) + + +@register_matcher +def match_ethosu_pooling(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Pooling. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + pool2d = convert_to_nhcwb16.op.input_tensors[0] + if pool2d.op.name != "ethosu_pooling": + return None + pad = pool2d.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + ifm_channels = int(input_tensors[0].shape[3]) + ofm_channels = ifm_channels + pool_shape_h = int(pool2d.op.attrs["pool_shape_h"]) + pool_shape_w = int(pool2d.op.attrs["pool_shape_w"]) + + subkernels = len( + device_config.get_kernel_steps(pool2d.op.name, pool_shape_h, pool_shape_w, ifm_dtype) + ) + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + valid_block_configs = device_config.get_valid_block_configs( + propagators[0], + pool2d.op.attrs, + output_tensor.shape, + ofm_channels, + ifm_channels, + output_layout, + input_layout, + ifm_dtype, + ofm_dtype, + pool_shape_h, + pool_shape_w, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + subkernels, + valid_block_configs, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py index 0aefc1c35d4c1..68d1c603ad98d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/unary_elementwise.py @@ -17,7 +17,9 @@ # pylint: disable=invalid-name,unused-argument """Tensor Expressions for unary_elementwise for the NPU""" +import numpy as np from tvm import te +from tvm.contrib.ethosu.cascader import TESubgraph, EthosuPart, Propagator, register_matcher from .dma import dma_ofm_compute, dma_ifm_compute @@ -127,5 +129,119 @@ def clz_imp(inp): attrs=unary_elementwise_attrs, ) + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + ifm_propagator = Propagator( + ifm_matrix, + [0, 0, 0, 0] if ifm_layout == "NHWC" else [0, 0, 0, 0, 0], + ) + propagator_attrs = {"ifm_propagator": ifm_propagator} + # Compute operation for the OFM DMA pipeline - return dma_ofm_compute(unary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) + return dma_ofm_compute( + unary_elementwise, + ofm_layout, + ofm_zero_point, + ofm_scale, + ofm_channels, + attrs=propagator_attrs, + ) + + +@register_matcher +def match_ethosu_unary_elementwise(output_tensor, device_config): + """Match a Tensor Expression corresponding to an NPU Unary Elementwise. + + If the Tensor Expression matches, an EthosuPart will be created that models the + matched Tensor Expression. Otherwise, None will be returned. + + Parameters + ---------- + output_tensor : tvm.te.Tensor + The tensor to attempt to match with. + device_config : EthosuDeviceConfig + Target device configuration + + Returns + ------- + Union[None, EthosuPart] + The created EthosuPart if there was a match, otherwise None. + + """ + write = output_tensor + if write.op.name != "ethosu_write": + return None + convert_to_nhcwb16 = write.op.input_tensors[0] + if convert_to_nhcwb16.op.name != "ethosu_convert_to_nhcwb16": + return None + unary_elementwise = convert_to_nhcwb16.op.input_tensors[0] + if unary_elementwise.op.name != "ethosu_unary_elementwise": + return None + pad = unary_elementwise.op.input_tensors[0] + if pad.op.name != "ethosu_pad": + return None + convert_to_nhwc = pad.op.input_tensors[0] + if convert_to_nhwc.op.name != "ethosu_convert_to_nhwc": + return None + read = convert_to_nhwc.op.input_tensors[0] + if read.op.name != "ethosu_read": + return None + + input_tensors = [ + read.op.input_tensors[0], + ] + subgraph = TESubgraph(input_tensors, output_tensor) + propagators = [ + write.op.attrs["ifm_propagator"], + ] + ifm_dtype = input_tensors[0].dtype + ofm_dtype = output_tensor.dtype + + output_layout = convert_to_nhcwb16.op.attrs["layout"] + input_layout = convert_to_nhwc.op.attrs["layout"] + output_quantum = device_config.get_output_quantum(output_layout) + + block_config = device_config.get_elementwise_block_config( + propagators[0], + None, + unary_elementwise.op.attrs, + output_tensor.shape, + output_layout, + input_layout, + None, + ifm_dtype, + ofm_dtype, + ) + + return EthosuPart( + subgraph, + propagators, + output_quantum, + 1, + block_config, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index c5bfa5cf92ef7..bcd785ddbbd8d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -123,12 +123,11 @@ def __init__(self): def visit_constant(self, const): if isinstance(const.checked_type, relay.ty.TensorType): - if const.checked_type.concrete_shape != (): - self.constants.append(const.data.asnumpy()) - name = "p" + str(len(self.constants)) - var = relay.var(type_annotation=const.checked_type, name_hint=name) - self.const_vars.append(var) - return var + self.constants.append(const.data.asnumpy()) + name = "p" + str(len(self.constants)) + var = relay.var(type_annotation=const.checked_type, name_hint=name) + self.const_vars.append(var) + return var return const diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 5720574526026..20a8ff85ee2f7 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -136,7 +136,10 @@ def _visit(tensor, reader, lut): if tensor not in planned: planned.add(tensor) if isinstance(tensor.op, tvm.te.PlaceholderOp) and tensor != lut: - index = list(cached_func.inputs).index(tensor) + # Find index of input using 'same_as' check to prevent equality + # ambiguity when encountering a scalar. + is_same = [var.same_as(tensor) for var in cached_func.inputs] + index = is_same.index(True) if index in const_dict: sch.cache_read(tensor, "global", [reader]) diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index 68bf767557d5d..36caf32340280 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -83,14 +83,13 @@ def convert_flatten(self, op): def convert_eltwise(self, op): """Convert Eltwise layer""" inputs = op.bottom - assert len(inputs) == 2, "input tensors length should be 2" + assert len(inputs) >= 2, "input tensors length should be larger than 2" + # gethering initial 2 input expressions lhs_expr = self.exp_tab.get_expr(inputs[0]) rhs_expr = self.exp_tab.get_expr(inputs[1]) - lhs_shape = _infer_shape(lhs_expr) rhs_shape = _infer_shape(rhs_expr) - assert lhs_shape == rhs_shape, "input tensors shape should be equal" eltwise_params = op.eltwise_param @@ -100,6 +99,11 @@ def convert_eltwise(self, op): if eltwise_type_dict[eltwise_type] == "PROD": out = _op.multiply(lhs_expr, rhs_expr) + # for rest inputs + for i in range(len(inputs) - 2): + extra_expr = self.exp_tab.get_expr(inputs[i + 2]) + assert _infer_shape(out) == _infer_shape(extra_expr) + out = _op.multiply(out, extra_expr) elif eltwise_type_dict[eltwise_type] == "SUM": if coeff: left_coeff_expr = self.exp_tab.new_const(np.asarray(coeff[0], np.float32)) @@ -109,8 +113,23 @@ def convert_eltwise(self, op): out = _op.add(lhs_expr_scale, rhs_expr_scale) else: out = _op.add(lhs_expr, rhs_expr) + # for rest inputs + for i in range(len(inputs) - 2): + extra_expr = self.exp_tab.get_expr(inputs[i + 2]) + assert _infer_shape(out) == _infer_shape(extra_expr) + if coeff: + coeff_expr = self.exp_tab.new_const(np.asarray(coeff[i + 2], np.float32)) + extra_expr_scale = _op.multiply(extra_expr, coeff_expr) + out = _op.add(out, extra_expr_scale) + else: + out = _op.add(out, extra_expr) elif eltwise_type_dict[eltwise_type] == "MAX": out = _op.maximum(lhs_expr, rhs_expr) + # for rest inputs + for i in range(len(inputs) - 2): + extra_expr = self.exp_tab.get_expr(inputs[i + 2]) + assert _infer_shape(out) == _infer_shape(extra_expr) + out = _op.maximum(out, extra_expr) else: raise tvm.error.OpNotImplemented( "eltwise_type {} is not supported for frontend Caffe.".format(eltwise_type) @@ -515,21 +534,76 @@ def convert_deconv(self, op): if weight: kh, kw = params["kernel_size"] weight_shape = [-1, conv_params.num_output, kh, kw] - weight_value = np.asarray(weight.data, np.float32) + if not weight.data: + if conv_params.weight_filler: + _filler = conv_params.weight_filler.value + weight_value = np.full(weight.shape.dim, _filler, np.float32) + else: + raise tvm.error.OpAttributeInvalid("At least weight_filler must be given") + else: + weight_value = np.asarray(weight.data, np.float32) weight_value = np.reshape(weight_value, weight_shape) # weight shape is in relay's IOHW format rn, we need it to be OIHW weight_value = np.transpose(weight_value, [1, 0, 2, 3]) else: - raise Exception("No weight value of layer {} in caffemodel".format(op.name)) + raise tvm.error.OpAttributeRequired( + "No weight value of layer {} in caffemodel".format(op.name) + ) weight_expr = self.exp_tab.new_const(weight_value, dtype="float32") in_expr = self.exp_tab.get_expr(inputs[0]) - out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) + + groups = params["groups"] + channels = params["channels"] + if bias: bias_value = np.asarray(bias.data, np.float32) bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") - out = _op.nn.bias_add(out, bias_expr) + + if groups > channels: + raise tvm.error.OpAttributeInvalid( + "Groups cannot be larger than the number of input channels" + ) + + if groups == channels: + inputs_expr = _op.split(in_expr, groups, axis=1) + # changing split axis to 0, according to PR #9336 + weights_expr = _op.split(weight_expr, groups, axis=0) + # Preventing to create Concat layer with too many tensors(> 16) + q = groups >> 4 + r = groups % 16 + + params["groups"] = 1 + params["channels"] = 1 + out = [] + for lc in range(q): + _outputs = [] + _inputs = [inputs_expr[i] for i in range(lc << 4, (lc << 4) + 16)] + _weights = [weights_expr[i] for i in range(lc << 4, (lc << 4) + 16)] + for (i, w) in zip(_inputs, _weights): + _out = _op.nn.conv2d_transpose(data=i, weight=w, **params) + if bias: + _out = _op.nn.bias_add(_out, bias_expr) + _outputs.append(_out) + out.append(_op.concatenate(_outputs, axis=1)) + if r != 0: + _outputs = [] + _inputs = [inputs_expr[i] for i in range(groups - r, groups)] + _weights = [weights_expr[i] for i in range(groups - r, groups)] + for (i, w) in zip(_inputs, _weights): + _out = _op.nn.conv2d_transpose(data=i, weight=w, **params) + if bias: + _out = _op.nn.bias_add(_out, bias_expr) + _outputs.append(_out) + out.append(_op.concatenate(_outputs, axis=1)) + out = _op.concatenate(out, axis=1) + elif groups == 1: + out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) + if bias: + out = _op.nn.bias_add(out, bias_expr) + else: + raise tvm.error.OpAttributeInvalid("Unable to handle.") return out def convert_slice(self, op): diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 1d4c8ad757624..f8c12ff334dba 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -716,11 +716,11 @@ def gru_cell( b_inp, b_hid : relay.Expr bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size) r_act : relay.op - activation funtion for reset gate. it is sigmoid by default + activation function for reset gate. it is sigmoid by default z_act : relay.op - activation funtion for update gate. it is sigmoid by default + activation function for update gate. it is sigmoid by default n_act : relay.op - activation funtion for new gate. it is tanh by default + activation function for new gate. it is tanh by default backwards : bool Flag for reverse pass of GRU @@ -812,7 +812,7 @@ def lstm_cell( p_i, p_f, p_o : relay.Expr peephole LSTM matrices. shape = (batch, hidden_size) f_act, g_act, h_act : relay.op - activation funtions + activation functions backwards : bool Flag for reverse pass of LSTM diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 263eb851f867c..234beec244ba3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -38,9 +38,9 @@ from .. import ty as _ty from .. import vision as _vision from .common import ( - autopad, AttrCvt, Renamer, + autopad, ensure_scalar_shape, fold_constant, get_name, @@ -238,18 +238,6 @@ def flatten_to_nd(x, x_shape, nd=3): out = _op.reshape(x, fold_constant(newshape)) return out - b_type = infer_type(inputs[1]) - # Convert to dense if the second matrix is 2d and non-dynamic - if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): - a = flatten_to_nd(inputs[0], a_shape, 2) - b = _op.transpose(inputs[1]) - output = _op.nn.dense(a, b, out_dtype=out_dtype) - else: - # Convert a and b into 3 dimensional tensors. - a = flatten_to_nd(inputs[0], a_shape, 3) - b = flatten_to_nd(inputs[1], b_shape, 3) - # Perform a NN batch matmul. - output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) @@ -268,6 +256,42 @@ def flatten_to_nd(x, x_shape, nd=3): ], 0, ) + + b_type = infer_type(inputs[1]) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(inputs[0], a_shape, 2) + b = _op.transpose(inputs[1]) + output = _op.nn.dense(a, b, out_dtype=out_dtype) + else: + # broadcast a and b + a_broadcasted_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice(a_shape, [a_rank - 2], [a_rank]), + ], + 0, + ) + b_broadcasted_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice(b_shape, [b_rank - 2], [b_rank]), + ], + 0, + ) + a = _op.transform.broadcast_to(inputs[0], fold_constant(a_broadcasted_shape)) + b = _op.transform.broadcast_to(inputs[1], fold_constant(b_broadcasted_shape)) + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(a, shape_of(a), 3) + b = flatten_to_nd(b, shape_of(b), 3) + if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Transpose matrix dimensions of b. + bt = _op.transpose(b, [0, 2, 1]) + # Perform a NT batch matmul. + output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype) + else: + # Perform a NN batch matmul. + output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) # Reshape output to original dimensions. final_shape = _op.concatenate( [ @@ -551,13 +575,13 @@ class ConvTranspose(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): # get number of channels out_type = infer_type(inputs[1]) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - channels = out_shapes[0][1] - attr["channels"] = channels + kernel_shape = [get_const_tuple(out_type.checked_type.shape)] + out_channels = kernel_shape[0][1] * attr.get("group", 1) + attr["channels"] = out_channels groups = attr.get("group", 1) if "kernel_shape" not in attr: - attr["kernel_shape"] = out_shapes[0][2:] + attr["kernel_shape"] = kernel_shape[0][2:] attr["groups"] = groups # infer pads for auto_pad @@ -606,13 +630,13 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v11(cls, inputs, attr, params): # get number of channels out_type = infer_type(inputs[1]) - out_shapes = [get_const_tuple(out_type.checked_type.shape)] - channels = out_shapes[0][1] - attr["channels"] = channels + kernel_shape = [get_const_tuple(out_type.checked_type.shape)] + out_channels = kernel_shape[0][1] * attr.get("group", 1) + attr["channels"] = out_channels groups = attr.get("group", 1) if "kernel_shape" not in attr: - attr["kernel_shape"] = out_shapes[0][2:] + attr["kernel_shape"] = kernel_shape[0][2:] attr["groups"] = groups # infer pads for auto_pad diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b0b2bd3e75d9a..12296bd505422 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1898,6 +1898,7 @@ def convert_fully_connected(self, op): fully_connected_options = FullyConnectedOptions() fully_connected_options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = fully_connected_options.FusedActivationFunction() + keep_num_dims = fully_connected_options.KeepNumDims() # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 weight_tensor_type = weight_tensor.tensor.Type() @@ -1975,6 +1976,13 @@ def convert_fully_connected(self, op): else: out = self.convert_fused_activation_function(out, fused_activation_fn) + # Change the output shape calculation based on keep_dim option + if keep_num_dims: + input_shape = _infer_shape(self.get_tensor_expr(input_tensor)) + output_shape = input_shape + output_shape[-1] = weight_tensor_shape[0] + out = _op.reshape(out, output_shape) + return out def convert_squeeze(self, op): diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 3793f947c5cc1..a3e5f110d365e 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -52,7 +52,6 @@ reshape_like, strided_slice, take, - tile, transpose, where, repeat, @@ -399,15 +398,14 @@ def conv2d_grad(orig, grad): data_shape = get_const_tuple(data.checked_type.shape) weight_shape = get_const_tuple(weight.checked_type.shape) _, _, grad_h, grad_w = get_const_tuple(orig.checked_type.shape) - batch, in_channel, in_h, in_w = data_shape - out_channel, _, filter_h, filter_w = weight_shape + _, _, in_h, in_w = data_shape + _, _, filter_h, filter_w = weight_shape # infer output_padding fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( get_const_tuple(attrs.padding), (filter_h, filter_w) ) stride_h, stride_w = get_const_tuple(attrs.strides) - dilation_h, dilation_w = get_const_tuple(attrs.dilation) out_h = (grad_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h out_w = (grad_w - 1) * stride_w - fpad_left - fpad_right + filter_w output_padding = (in_h - out_h, in_w - out_w) @@ -425,46 +423,21 @@ def conv2d_grad(orig, grad): groups=attrs.groups, output_padding=output_padding, ) - grad = tile(grad, [1, in_channel // attrs.groups, 1, 1]) - grad = reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow - data = reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw - backward_weight = _nn.conv2d( - data, + backward_weight = _nn.conv2d_backward_weight( grad, - strides=attrs.dilation, + data, + strides=attrs.strides, padding=attrs.padding, - dilation=attrs.strides, - groups=in_channel * batch, - ) - # infer shape of backward_weight - padded_weight_grad_h = ( - in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom - ) // dilation_h + 1 - padded_weight_grad_w = ( - in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right - ) // dilation_w + 1 - backward_weight = reshape( - backward_weight, - [ - batch, - in_channel // attrs.groups, - out_channel, - padded_weight_grad_h, - padded_weight_grad_w, - ], + dilation=attrs.dilation, + groups=attrs.groups, + channels=attrs.channels, + kernel_size=(filter_h, filter_w), + grad_layout=attrs.out_layout if attrs.out_layout else attrs.data_layout, + data_layout=attrs.data_layout, + kernel_layout=attrs.kernel_layout, + out_dtype=attrs.out_dtype, ) - backward_weight = _sum(backward_weight, axis=0) - backward_weight = transpose(backward_weight, [1, 0, 2, 3]) - - assert padded_weight_grad_h >= filter_h - assert padded_weight_grad_w >= filter_w - if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: - backward_weight = strided_slice( - backward_weight, - begin=[0, 0, 0, 0], - end=[out_channel, in_channel // attrs.groups, filter_h, filter_w], - ) return [backward_data, backward_weight] diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 9aa883d9b750b..2a941cc8c28ac 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -23,6 +23,7 @@ from tvm.runtime import convert from tvm.te.hybrid import script from tvm.topi.utils import get_const_tuple +from tvm.topi.nn.utils import get_pad_tuple from ....ir import container from ....tir import expr @@ -1061,6 +1062,83 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_injective_schedule("nn.batch_to_space_nd") +@reg.register_legalize("nn.conv2d_backward_weight") +def legalize_conv2d_backward_weight(attrs, inputs, types): + """Legalize conv2d_backward_weight op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + grad, data = inputs + data_shape = get_const_tuple(data.checked_type.shape) + weight_shape = get_const_tuple(types[2].shape) + _, out_channel, grad_h, grad_w = get_const_tuple(grad.checked_type.shape) + batch, in_channel, in_h, in_w = data_shape + _, _, filter_h, filter_w = weight_shape + fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple( + get_const_tuple(attrs.padding), (filter_h, filter_w) + ) + stride_h, stride_w = get_const_tuple(attrs.strides) + dilation_h, dilation_w = get_const_tuple(attrs.dilation) + + grad = relay.tile(grad, [1, in_channel // attrs.groups, 1, 1]) + grad = relay.reshape(grad, [-1, 1, 0, 0]) # batch * oc * ic // groups, 1, oh, ow + data = relay.reshape(data, [1, -1, 0, 0]) # 1, batch * ic, ih, iw + + backward_weight = relay.nn.conv2d( + data, + grad, + strides=attrs.dilation, + padding=attrs.padding, + dilation=attrs.strides, + groups=in_channel * batch, + ) + + # infer shape of backward_weight + padded_weight_grad_h = ( + in_h - (grad_h - 1) * stride_h - 1 + fpad_top + fpad_bottom + ) // dilation_h + 1 + padded_weight_grad_w = ( + in_w - (grad_w - 1) * stride_w - 1 + fpad_left + fpad_right + ) // dilation_w + 1 + + backward_weight = relay.reshape( + backward_weight, + [ + batch, + in_channel // attrs.groups, + out_channel, + padded_weight_grad_h, + padded_weight_grad_w, + ], + ) + backward_weight = relay.sum(backward_weight, axis=0) + backward_weight = relay.transpose(backward_weight, [1, 0, 2, 3]) + + assert padded_weight_grad_h >= filter_h + assert padded_weight_grad_w >= filter_w + + if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: + backward_weight = relay.strided_slice( + backward_weight, + begin=[0, 0, 0, 0], + end=[out_channel, in_channel // attrs.groups, filter_h, filter_w], + ) + + return backward_weight + + ##################### # Shape functions # ##################### diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index c7b376ec3d647..857e4c3eb9bae 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -3770,3 +3770,54 @@ def batch_to_space_nd(data, block_shape, crops): """ return _make.batch_to_space_nd(data, block_shape, crops) + + +def conv2d_backward_weight( + grad, + data, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + grad_layout="NCHW", + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="", +): + r"""The gradient of conv2d with respect to weight. + + This operator takes the output gradient `grad` and convolves it with `data` as + the convolution kernel, to produce the gradient with respect to weight. + + Note that the parameter `kernel_size` is the spatial size of the corresponding + forward convolution kernel, not that of `data`. `grad_layout` and + `kernel_layout` are the layouts of `grad` and the weight gradient respectively. + + Other parameters are the same as the conv2d op. See its documentation for more + details. + + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation) + padding = get_pad_tuple2d(padding) + + return _make.conv2d_backward_weight( + grad, + data, + strides, + padding, + dilation, + groups, + channels, + kernel_size, + grad_layout, + data_layout, + kernel_layout, + out_dtype, + ) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 542980561e780..f21e3eaf2c3cd 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -73,7 +73,7 @@ def get_tensor_array_shape(expr, dtype, prelude): return None -def _get_name_static(canonical, dtype, shape, batch_dim=None): +def _get_name_static(canonical, dtype, shape, batch_dim=None, extra_shapes=None): """Get name for static shape tensor array op By design, static ADT tensor in TVM has type name in the format @@ -100,14 +100,12 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None): name : String The tensor array op name """ - dim_names = [] - for dim in shape: - if isinstance(dim, Any): - dim_names.append("any") - else: - dim_names.append(str(dim)) + shape_str = _to_str(shape) - shape_str = "_".join(dim_names) + if extra_shapes is not None: + for n, s in extra_shapes.items(): + extra_shape_str = "_{}_{}".format(n, _to_str(s)) + shape_str += extra_shape_str if len(shape_str) == 0: shape_str = "scalar" @@ -120,6 +118,16 @@ def _get_name_static(canonical, dtype, shape, batch_dim=None): return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str) +def _to_str(shape): + dim_names = [] + for dim in shape: + if isinstance(dim, Any): + dim_names.append("any") + else: + dim_names.append(str(dim)) + return "_".join(dim_names) + + class StaticTensorArrayOps(object): """Contains tensor array related ops for fixed rank tensor array""" @@ -131,9 +139,9 @@ def __init__(self, prelude, dtype, shape, batch_dim=None): self.batch_dim = batch_dim self.list, self.cons, self.nil = self.prelude.mod.get_type("List") - def get_name(self, canonical): + def get_name(self, canonical, extra_shapes=None): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim) + return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim, extra_shapes) def get_global_var(self, canonical): """Get global corresponding to the canonical name""" @@ -408,11 +416,16 @@ def define_tensor_array_scatter(self, indices_shape=None, force_update=False): # When this operator has already been registered, only update # when force_update is set. This should be used only when we need to # redefine this op for static indices shape. - tensor_array_scatter_name = self.get_name("tensor_array_scatter") + + extra_shapes = {"indices": indices_shape} if indices_shape is not None else None + tensor_array_scatter_name = self.get_name("tensor_array_scatter", extra_shapes) if hasattr(self.prelude, tensor_array_scatter_name) and not force_update: return - tensor_array_scatter_helper_name = self.get_name("tensor_array_scatter_helper") + tensor_array_scatter_helper_name = self.get_name( + "tensor_array_scatter_helper", extra_shapes + ) + tensor_array_scatter_helper_var = self._create_global_var(tensor_array_scatter_helper_name) ta = Var("ta", self.list(self.tensor_type_var())) current = Var("current", scalar_type("int32")) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 4749b63b858f3..b69afd69b8c54 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -105,14 +105,18 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): ---------- data : tvm.relay.Expr The input tensor to be quantized. Can be of type float32. - output_zero_point : tvm.relay.Expr - The output zero_point. + output_scale : tvm.relay.Expr The output scale. + + output_zero_point : tvm.relay.Expr + The output zero_point. + axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. out_dtype : str, optional The data type of the input tensor. Can be [int8, uint8, int32] + Returns ------- result : tvm.relay.Expr @@ -132,14 +136,18 @@ def simulated_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype ---------- data : tvm.relay.Expr The input tensor to be quantized. Can be of type float32. - output_zero_point : tvm.relay.Expr - The output zero_point. + + out_dtype : string or tvm.relay.Expr + A string or tensor indicating which datatype to quantize to. + output_scale : tvm.relay.Expr The output scale. + + output_zero_point : tvm.relay.Expr + The output zero_point. + axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. - out_dtype : string or tvm.relay.Expr - A string or tensor indicating which datatype to quantize to. Returns ------- @@ -166,12 +174,16 @@ def dequantize(data, input_scale, input_zero_point, axis=-1): ---------- data : tvm.relay.Expr The input tensor to be dequantized. Can be of type [int8, uint8, int32]. - input_zero_point : tvm.relay.Expr - The input zero_point. + input_scale : tvm.relay.Expr The input scale. + + input_zero_point : tvm.relay.Expr + The input zero_point. + axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. + Returns ------- result : tvm.relay.Expr @@ -191,14 +203,18 @@ def simulated_dequantize(data, input_scale, input_zero_point, axis=-1, in_dtype= ---------- data : tvm.relay.Expr The input tensor to be dequantized. - input_zero_point : tvm.relay.Expr - The input zero_point. + + in_dtype : string or tvm.relay.Expr + A string or tensor indicating which datatype to dequantize from. + input_scale : tvm.relay.Expr The input scale. + + input_zero_point : tvm.relay.Expr + The input zero_point. + axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. - in_dtype : string or tvm.relay.Expr - A string or tensor indicating which datatype to dequantize from. Returns ------- diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 9fc75199bdf50..909712511061f 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -127,6 +127,7 @@ def check_grad( fwd_func = run_infer_type(func) bwd_func = run_infer_type(gradient(fwd_func, mode=mode)) + bwd_func = run_opt_pass(bwd_func, relay.transform.Legalize()) if scale is None: scale = 10 * eps diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 25a57bbb1c368..37bab4a9b714a 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -185,6 +185,11 @@ def type_key(self): """Get type key of the module.""" return _ffi_api.ModuleGetTypeKey(self) + @property + def format(self): + """Get the format of the module.""" + return _ffi_api.ModuleGetFormat(self) + def get_source(self, fmt=""): """Get source code from module, if available. @@ -402,7 +407,12 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No for index, module in enumerate(modules): if fcompile is not None and hasattr(fcompile, "object_format"): if module.type_key == "c": - object_format = "c" + assert module.format in [ + "c", + "cc", + "cpp", + ], "The module.format needs to be either c, cc or cpp" + object_format = module.format has_c_module = True else: object_format = fcompile.object_format @@ -411,7 +421,15 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No object_format = "o" else: assert module.type_key == "c" - object_format = "c" + if len(module.format) > 0: + assert module.format in [ + "c", + "cc", + "cpp", + ], "The module.format needs to be either c, cc or cpp" + object_format = module.format + else: + object_format = "c" if "cc" in kwargs: if kwargs["cc"] == "nvcc": object_format = "cu" diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 2b9f7f9446baf..8400a5998e398 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -127,7 +127,7 @@ def __setitem__(self, in_slice, value): raise TypeError("type %s not supported" % str(type(value))) def copyfrom(self, source_array): - """Peform an synchronize copy from the array. + """Perform an synchronize copy from the array. Parameters ---------- diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 56072d7b00c2f..dbaba46fdc9ca 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -433,9 +433,7 @@ def _get_targets(target_str=None): DEFAULT_TEST_TARGETS = [ "llvm", - "llvm -device=arm_cpu", "cuda", - "cuda -model=unknown -libs=cudnn", "nvptx", "vulkan -from_device=0", "opencl", diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b261fd0a75189..7d352f156a313 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -369,6 +369,32 @@ def sample_perfect_tile( ) ) + @type_checked + def sample_compute_location( + self, + block: BlockRV, + decision: Optional[int] = None, + ) -> LoopRV: + """Sample a compute-at location of the given block + + Parameters + ---------- + block : BlockRV + The block whose compute-at location is to be sampled + decision : Optional[int] + The sampling decision + + Returns + ------- + result : LoopRV + The sampled loop where the input block is to be computed at + """ + return _ffi_api.ScheduleSampleComputeLocation( # type: ignore # pylint: disable=no-member + self, + block, + decision, + ) + ########## Schedule: Get blocks & loops ########## @type_checked def get_block( diff --git a/python/tvm/topi/cuda/batch_matmul_tensorcore.py b/python/tvm/topi/cuda/batch_matmul_tensorcore.py index d05984b913938..ac16dd7b65b4f 100644 --- a/python/tvm/topi/cuda/batch_matmul_tensorcore.py +++ b/python/tvm/topi/cuda/batch_matmul_tensorcore.py @@ -203,7 +203,7 @@ def _schedule(cfg, s, C): s[BF].reorder(bs, o, i, o_ii, i_ii) # Schedule for A's(B's) shared memory load - def shared_shedule(stage, strides): + def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) bs, xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) @@ -217,8 +217,8 @@ def shared_shedule(stage, strides): s[stage].bind(tx, thread_x) s[stage].vectorize(vi) - shared_shedule(AS, AS_align) - shared_shedule(BS, BS_align) + shared_schedule(AS, AS_align) + shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) diff --git a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py index 1e368f585354e..698beeac6dc46 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py @@ -165,7 +165,7 @@ def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack): s[BF].reorder(i, o, i_ii, o_ii) # Schedule for A's(B's) shared memory load - def shared_shedule(stage, strides): + def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) _, _, xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) @@ -179,8 +179,8 @@ def shared_shedule(stage, strides): s[stage].bind(tx, thread_x) s[stage].vectorize(vi) - shared_shedule(AS, AS_align) - shared_shedule(BS, BS_align) + shared_schedule(AS, AS_align) + shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) in_dtype = "float16" diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py index ea37364becde9..c3c5b6e7cf850 100644 --- a/python/tvm/topi/cuda/dense.py +++ b/python/tvm/topi/cuda/dense.py @@ -42,10 +42,11 @@ def _matmul_cublas_common( assert len(bias.shape) == 1 if out_dtype is None: out_dtype = tensor_a.dtype - assert out_dtype == tensor_a.dtype, "Mixed precision not supported." + if out_dtype not in [tensor_a.dtype, "int32"]: + assert out_dtype == tensor_a.dtype, "Mixed precision other than int8 + int32 not supported." batch, in_dim = get_const_tuple(tensor_a.shape) out_dim, _ = get_const_tuple(tensor_b.shape) - matmul = cublas.matmul(tensor_a, tensor_b, transpose_a, transpose_b) + matmul = cublas.matmul(tensor_a, tensor_b, transpose_a, transpose_b, dtype=out_dtype) if all(isinstance(d, int) for d in [batch, in_dim, out_dim]): cfg.add_flop(batch * in_dim * out_dim * 2) if bias is not None: diff --git a/python/tvm/topi/cuda/dense_tensorcore.py b/python/tvm/topi/cuda/dense_tensorcore.py index 20ff1aaccc5fc..7acc1307f84c9 100644 --- a/python/tvm/topi/cuda/dense_tensorcore.py +++ b/python/tvm/topi/cuda/dense_tensorcore.py @@ -238,7 +238,7 @@ def _schedule_dense_tensorcore(cfg, s, C): s[BF].reorder(o, i, o_ii, i_ii) # Schedule for A's(B's) shared memory load - def shared_shedule(stage, strides): + def shared_schedule(stage, strides): s[stage].compute_at(s[CF], ko) xo, yo = stage.op.axis s[stage].storage_align(xo, strides - 1, strides) @@ -252,8 +252,8 @@ def shared_shedule(stage, strides): s[stage].bind(tx, thread_x) s[stage].vectorize(vi) - shared_shedule(AS, AS_align) - shared_shedule(BS, BS_align) + shared_schedule(AS, AS_align) + shared_schedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype) diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 640c13f4372f2..dc70e0ed89f90 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -228,8 +228,8 @@ def schedule_conv_NCHWc_cpu_common_int8( batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + s[C].compute_at(s[O], ow_block) parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) elif out_ndim == 4: @@ -237,8 +237,8 @@ def schedule_conv_NCHWc_cpu_common_int8( ow_chunk, ow_block = s[O].split(ow, factor=reg_n) oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + s[C].compute_at(s[O], ow_block) parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: @@ -301,8 +301,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) + s[CC].compute_at(s[C], ow_inner) parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) - s[CC].compute_at(s[C], parallel_axis) if C == O: s[C].parallel(parallel_axis) @@ -346,8 +346,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].compute_at(s[O], ow_inner) parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) elif out_ndim == 4: @@ -357,8 +357,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].compute_at(s[O], ow_inner) parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: diff --git a/python/tvm/topi/testing/__init__.py b/python/tvm/topi/testing/__init__.py index 345886c2be91b..75eabffc957a5 100644 --- a/python/tvm/topi/testing/__init__.py +++ b/python/tvm/topi/testing/__init__.py @@ -75,3 +75,4 @@ from .nll_loss import nll_loss from .dense import dense from .searchsorted import searchsorted_ref +from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py b/python/tvm/topi/testing/conv2d_backcward_weight_python.py new file mode 100644 index 0000000000000..587cd45b49c11 --- /dev/null +++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-nested-blocks +"""Gradient of conv2d with respect to weight in python""" +import numpy as np + + +# Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h +def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding): + """Gradient of the conv2d op with respect to weight, in NCHW layout. + + Parameters + ---------- + dy_np : numpy.ndarray + 4-D with shape [batch, in_channel, out_height, out_width] + + x_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel_size : tuple of two ints + Height and width of the weight + + stride : tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : tuple of two ints + Spatial padding, or [pad_h, pad_w] + + Returns + ------- + b_np : np.ndarray + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + + """ + N, C, H, W = x_np.shape + _, K, P, Q = dy_np.shape + R, S = kernel_size + pad_h, pad_w = padding + stride_h, stride_w = stride + dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) + + for k in range(K): + for r in range(R): + for s in range(S): + for c in range(C): + acc = 0 + for n in range(N): + for p in range(P): + for q in range(Q): + coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s) + + if ( + coord[2] < H + and coord[2] >= 0 + and coord[3] < W + and coord[3] >= 0 + ): + acc += dy_np[n, k, p, q] * x_np[coord] + + dw[k, c, r, s] = acc + + return dw diff --git a/python/tvm/topi/x86/utils.py b/python/tvm/topi/x86/utils.py index 658a92966257b..50c5c848ee0aa 100644 --- a/python/tvm/topi/x86/utils.py +++ b/python/tvm/topi/x86/utils.py @@ -81,8 +81,8 @@ def target_has_avx512(target): # explicit enumeration of VNNI capable due to collision with alderlake "cascadelake", "icelake-client", + "icelake-server", "rocketlake", - "icelake", "tigerlake", "cooperlake", "sapphirerapids", @@ -93,8 +93,8 @@ def target_has_vnni(target): return target in { "cascadelake", "icelake-client", + "icelake-server", "rocketlake", - "icelake", "tigerlake", "cooperlake", "sapphirerapids", diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index eff52308f389e..6dec9a5502e1a 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -78,7 +78,7 @@ inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, } /*! - * \brief Peform trunc division of two integers. + * \brief Perform trunc division of two integers. * \param x The left operand. * \param y The right operand. * \return the result. @@ -94,7 +94,7 @@ inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; } inline int64_t truncmod(int64_t x, int64_t y) { return x % y; } /*! - * \brief Peform floor division of two integers. + * \brief Perform floor division of two integers. * \param x The left operand. * \param y The right operand. * \return the result. diff --git a/src/contrib/ethosu/cascader/block_config.cc b/src/contrib/ethosu/cascader/block_config.cc new file mode 100644 index 0000000000000..fe698aa17aaca --- /dev/null +++ b/src/contrib/ethosu/cascader/block_config.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "block_config.h" + +#include +#include +#include + +#include +#include + +#include "common.h" + +namespace tvm { +namespace contrib { +namespace ethosu { +namespace cascader { + +void BlockConfigNode::VisitAttrs(AttrVisitor* v) { + Array tmp_arr = make_array(output_shape_); + v->Visit("_output_shape", &tmp_arr); +} + +BlockConfig::BlockConfig(const std::vector& output_shape, int compute_cycles, + int output_cycles) { + auto n = make_object(); + n->output_shape_ = std::move(output_shape); + n->compute_cycles_ = compute_cycles; + n->output_cycles_ = output_cycles; + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.BlockConfig") + .set_body_typed([](Array output_shape, int compute_cycles, int output_cycles) { + std::vector voutput_shape = make_vector(output_shape); + return BlockConfig(voutput_shape, compute_cycles, output_cycles); + }); + +TVM_REGISTER_NODE_TYPE(BlockConfigNode); + +} // namespace cascader +} // namespace ethosu +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/ethosu/cascader/block_config.h b/src/contrib/ethosu/cascader/block_config.h new file mode 100644 index 0000000000000..d7da1d90e82ee --- /dev/null +++ b/src/contrib/ethosu/cascader/block_config.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/ethosu/cascader/block_config.h + * \brief BlockConfig object for the NPU cascader + */ +#ifndef TVM_CONTRIB_ETHOSU_CASCADER_BLOCK_CONFIG_H_ +#define TVM_CONTRIB_ETHOSU_CASCADER_BLOCK_CONFIG_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace contrib { +namespace ethosu { +namespace cascader { + +class BlockConfig; + +/*! \brief Node to represent a BlockConfig */ +class BlockConfigNode : public Object { + public: + void VisitAttrs(AttrVisitor* v); + + /*! + * \brief Get the shape of output block. + * \return The output shape of the block config. + */ + inline std::vector GetOutputBlockShape() const { return output_shape_; } + + /*! + * \brief Get the number of cycles required to output this block + * \return The output cycles + */ + inline int GetOutputCycles() const { return output_cycles_; } + + /*! + * \brief Get the number of cycles required to compute this block + * \return The compute cycles + */ + inline int GetComputeCycles() const { return compute_cycles_; } + + static constexpr const char* _type_key = "contrib.ethosu.cascader.BlockConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockConfigNode, Object); + + protected: + friend class BlockConfig; + + /*! \brief The shape of the output block */ + std::vector output_shape_; + /*! \brief Cycles required to compute this block */ + int compute_cycles_; + /*! \brief Cycles required to output this block */ + int output_cycles_; +}; + +/*! + * \brief An object that contains a an output block shape as well as the output and compute cycles + * required to compute this block + */ +class BlockConfig : public ObjectRef { + public: + BlockConfig(const std::vector& output_shape, int compute_cycles, int output_cycles); + + TVM_DEFINE_OBJECT_REF_METHODS(BlockConfig, ObjectRef, BlockConfigNode); +}; + +} // namespace cascader +} // namespace ethosu +} // namespace contrib +} // namespace tvm + +#endif // TVM_CONTRIB_ETHOSU_CASCADER_BLOCK_CONFIG_H_ diff --git a/src/contrib/ethosu/cascader/common.h b/src/contrib/ethosu/cascader/common.h index ec62861049a31..b4b5664e04b98 100644 --- a/src/contrib/ethosu/cascader/common.h +++ b/src/contrib/ethosu/cascader/common.h @@ -68,6 +68,22 @@ inline Array make_array(const std::vector& vec) { return arr; } +/*! + * \brief Make a tvm::Array from an int64_t vector. + * \param vec The int64_t vector. + * \return The IntImm Array. + * \note Array(std::vector) doesn't work as this implicit + * type conversion fails. This is why this helper is required. + */ +inline Array make_array(const std::vector& vec) { + Array arr; + arr.resize(vec.size()); + for (unsigned int i = 0; i < vec.size(); ++i) { + arr.Set(i, IntImm(DataType::Int(64), vec[i])); + } + return arr; +} + /*! * \brief Make a tvm::Array from an float vector. * \param vec The float vector. @@ -82,6 +98,16 @@ inline Array make_array(const std::vector& vec) { return arr; } +/*! + * \brief Calculate the ceil of an Integer division + * \param dividend The dividend of the division + * \param divisor The divisor of the division + * \return The quotient + */ +inline int round_up_divide(int dividend, int divisor) { + return dividend / divisor + (dividend % divisor != 0); +} + /*! * \brief Make a vector from a tvm::Array. * \param arr The Array. diff --git a/src/contrib/ethosu/cascader/graph.cc b/src/contrib/ethosu/cascader/graph.cc index a930c2606e185..ce28f728d838f 100644 --- a/src/contrib/ethosu/cascader/graph.cc +++ b/src/contrib/ethosu/cascader/graph.cc @@ -38,12 +38,10 @@ namespace ethosu { namespace cascader { void PerformanceInfoNode::VisitAttrs(AttrVisitor* v) { - int compute_cycles_int = static_cast(compute_cycles); - v->Visit("_compute_cycles", &compute_cycles_int); - Array tmp_reads = make_array(read_bytes); + v->Visit("_compute_cycles", &compute_cycles); + Array tmp_reads = make_array(read_bytes); v->Visit("_read_bytes", &tmp_reads); - int write_bytes_int = static_cast(write_bytes); - v->Visit("_write_bytes", &write_bytes_int); + v->Visit("_write_bytes", &write_bytes); } TVM_REGISTER_NODE_TYPE(PerformanceInfoNode); @@ -147,8 +145,9 @@ TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetStripeAlignHint").set_body_t return make_array(align_hint); }); TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.PartGetPerformanceInfo") - .set_body_typed([](Part part, StripeConfig stripe_config, bool is_rolling) { - return part->GetPerformanceInfo(stripe_config, is_rolling); + .set_body_typed([](Part part, StripeConfig stripe_config, int buffer_mode) { + BufferMode ebuffer_mode = static_cast(buffer_mode); + return part->GetPerformanceInfo(stripe_config, ebuffer_mode); }); CascaderGraphNode::CascaderGraphNode(std::vector input_tensors, diff --git a/src/contrib/ethosu/cascader/graph.h b/src/contrib/ethosu/cascader/graph.h index 2bea890c722b2..81cbd1c9da5f3 100644 --- a/src/contrib/ethosu/cascader/graph.h +++ b/src/contrib/ethosu/cascader/graph.h @@ -44,6 +44,14 @@ class Tensor; class Part; class StripeConfig; +/*! + * \brief The buffering mode to use when realizing a tensor. + * RECOMPUTE - The 'default' behaviour of TVM. Overlapping stripes will be recomputed. + * ROLLING - Apply both the sliding window and storage folding optimizations to the tensor + * realization. + */ +enum BufferMode { RECOMPUTE, ROLLING }; + /*! \brief A struct to hold a Tensor Expression subgraph */ struct TESubgraph { /*! \brief The input te::Tensors to the subgraph */ @@ -58,11 +66,11 @@ class PerformanceInfoNode : public Object { void VisitAttrs(AttrVisitor* v); /*! \brief The cycles to compute a block */ - size_t compute_cycles; + int64_t compute_cycles; /*! \brief The number of bytes read per input tensor */ - std::vector read_bytes; + std::vector read_bytes; /*! \brief The number of bytes written to the output tensor */ - size_t write_bytes; + int64_t write_bytes; static constexpr const char* _type_key = "contrib.ethosu.cascader.PerformanceInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(PerformanceInfoNode, Object); @@ -77,7 +85,7 @@ class PerformanceInfoNode : public Object { */ class PerformanceInfo : public ObjectRef { public: - PerformanceInfo(size_t compute_cycles, std::vector read_bytes, size_t write_bytes) { + PerformanceInfo(int64_t compute_cycles, std::vector read_bytes, int64_t write_bytes) { auto n = make_object(); n->compute_cycles = compute_cycles; n->read_bytes = std::move(read_bytes); @@ -190,7 +198,7 @@ class PartNode : public Object { * \return The performance information containing the compute cycles and read/write bytes. */ virtual const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) = 0; + BufferMode buffer_mode) = 0; static constexpr const char* _type_key = "contrib.ethosu.cascader.Part"; TVM_DECLARE_BASE_OBJECT_INFO(PartNode, Object); diff --git a/src/contrib/ethosu/cascader/parts/ethosu.cc b/src/contrib/ethosu/cascader/parts/ethosu.cc index 29b43269c7b69..cdbbda18c142c 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.cc +++ b/src/contrib/ethosu/cascader/parts/ethosu.cc @@ -21,6 +21,9 @@ #include #include +#include +#include +#include #include #include @@ -32,62 +35,114 @@ namespace contrib { namespace ethosu { namespace cascader { -const std::vector EthosuPartNode::GetBlockShape(const StripeConfig& output_stripe_config, - bool is_rollling) { - std::vector block_shape; - for (int axis : output_stripe_config->GetShape()) { - block_shape.push_back(std::min(axis, 4)); - } - return block_shape; -} +const std::vector EthosuPartNode::GetBytesRead(const std::vector& block_shape, + const std::vector& full_shape) { + std::vector bytes_per_input(propagators_.size(), 0); -const std::vector EthosuPartNode::GetBlockInputBytes_(const std::vector& block_shape) { - std::vector bytes_per_input; - std::vector strides; std::vector order; std::vector stripes; std::vector offset; + std::vector strides; for (size_t i = 0; i < block_shape.size(); i++) { - strides.push_back(1.0); order.push_back(1); - stripes.push_back(1); + stripes.push_back(round_up_divide(full_shape[i], block_shape[i])); offset.push_back(0); + strides.push_back(static_cast(block_shape[i])); } - StripeConfig output_block_config(block_shape, block_shape, strides, order, stripes, offset); + + StripeConfig output_block_config(block_shape, full_shape, strides, order, stripes, offset); auto input_block_configs = CalculateInputStripeConfigs(output_block_config); + + int i = 0; for (const auto& input_block_config : input_block_configs) { - bytes_per_input.push_back(mul_reduce(input_block_config->GetShape())); + std::map, int> input_blocks = CountStripes(input_block_config, false); + for (const auto& block : input_blocks) { + bytes_per_input[i] += mul_reduce(block.first) * block.second; + } + i++; } + + if (weight_tensor_idx_ != -1) { + bytes_per_input[weight_tensor_idx_] *= (stripes[height_idx_] * stripes[width_idx_]); + } + return bytes_per_input; } +const BlockConfig EthosuPartNode::GetBlockConfig(const StripeConfig& output_stripe_config) { + BlockConfig best_block_config; + float best_cost = std::numeric_limits::infinity(); + std::vector output_stripe_shape = output_stripe_config->GetShape(); + + for (const auto& block_config : valid_block_configs_) { + std::vector output_block = block_config->GetOutputBlockShape(); + + std::vector bytes_per_input = GetBytesRead(output_block, output_stripe_shape); + bytes_per_input[0] *= subkernels_; + + // Calculate bytes read per output element + float relative_cost = static_cast(bytes_per_input[0] + bytes_per_input[1]) / + mul_reduce(output_stripe_shape); + + // Single buffering hardware optimization + if (mul_reduce(output_stripe_shape) <= 2 * mul_reduce(output_block)) { + relative_cost /= 2; + } + + if (relative_cost < best_cost) { + best_block_config = block_config; + best_cost = relative_cost; + } + } + + return best_block_config; +} + const PerformanceInfo EthosuPartNode::GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) { - std::vector block_shape = GetBlockShape(output_stripe_config, is_rolling); - std::vector bytes_per_input = GetBlockInputBytes_(block_shape); - int bytes_per_output = mul_reduce(block_shape); - int num_blocks = 1; + BufferMode buffer_mode) { + BlockConfig block_config = GetBlockConfig(output_stripe_config); + std::vector block_shape = block_config->GetOutputBlockShape(); + + std::vector bytes_per_input = + GetBytesRead(block_shape, output_stripe_config->GetShape()); + + int elements_per_block = mul_reduce(block_shape); + int bytes_per_output = elements_per_block; + float num_blocks = 1.0f; for (size_t i = 0; i < block_shape.size(); i++) { - if (!is_rolling) { - num_blocks *= output_stripe_config->GetShape()[i] * output_stripe_config->GetStripes()[i] / + if (buffer_mode == BufferMode::RECOMPUTE) { + num_blocks *= static_cast(output_stripe_config->GetShape()[i] * + output_stripe_config->GetStripes()[i]) / block_shape[i]; } else { - num_blocks *= output_stripe_config->GetExtent()[i] / block_shape[i]; + num_blocks *= + std::max(static_cast(output_stripe_config->GetExtent()[i]) / block_shape[i], 1.0f); } } - int num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1; - std::vector read_bytes; + float num_stripes = mul_reduce(output_stripe_config->GetStripes()) - 1.0f; + std::vector read_bytes; for (int block_bytes : bytes_per_input) { read_bytes.push_back((num_blocks + num_stripes) * block_bytes); } - int write_bytes = (num_blocks + num_stripes) * bytes_per_output; - auto shape = output_stripe_config->GetShape(); - PerformanceInfo info(0, read_bytes, write_bytes); + int64_t write_bytes = (num_blocks + num_stripes) * bytes_per_output; + + int block_output_cycles = block_config->GetOutputCycles(); + int block_compute_cycles = block_config->GetComputeCycles(); + + int64_t total_cycles = 0; + if (block_output_cycles > block_compute_cycles) { + total_cycles = (block_output_cycles * num_blocks) + block_compute_cycles; + } else { + total_cycles = (block_compute_cycles * num_blocks) + block_output_cycles; + } + + PerformanceInfo info(total_cycles, read_bytes, write_bytes); return info; } EthosuPart::EthosuPart(const TESubgraph& subgraph, const std::vector propagators, - const std::vector output_quantum, int quantum_cycles) { + const std::vector& output_quantum, int subkernels, + const std::vector& valid_block_configs, int weight_tensor_idx) { auto n = make_object(); ICHECK_GT(propagators.size(), 0) << "The Part must include at least one Propagator."; n->subgraph_ = subgraph; @@ -95,21 +150,40 @@ EthosuPart::EthosuPart(const TESubgraph& subgraph, const std::vector n->in_line_ = false; n->input_tensors_.resize(propagators.size()); n->output_quantum_ = output_quantum; - n->quantum_cycles_ = quantum_cycles; + n->valid_block_configs_ = valid_block_configs; + n->subkernels_ = subkernels; + n->weight_tensor_idx_ = weight_tensor_idx; + if (output_quantum.size() == 5) { + // NHCWB16 Format + n->height_idx_ = 1; + n->width_idx_ = 3; + } else { + // NHWC Format + n->height_idx_ = 1; + n->width_idx_ = 2; + } data_ = std::move(n); } TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.EthosuPart") .set_body_typed([](Array subgraph_inputs, te::Tensor subgraph_output, - Array propagators, Array output_quantum, - int quantum_cycles) { + Array propagators, Array output_quantum, int subkernels, + Array valid_block_configs, int weight_tensor_idx) { std::vector vsubgraph_inputs(subgraph_inputs.begin(), subgraph_inputs.end()); std::vector vpropagators(propagators.begin(), propagators.end()); + std::vector voutput_quantum(output_quantum.begin(), output_quantum.end()); TESubgraph subgraph; subgraph.input_tensors = vsubgraph_inputs; subgraph.output_tensor = subgraph_output; - std::vector voutput_quantum = make_vector(output_quantum); - return EthosuPart(subgraph, vpropagators, voutput_quantum, quantum_cycles); + std::vector vvalid_block_configs(valid_block_configs.begin(), + valid_block_configs.end()); + return EthosuPart(subgraph, vpropagators, voutput_quantum, subkernels, vvalid_block_configs, + weight_tensor_idx); + }); + +TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.EthosuPartGetBlockConfig") + .set_body_typed([](EthosuPart part, StripeConfig stripe_config) { + return part->GetBlockConfig(stripe_config); }); TVM_REGISTER_NODE_TYPE(EthosuPartNode); diff --git a/src/contrib/ethosu/cascader/parts/ethosu.h b/src/contrib/ethosu/cascader/parts/ethosu.h index ab3ca69d27170..cd8fa84eca2b0 100644 --- a/src/contrib/ethosu/cascader/parts/ethosu.h +++ b/src/contrib/ethosu/cascader/parts/ethosu.h @@ -28,6 +28,7 @@ #include +#include "../block_config.h" #include "../graph.h" namespace tvm { @@ -39,11 +40,10 @@ namespace cascader { class EthosuPartNode : public PartNode { public: /*! - * \brief Get the optimal block shape to use. + * \brief Get the optimal BlockConfig to use given a StripeConfig * \param output_stripe_config The output StripeConfig. - * \param is_rolling Whether the output config should be computed as a rolling buffer. */ - const std::vector GetBlockShape(const StripeConfig& output_stripe_config, bool is_rolling); + const BlockConfig GetBlockConfig(const StripeConfig& output_stripe_config); /*! * \brief Get the preferred alignment in each axis for a stripe of the Part. * \note This is used to bias the selection of StripeConfigs towards those that are integer @@ -53,11 +53,11 @@ class EthosuPartNode : public PartNode { /*! * \brief Get the performance information for a given output stripe config. * \param output_stripe_config The output stripe config to compute the performance for. - * \param is_rolling Whether the output config should be computed as a rolling buffer. + * \param buffer_mode The mode of buffering, rolling or recompute. * \return The performance information containing the compute cycles and read/write bytes. */ const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) final; + BufferMode buffer_mode) final; static constexpr const char* _type_key = "contrib.ethosu.cascader.EthosuPart"; TVM_DECLARE_FINAL_OBJECT_INFO(EthosuPartNode, PartNode); @@ -66,16 +66,27 @@ class EthosuPartNode : public PartNode { friend class EthosuPart; /*! - * \brief Get the size of input required (per input tensor) to compute a block. - * \param block_shape The shape of the block to compute. + * \brief Get the size of input required (per input tensor) to compute a stripe given a + * block_shape + * \param block_shape The shape of the block(s) the stripe is split into + * \param stripe_shape The shape of the full stripe to compute. * \return The bytes required per input tensor. */ - const std::vector GetBlockInputBytes_(const std::vector& block_shape); + const std::vector GetBytesRead(const std::vector& block_shape, + const std::vector& full_shape); + /*! \brief List of block configs that are valid for this part */ + std::vector valid_block_configs_; /*! \brief The output volume that is atomically computed */ std::vector output_quantum_; - /*! \brief The cycles taken to compute a single output quantum */ - int quantum_cycles_; + /*! \brief Index for output height dimension */ + int height_idx_; + /*! \brief Index for output width dimension */ + int width_idx_; + /*! \brief Index of weight tensor, -1 if the Part has no weights */ + int weight_tensor_idx_; + /*! \brief Number of sub-kernels the kernel has been split into */ + int subkernels_; }; /*! @@ -86,7 +97,8 @@ class EthosuPartNode : public PartNode { class EthosuPart : public Part { public: EthosuPart(const TESubgraph& subgraph, const std::vector propagators, - const std::vector output_quantum, int quantum_cycles); + const std::vector& output_quantum, int subkernels, + const std::vector& valid_block_configs, int weight_tensor_idx); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EthosuPart, Part, EthosuPartNode); }; diff --git a/src/contrib/ethosu/cascader/parts/inline.cc b/src/contrib/ethosu/cascader/parts/inline.cc index ff5e055084e6d..cb216e7d14543 100644 --- a/src/contrib/ethosu/cascader/parts/inline.cc +++ b/src/contrib/ethosu/cascader/parts/inline.cc @@ -31,8 +31,8 @@ namespace ethosu { namespace cascader { const PerformanceInfo InlinePartNode::GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) { - std::vector read_bytes(input_tensors_.size()); + BufferMode buffer_mode) { + std::vector read_bytes(input_tensors_.size()); PerformanceInfo info(0, read_bytes, 0); return info; } diff --git a/src/contrib/ethosu/cascader/parts/inline.h b/src/contrib/ethosu/cascader/parts/inline.h index 44f2762319fbb..11d94f17397d3 100644 --- a/src/contrib/ethosu/cascader/parts/inline.h +++ b/src/contrib/ethosu/cascader/parts/inline.h @@ -45,7 +45,7 @@ class InlinePartNode : public PartNode { * \return The performance information containing the compute cycles and read/write bytes. */ const PerformanceInfo GetPerformanceInfo(const StripeConfig& output_stripe_config, - bool is_rolling) final; + BufferMode buffer_mode) final; static constexpr const char* _type_key = "contrib.ethosu.cascader.InlinePart"; TVM_DECLARE_FINAL_OBJECT_INFO(InlinePartNode, PartNode); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7dc7b28b968bf..e750344f4f0ca 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -44,6 +44,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); @@ -191,6 +192,8 @@ Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); + bool disable_storage_rewrite = + pass_ctx->GetConfig("tir.disable_storage_rewrite", Bool(false)).value(); bool instrument_bound_checkers = pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); @@ -260,7 +263,9 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); - pass_list.push_back(tir::transform::StorageRewrite()); + if (!disable_storage_rewrite) { + pass_list.push_back(tir::transform::StorageRewrite()); + } pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc new file mode 100644 index 0000000000000..edf13e36bef44 --- /dev/null +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Extract attribute from a target. */ +Integer Extract(const Target& target, const char* name) { + ICHECK(target.defined()); + if (Optional v = target->GetAttr(name)) { + return v.value(); + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +/*! \brief Verify the correctness of the generated GPU code. */ +class VerifyGPUCodeNode : public PostprocNode { + public: + Map target_constraints_{nullptr}; + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->target_constraints_ = Map{ + {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, + {"max_local_memory_per_block", Extract(target, "registers_per_block")}, + {"max_threads_per_block", Extract(target, "max_threads_per_block")}, + {"max_vthread", Integer(8)}, + {"max_vector_bytes", Integer(16)}}; + } + + bool Verify(const IRModule& mod) const { + for (const auto& kv : mod->functions) { + if (const auto* prim_func = kv.second.as()) { + if (!tir::VerifyGPUCode(GetRef(prim_func), this->target_constraints_)) { + return false; + } + } + } + return true; + } + + bool Apply(const tir::Schedule& sch) final { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + IRModule lowered{nullptr}; + try { + auto pass_list = Array(); + // Phase 1 + // First three passes are not needed in TIR schedule. + // pass_list.push_back(tir::transform::InjectPrefetch()); + // pass_list.push_back(tir::transform::TextureFlatten()); + // pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerCrossThreadReduction()); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::BF16Legalize()); + pass_list.push_back(tir::transform::NarrowDataType(32)); + pass_list.push_back(tir::transform::Simplify()); + + // Phase 2 + pass_list.push_back(tir::transform::VectorizeLoop(true)); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::InjectDoubleBuffer()); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + + // Convert Function to IRModule + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", + runtime::String(g_var->name_hint)); + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); + } catch (const dmlc::Error& e) { + return false; + } + if (!Verify(lowered)) { + return false; + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; + TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); +}; + +Postproc Postproc::VerifyGPUCode() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc new file mode 100644 index 0000000000000..5ef2ac3aad367 --- /dev/null +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class AddRFactorNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->max_parallel_basic_ = GetTargetNumCores(target); + if (this->max_jobs_per_core != -1) { + this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core; + } + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + public: + /*! + * \brief The maximum number of jobs to be launched per core. + * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int max_jobs_per_core; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The number of uplimit of parallelism. */ + int max_parallel_extent_; + /*! \brief The number of cores. */ + int max_parallel_basic_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `max_parallel_extent_` is not visited + // `max_parallel_basic_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.AddRFactor"; + TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, + Optional max_innermost_factor) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->max_parallel_extent_ = -1; + n->max_parallel_basic_ = -1; + return ScheduleRule(n); +} + +Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, + max_parallel_basic_)) { + return {sch}; + } + + // Make a copy of the original schedule. + tir::Schedule ori_sch = sch->Copy(); + ori_sch->Seed(sch->ForkSeed()); + + // Reorder the loop axes if reduction loops are not innermost. + // After the reordering, fuse all the reduction loops. + size_t num_spatial_loops; + tir::LoopRV fused_reduce_loop; + ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); + + // Split the fused reduction loop. + Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); + const Array& split_loops = + sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + + Array res; + for (const tir::LoopRV& split_loop : split_loops) { + tir::Schedule sch_tmp = sch->Copy(); + sch_tmp->Seed(sch->ForkSeed()); + try { + const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); + Array axes = sch_tmp->GetLoops(block_rf); + ICHECK_GT(axes.size(), num_spatial_loops); + + // Annotate that the rfactor block, which is now the producer of the original block, needs to + // be considered by the rule Random-Compute-Location. + sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Bool(true)); + res.push_back(sch_tmp); + } catch (const tvm::runtime::Error& e) { + } + } + + res.push_back(ori_sch); + return res; +} + +TVM_REGISTER_NODE_TYPE(AddRFactorNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") + .set_body_typed(ScheduleRule::AddRFactor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc new file mode 100644 index 0000000000000..38156f86e6cbf --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The type of inline to be performed on a specific block */ +enum class InlineType : int32_t { + /*! \brief No inline opportunity */ + kNoInline = 0, + /*! \brief Inline the block into its consumer */ + kInlineIntoConsumer = 1, + /*! \brief Inline the block into its producer */ + kInlineIntoProducer = 2, +}; + +/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +class AutoInlineNode : public ScheduleRuleNode { + public: + /*! \brief Checks if the specific block should be inlined */ + inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + InlineType inline_type = CheckInline(sch, block_rv); + if (inline_type == InlineType::kInlineIntoConsumer) { + sch->ComputeInline(block_rv); + } else if (inline_type == InlineType::kInlineIntoProducer) { + sch->ReverseComputeInline(block_rv); + } + return {sch}; + } + + public: + /*! \brief If allows to inline a block into its producer */ + bool into_producer; + /*! \brief If allows to inline a block into its consumer */ + bool into_consumer; + /*! \brief Always inline constant tensors */ + bool inline_const_tensor; + /*! \brief Always disallow if-then-else-like constructs */ + bool disallow_if_then_else; + /*! \brief Always require the read-to-write mapping to be injective to do auto inline */ + bool require_injective; + /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ + bool require_ordered; + /*! \brief The operators that are disallowed in auto inline */ + Array disallow_op; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("into_producer", &into_producer); + v->Visit("into_consumer", &into_consumer); + v->Visit("inline_const_tensor", &inline_const_tensor); + v->Visit("disallow_if_then_else", &disallow_if_then_else); + v->Visit("require_injective", &require_injective); + v->Visit("require_ordered", &require_ordered); + v->Visit("disallow_op", &disallow_op); + } + + static constexpr const char* _type_key = "meta_schedule.AutoInline"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); +}; + +inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { + using namespace tvm::tir; + StmtSRef block_sref = sch->GetSRef(block_rv); + ScheduleState state = sch->state(); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + BlockRealize realize = GetBlockRealize(state, block_sref); + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return InlineType::kNoInline; + } + // Cond 2. For a block that generates a constant tensor, ignore all other conditions + if (inline_const_tensor && block->reads.empty()) { + return InlineType::kInlineIntoConsumer; + } + // Cond 3. The block doesn't contain any disallowed operators + if (!disallow_op.empty() && HasOp(realize, disallow_op)) { + return InlineType::kNoInline; + } + // Cond 4. The block doesn't have any if-then-else-like constructs + if (disallow_if_then_else && HasIfThenElse(realize)) { + return InlineType::kNoInline; + } + // Cond 5. The mapping from read indices to write indices are injective and ordered + if (require_injective || require_ordered) { + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool injective, ordered; + auto _ = std::ignore; + std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_, + /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region); + if (require_injective && injective == false) { + return InlineType::kNoInline; + } + if (require_ordered && ordered == false) { + return InlineType::kNoInline; + } + } + } + // Last cond: Check inline into the consumers or the spatial producer + tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false); + if (into_consumer) { + Array consumer_srefs = GetConsumers(state, block_sref); + if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { + return InlineType::kInlineIntoConsumer; + } + } + if (into_producer) { + Array producer_srefs = GetProducers(state, block_sref); + if (producer_srefs.size() == 1 && + tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && + CanReverseComputeInline(state, block_sref)) { + return InlineType::kInlineIntoProducer; + } + } + return InlineType::kNoInline; +} + +ScheduleRule ScheduleRule::AutoInline(bool into_producer, // + bool into_consumer, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op) { + ObjectPtr n = make_object(); + n->into_producer = into_producer; + n->into_consumer = into_consumer; + n->inline_const_tensor = inline_const_tensor; + n->disallow_if_then_else = disallow_if_then_else; + n->require_injective = require_injective; + n->require_ordered = require_ordered; + n->disallow_op.clear(); + if (disallow_op.defined()) { + Array op_names = disallow_op.value(); + n->disallow_op.reserve(op_names.size()); + for (const String& op_name : op_names) { + n->disallow_op.push_back(Op::Get(op_name)); + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(AutoInlineNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") + .set_body_typed(ScheduleRule::AutoInline); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc new file mode 100644 index 0000000000000..957ad89af106d --- /dev/null +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class RandomComputeLocationNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + if (!CheckConditions(sch, block_rv)) { + return {sch}; + } + + // Step 1. If the producer of the input block needs a random compute-at location (specified by + // the annotation), we collect the producer first, and transform the producer block later. + // - The reason we collect the producer before transforming the input block is that, if the + // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer + // access the input block. Hence we collect its producer ahead of time. + // - Note that only single producer is allowed in this case. + Array producers{nullptr}; + if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, + true)) { + producers = sch->GetProducers(block_rv); + sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); + ICHECK_EQ(producers.size(), 1); + } + + // Step 2. Transform the input block. + tir::Schedule res = RandomlyComputeAt(sch, block_rv); + + // Step 3. Transform the producer block if compute-location sampling is needed. + if (producers.defined()) { + res = RandomlyComputeAt(res, producers[0]); + } + + return {res}; + } + + private: + bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + + // Cond 1. The block is not the root block. + if (block_sref->parent == nullptr) { + return false; + } + // Cond 2. The block should be the direct child block of the root block. + if (GetScopeRoot(sch->state(), block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false) + ->parent != nullptr) { + return false; + } + // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child + // block. + Array loop_srefs = tir::GetLoops(block_sref); + if (loop_srefs.empty()) { + return false; + } + if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { + return false; + } + // Cond 5. The block is not tiled. We check this condition by examine the block's annotation. + if (tir::HasBeenMultiLevelTiled(block_sref)) { + return false; + } + // Cond 6. The block has at lease one consumer. + if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { + return false; + } + return true; + } + + /*! + * \brief Keep sampling a compute-at location for the input block until success. + * \param sch The TIR schedule + * \param block_rv The block whose compute-at location is to be sampled + * \return The TIR schedule after transformation + */ + tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); + sch->ComputeAt(block_rv, compute_at_loc, true); + return sch; + } + + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::RandomComputeLocation() { + return ScheduleRule(make_object()); +} + +TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") + .set_body_typed(ScheduleRule::RandomComputeLocation); +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ef15f49955418..5b497695400a2 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -318,6 +318,26 @@ struct ThreadedTraceApply { Item* items_; }; +/*! + * \brief Get the number of cores in CPU + * \param target The target + * \return The number of cores. + */ +inline int GetTargetNumCores(const Target& target) { + int num_cores = target->GetAttr("num-cores").value_or(-1); + if (num_cores == -1) { + static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count"); + ICHECK(f_cpu_count) + << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; + num_cores = (*f_cpu_count)(false); + LOG(FATAL) + << "Target does not have attribute \"num-cores\", physical core number must be " + "defined! For example, on the local machine, the target must be \"llvm -num-cores " + << num_cores << "\""; + } + return num_cores; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 277a6042cbebf..023bb0d3ef00d 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -222,7 +222,9 @@ class TVMScriptPrinter : public StmtFunctor, void TryDeallocVar(const Var& var); bool ContainsOptionalInfo(const Stmt& stmt); /*! - * \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified + * \brief Check if a buffer declaration satisfies: + * 1. has only 'shape' and 'dtype' arguments specified, + * 2. the shape and strides are not dynamic. * \param buffer The match buffer to be checked */ bool IsSimpleBuffer(const Buffer& buffer); @@ -481,6 +483,7 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { // check if all arguments, except the first two, are specified for T.match_buffer // if not, then this match buffer is printed out as T.buffer in prim_func arguments +// and check whether there are undefined variables in the shape/strides. bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) { if (memo_var_.find(buf->data) != memo_var_.end()) { return false; @@ -488,7 +491,17 @@ bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) { if (!buf->strides.empty()) { return false; } - if (buf->elem_offset->IsInstance()) { + for (const PrimExpr& shp_i : buf->shape) { + if (!UndefinedVars(shp_i).empty()) { + return false; + } + } + for (const PrimExpr& stride_i : buf->strides) { + if (!UndefinedVars(stride_i).empty()) { + return false; + } + } + if (!UndefinedVars(buf->elem_offset).empty()) { return false; } else if (buf->elem_offset->IsInstance()) { IntImm elem_offset = Downcast(buf->elem_offset); @@ -1302,6 +1315,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { // check if this param is a T.handle if (it != op->buffer_map.end()) { // check if this match_buffer has only the first two arguments specified + // and whether the match_buffer is a dynamic buffer. const Buffer& buf = (*it).second; if (IsSimpleBuffer(buf)) { simple_buf.insert(buf); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index d901f8a26c4f2..f076efeb4ac50 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -35,12 +35,14 @@ #include #include #include +#include #include #include #include #include +#include "../../target/source/codegen_source_base.h" #include "../op/annotation/annotation.h" #include "../op/call/call.h" #include "../op/memory/device_copy.h" @@ -290,7 +292,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); if (output_iter != return_sid_.end()) { int output_index = std::distance(return_sid_.begin(), output_iter); - buffer_vars.push_back(main_signature_[input_vars_.size() + output_index]); + buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index)); continue; } @@ -308,7 +310,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { if (input_iter != input_vars_.end()) { // Input variable int main_index = std::distance(input_vars_.begin(), input_iter); - return {main_signature_[main_index]}; + return {GetBufferVarForIO(main_index)}; } else { // Storage identifier (i.e., intermediate memory) return PackSid(arg); @@ -331,7 +333,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { if (params_by_expr_.find(arg) != params_by_expr_.end()) { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[arg])}); - args.push_back(param_handle); + args.push_back(tvm::tir::Cast(DataType::Handle(), param_handle)); } else { auto var_arg = FindExpr(arg); for (const auto& var : var_arg) { @@ -405,26 +407,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { auto tmp1 = te::Var("tmp1", DataType::Handle()); te::Var loop_idx("i", DataType::Int(32)); auto retval_i = tir::Load(DataType::UInt(8), tmp0, loop_idx, tir::const_true()); - - PrimExpr retval_get = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), - {in, 0, tir::builtin::kArrData}); - PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(), - {out, 0, tir::builtin::kArrData}); - if (use_unpacked_api_) { - tostore = out; - } - - // Do not pack the input if the flag is set or the caller - // explicitly asked to do so (e.g., copying a param to the output) - if (use_unpacked_api_ || !pack_input) { - retval_get = in; - } - // Copy the variable from the input to the output - tir::Stmt copy = tir::For( - loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, - tir::Store(tmp1, tir::Let(tmp0, retval_get, retval_i), loop_idx, tir::const_true())); - stmts_.push_back(tir::LetStmt(tmp1, tostore, copy)); + tir::Stmt copy = + tir::For(loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial, + tir::Store(tmp1, tir::Let(tmp0, in, retval_i), loop_idx, tir::const_true())); + stmts_.push_back(tir::LetStmt(tmp1, out, copy)); } /* @@ -546,12 +533,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { if (params_by_expr_.find(expr) != params_by_expr_.end()) { auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, - /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); } else { auto var_expr = FindExpr(expr); - CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], - /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); } } } @@ -572,7 +559,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { int output_index = std::distance(return_sid_.begin(), output_iter); auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false, + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, false, sinfo->storage_sizes_in_bytes[0]); } } @@ -645,32 +632,131 @@ class AOTExecutorCodegen : public MixedModeVisitor { // TODO(giuseros): we should allocate this once outside the PrimFunc // so we don't pay the price of allocation for every inference if (!allocated[sid]) { - body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); + PointerType ptype = Downcast(sids_table_[sid]->type_annotation); + DataType element_type = Downcast(ptype->element_type)->dtype; + body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body); } allocated[sid] = true; } } - // Define the attributes - body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body); - body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body); - // Define the PrimFunc attributes Map dict_attrs; String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix); dict_attrs.Set("global_symbol", run_func_name); dict_attrs.Set("runner_function", Bool(true)); + dict_attrs.Set(tvm::attr::kTarget, target_host_); tir::Stmt device_activations = GenerateAllDeviceHook("Activate"); tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate"); tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); // Make the PrimFunc - return tir::PrimFunc(main_signature_, final_body, VoidType(), Map(), + return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, DictAttrs(dict_attrs)); } + /*! + * brief Access IO vars using the buffer vars and + * not the actual var. + */ + tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } + + /*! + * brief Create tir::Var for input/output while updating + * the buffer_maps. + */ + void CreateIOVar(const Expr& expr, std::string name) { + if (expr->IsInstance()) { + Tuple tuple = Downcast(expr); + for (unsigned i = 0; i < tuple->fields.size(); i++) { + CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_"); + } + } else { + tir::Var var = tir::Var(name, DataType::Handle()); + main_signature_.push_back(var); + auto tensor_type = expr->checked_type().as(); + DataType elem_type = tensor_type->dtype; + tir::Var buffer_var = + tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global")); + tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, + name + "_buffer", 16, 1, tir::BufferType::kDefault); + main_buffer_map_.Set(var, buffer); + } + } + + /*! + * brief Run USMP to plan memory for lowered IRModule + */ + IRModule PlanMemoryWithUSMP(const IRModule& mod) { + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); + Integer workspace_byte_alignment = + executor_config->GetAttr("workspace-byte-alignment").value_or(16); + IRModule lowered_mod = mod->ShallowCopy(); + lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod); + // Update workspace size based on the pool allocations. + for (const auto& kv : function_metadata_) { + if (lowered_mod->ContainGlobalVar(kv.first) && + lowered_mod->Lookup(kv.first)->IsInstance()) { + tir::PrimFunc pfunc = Downcast(lowered_mod->Lookup(kv.first)); + Target tgt = pfunc->GetAttr(tvm::attr::kTarget).value(); + const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment); + kv.second->workspace_sizes.Set(tgt, ws); + } + } + Optional> allocated_pool_infos = + lowered_mod->GetAttr>(tvm::attr::kPoolArgs); + backend::FunctionInfo main_func_info = + lowered_mod->GetAttr("main_func_info").value(); + main_func_info->workspace_sizes.clear(); + if (allocated_pool_infos) { + for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { + for (const auto& kv : allocated_pool_info->pool_info->target_access) { + Target tgt = kv.first; + if (main_func_info->workspace_sizes.find(tgt) == main_func_info->workspace_sizes.end()) { + main_func_info->workspace_sizes.Set(tgt, allocated_pool_info->allocated_size); + } else { + main_func_info->workspace_sizes.Set(tgt, + main_func_info->workspace_sizes[tgt]->value + + allocated_pool_info->allocated_size->value); + } + } + } + } + function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); + return lowered_mod; + } + + /*! + * brief Run StorageRewrite to plan memory for lowered IRModule + */ + IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { + Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); + Integer workspace_byte_alignment = + executor_config->GetAttr("workspace-byte-alignment").value_or(16); + IRModule lowered_mod = mod->ShallowCopy(); + // Running StorageRewrite just on the main function + tir::PrimFunc tir_main_func = + Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + IRModule main_func_mod; + main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), + tir_main_func); + main_func_mod = tir::transform::StorageRewrite()(main_func_mod); + lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), + main_func_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + tir_main_func = + Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + // Use the PrimFunc to calculate the workspace required to service the allocates + Integer main_workspace_size_bytes = + CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment); + backend::FunctionInfo main_func_info = + lowered_mod->GetAttr("main_func_info").value(); + main_func_info->workspace_sizes.Set(target_host_, main_workspace_size_bytes); + function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); + return lowered_mod; + } + protected: /*! \brief mod */ runtime::Module* mod_; @@ -682,6 +768,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { Map device_contexts_; /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; + /*! \brief input and output variables belonging to the main function signature */ + Map main_buffer_map_; /*! \brief target device */ tec::TargetMap targets_; /*! \brief target host */ @@ -773,7 +861,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (auto input : lowered_main_func->params) { input_vars_.push_back(input); - main_signature_.push_back(tir::Var("input", DataType::Handle())); + std::string input_name = SanitizeName(input->name_hint()); + CreateIOVar(input, input_name); } // Define the storage allocator ids @@ -792,9 +881,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Retrieve the return sids return_sid_ = final_aot_allocator.GetReturnIds(); - for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) { - main_signature_.push_back(tir::Var("output", DataType::Handle())); - } + // Insert outputs to main func signature + CreateIOVar(lowered_main_func->body, "output"); CollectDeviceVariables(lowered_mod->GetAttr>("device_contexts").value()); VisitExpr(lowered_main_func->body); @@ -802,7 +890,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Create the runner function. Please note that the function is not legal yet // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. - auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); LoweredOutput ret; ret.params = std::unordered_map>(); @@ -812,36 +899,30 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); } - // Build the TIR IRModule for the main AOT function - Map symbol_map; - symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); - IRModule mod_run(symbol_map, {}, {}, {}, mod->attrs); - VLOG(1) << "main module:" << std::endl << PrettyPrint(mod_run); - - // Apply storage rewrite pass to the runner function to do memory planning - auto storage_rewrite = tir::transform::StorageRewrite(); - mod_run = storage_rewrite(mod_run); - // The workspace for main function should be calculated after performing storage_rewrite for - // the top level TIR function. - Integer main_workspace_size = CalculateWorkspaceBytes( - Downcast(mod_run->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)), - workspace_byte_alignment); - - Optional main_func_info = - lowered_mod->GetAttr("main_func_info"); - - main_func_info.value()->workspace_sizes.Set(target_host_, main_workspace_size); - function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); + // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main + // function and replacing it with its TIR version. We should try to make this a Pass. + lowered_mod->Remove(lowered_mod->GetGlobalVar("main")); + auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); + lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); + // Parallel for loops are not supported in AoT codegen. + lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); + + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); + if (enable_usmp) { + lowered_mod = PlanMemoryWithUSMP(lowered_mod); + } else { + lowered_mod = PlanMemoryWithStorageRewrite(lowered_mod); + } + ret.function_metadata = std::move(function_metadata_); // Legalize AOT if needed. This means that all the packed calls // need to be wrapped in TVMValues (unless use_unpacked_api is set) if (!use_unpacked_api_) { auto pack_calls = tir::transform::LegalizePackedCalls(); - mod_run = pack_calls(mod_run); + lowered_mod = pack_calls(lowered_mod); } - ret.function_metadata = std::move(function_metadata_); - Optional> external_modules = lowered_mod->GetAttr>("external_mods"); ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; @@ -859,20 +940,26 @@ class AOTExecutorCodegen : public MixedModeVisitor { ret.external_mods = external_modules.value(); - if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) { - VLOG(1) << "merging main into existing module for host target"; - ret.lowered_funcs[target_host_]->Update(mod_run); - } else { - VLOG(1) << "adding main into new module for host target"; - ret.lowered_funcs.Set(target_host_, mod_run); + Map pool_var_info; + std::vector pool_vars; + tir::PrimFunc tir_main_func = + Downcast(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + Optional> allocated_pool_infos = + tir_main_func->GetAttr>(tvm::attr::kPoolArgs); + if (allocated_pool_infos) { + for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { + pool_vars.push_back(allocated_pool_info->pool_var.value()); + pool_var_info.Set(allocated_pool_info->pool_var.value(), allocated_pool_info); + } } - - std::vector input_var_names(input_vars_.size()); - std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), - [](Var input_var) -> String { return input_var->name_hint(); }); - ret.metadata = - runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(), - runtime::kTvmExecutorAot, mod_name, interface_api, use_unpacked_api_); + Array devices = ListDevices(); + Array inputs = + Array(tir_main_func->params.begin(), + tir_main_func->params.begin() + tir_main_func->params.size() - + return_sid_.size() - pool_vars.size() - devices.size()); + ret.metadata = ExecutorCodegenMetadata(inputs, pool_vars, devices, return_sid_.size(), + runtime::kTvmExecutorAot, mod_name, interface_api, + use_unpacked_api_, pool_var_info); return ret; } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ccfd30476f670..2f986669e758e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -103,7 +103,9 @@ struct ExecutorCodegen { Array ListDevices() { return CallFunc>("get_devices"); } - runtime::Metadata GetMetadata() { return CallFunc("get_metadata"); } + relay::backend::ExecutorCodegenMetadata GetMetadata() { + return CallFunc("get_metadata"); + } virtual ~ExecutorCodegen() {} protected: diff --git a/src/relay/backend/contrib/ethosu/source_module.cc b/src/relay/backend/contrib/ethosu/source_module.cc index f56544aee99a9..66955f8b201f8 100644 --- a/src/relay/backend/contrib/ethosu/source_module.cc +++ b/src/relay/backend/contrib/ethosu/source_module.cc @@ -89,6 +89,8 @@ class EthosUModuleNode : public ModuleNode { std::string GetSource(const std::string& format) final { return c_source; } + std::string GetFormat() { return "c"; } + Array GetArtifacts() { return compilation_artifacts_; } /*! diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 5ed93914ac534..4b77cb14d48bb 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -192,7 +192,7 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } } - // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 658283b5dc368..cb019083a9d5a 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -53,6 +54,63 @@ class TECompiler; namespace backend { using Pass = tvm::transform::Pass; +/*! + * \brief Structure that can be optionally used by the executor codegen + */ +class ExecutorCodegenMetadataNode : public Object { + public: + /*! \brief input information for the main function */ + Array inputs; + /*! \brief pool information for the main function */ + Array pools; + /*! \brief number of outputs of the main function */ + unsigned int num_outputs = 1; + /*! \brief device contexts information for the main function */ + Array devices; + /*! \brief the executor to be used to run the model */ + String executor = runtime::kTvmExecutorGraph; + /*! \brief The external API (packed or c) in use */ + String interface_api; + /*! \brief The internal API (packed or unpacked) in use */ + bool unpacked_api; + /*! \brief the input var names that correspond to pool_inputs */ + Optional> pool_inputs; + + String mod_name = ""; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "MetadataObj"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorCodegenMetadataNode, Object); +}; + +/*! + * \brief Managed reference to ExecutorCodegenMetadataNode. + */ +class ExecutorCodegenMetadata : public ObjectRef { + public: + TVM_DLL ExecutorCodegenMetadata(Array inputs, Array pools, + Array devices, int num_outputs, String executor, + String mod_name, String interface_api = "packed", + bool unpacked_api = false, + Map pool_inputs = + Map()) { + auto n = make_object(); + n->inputs = inputs; + n->pools = pools; + n->devices = devices; + n->num_outputs = num_outputs; + n->executor = executor; + n->interface_api = interface_api; + n->unpacked_api = unpacked_api; + n->mod_name = mod_name; + n->pool_inputs = pool_inputs; + data_ = std::move(n); + } + + TVM_DEFINE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef, ExecutorCodegenMetadataNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorCodegenMetadataNode); +}; + /*! * \brief The static storage information for each Tensor in the result of a Relay expression * (as per relay::FlattenTupleType). @@ -147,7 +205,7 @@ struct LoweredOutput { Array external_mods; Map function_metadata; std::unordered_map> params; - runtime::Metadata metadata; + ExecutorCodegenMetadata metadata; }; /*! diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 73f4b672a81cb..f68dd9f8d2dff 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1162,7 +1162,8 @@ void VMCompiler::Codegen() { } lib = codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, - Runtime::Create("cpp"), runtime::Metadata()); + Runtime::Create("cpp"), + relay::backend::ExecutorCodegenMetadata()); exec_->SetLib(lib); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 2d6f75ae39487..b710c2791acf4 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -166,7 +166,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* var_node) { if (var_node->type_annotation.defined()) { type_annotation = this->VisitType(var_node->type_annotation); } - return WithFields(GetRef(var_node), std::move(var_node->vid), std::move(type_annotation)); + return WithFields(GetRef(var_node), var_node->vid, type_annotation); } Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } @@ -183,7 +183,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) { auto new_field = this->Mutate(field); fields.push_back(new_field); } - return WithFields(GetRef(tuple_node), std::move(fields)); + return WithFields(GetRef(tuple_node), fields); } Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { @@ -203,8 +203,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { auto ret_type = this->VisitType(func_node->ret_type); auto body = this->Mutate(func_node->body); - return WithFields(GetRef(func_node), std::move(params), std::move(body), - std::move(ret_type), std::move(ty_params)); + return WithFields(GetRef(func_node), params, body, ret_type, ty_params); } Expr ExprMutator::VisitExpr_(const CallNode* call_node) { @@ -225,8 +224,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) { call_args.push_back(new_arg); } - return WithFields(GetRef(call_node), std::move(new_op), std::move(call_args), {}, - std::move(ty_args)); + return WithFields(GetRef(call_node), new_op, call_args, {}, ty_args); } Expr ExprMutator::VisitExpr_(const LetNode* let_node) { @@ -234,7 +232,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* let_node) { auto value = this->Mutate(let_node->value); auto body = this->Mutate(let_node->body); - return WithFields(GetRef(let_node), std::move(var), std::move(value), std::move(body)); + return WithFields(GetRef(let_node), var, value, body); } Expr ExprMutator::VisitExpr_(const IfNode* if_node) { @@ -242,28 +240,28 @@ Expr ExprMutator::VisitExpr_(const IfNode* if_node) { auto true_b = this->Mutate(if_node->true_branch); auto false_b = this->Mutate(if_node->false_branch); - return WithFields(GetRef(if_node), std::move(cond), std::move(true_b), std::move(false_b)); + return WithFields(GetRef(if_node), cond, true_b, false_b); } Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { Expr tuple = this->Mutate(get_item->tuple); - return WithFields(GetRef(get_item), std::move(tuple)); + return WithFields(GetRef(get_item), tuple); } Expr ExprMutator::VisitExpr_(const RefCreateNode* ref_create) { Expr value = this->Mutate(ref_create->value); - return WithFields(GetRef(ref_create), std::move(value)); + return WithFields(GetRef(ref_create), value); } Expr ExprMutator::VisitExpr_(const RefReadNode* ref_read) { Expr ref = this->Mutate(ref_read->ref); - return WithFields(GetRef(ref_read), std::move(ref)); + return WithFields(GetRef(ref_read), ref); } Expr ExprMutator::VisitExpr_(const RefWriteNode* ref_write) { Expr ref = this->Mutate(ref_write->ref); Expr value = this->Mutate(ref_write->value); - return WithFields(GetRef(ref_write), std::move(ref), std::move(value)); + return WithFields(GetRef(ref_write), ref, value); } Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } @@ -275,13 +273,13 @@ Expr ExprMutator::VisitExpr_(const MatchNode* match_node) { } Expr data = Mutate(match_node->data); - return WithFields(GetRef(match_node), std::move(data), std::move(clauses)); + return WithFields(GetRef(match_node), data, clauses); } Clause ExprMutator::VisitClause(const Clause& clause) { Pattern lhs = VisitPattern(clause->lhs); Expr rhs = Mutate(clause->rhs); - return WithFields(std::move(clause), std::move(lhs), std::move(rhs)); + return WithFields(clause, lhs, rhs); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } @@ -462,7 +460,7 @@ class ExprBinder : public MixedModeMutator, PatternMutator { Clause VisitClause(const Clause& clause) final { Pattern lhs = VisitPattern(clause->lhs); - return WithFields(std::move(clause), std::move(lhs), VisitExpr(clause->rhs)); + return WithFields(clause, lhs, VisitExpr(clause->rhs)); } Var VisitVar(const Var& v) final { diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index df648abba57c7..4efe57b491db0 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -45,6 +45,8 @@ IndexedGraph CreateIndexedGraph(const Expr& expr) { } protected: + using MixedModeVisitor::VisitExpr_; + void VisitLeaf(const Expr& expr) override { MixedModeVisitor::VisitLeaf(expr); auto node = std::make_shared::Node>(expr, index_++); diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 92164481807ae..f1d4eb3d87ea6 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -579,5 +579,94 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); }); +inline Expr MakeConv2dBackwardWeight(Expr grad, Expr data, Array strides, + Array padding, Array dilation, + int groups, IndexExpr channels, Array kernel_size, + std::string grad_layout, std::string data_layout, + std::string kernel_layout, DataType out_dtype) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->out_dtype = std::move(out_dtype); + attrs->data_layout = std::move(grad_layout); + attrs->kernel_layout = std::move(data_layout); + attrs->out_layout = std::move(kernel_layout); + const Op& op = Op::Get("nn.conv2d_backward_weight"); + return Call(op, {grad, data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_backward_weight") + .set_body_typed([](Expr grad, Expr data, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String grad_layout, String data_layout, + String kernel_layout, DataType out_dtype) { + return MakeConv2dBackwardWeight(grad, data, strides, padding, dilation, groups, channels, + kernel_size, grad_layout, data_layout, kernel_layout, + out_dtype); + }); + +bool Conv2DBackwardWeightRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 3); + const auto* grad = types[0].as(); + const auto* data = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const auto* param = attrs.as(); + ICHECK(param != nullptr); + // Require kernel_size to be passed, to simplify the output shape determination. + ICHECK(param->kernel_size.defined()) << "kernel_size attribute needs to be specified"; + + // We repurpose Conv2dAttrs for Conv2DBackwardWeight, note the meanings of layouts. + const Layout grad_layout(param->data_layout); + const Layout in_layout(param->kernel_layout); + const Layout kernel_layout(param->out_layout); + + const auto trans_grad_layout = tir::BijectiveLayout(grad_layout, kNCHW); + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); + Array grad_shape_nchw = trans_grad_layout.ForwardShape(grad->shape); + + auto in_channels = dshape_nchw[1]; + auto out_channels = grad_shape_nchw[1]; + + Array wshape_oihw( + {out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]}); + + auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw); + reporter->Assign(types[2], TensorType(wshape, data->dtype)); + return true; +} + +RELAY_REGISTER_OP("nn.conv2d_backward_weight") + .describe(R"code(The gradient of the 2D convolution layer with respect to the weight. + +This layer computes the gradient of the conv2d op with respect to weight, +given the original input data and the output gradient. + +- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`. +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, in_channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (channels, in_channels, kernel_size[0], kernel_size[1]) if `layout` is `NCHW`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("grad", "Tensor", "The gradient tensor.") + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel) + .set_attr("TNonComputational", true) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index ecdd36ddb7914..03fa770e404f3 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -140,31 +140,37 @@ WorkloadType GetWorkload(const Array& arg_types, const Conv2DA int out_channels, kernel_h, kernel_w; int channel_multiplier = -1; bool depthwise = is_depthwise(param); - if (param->kernel_layout == "OIHW") { - out_channels = get_const_int(kernel_shape[0]); - kernel_h = get_const_int(kernel_shape[2]); - kernel_w = get_const_int(kernel_shape[3]); - if (depthwise) { - channel_multiplier = get_const_int(kernel_shape[1]); - } - } else if (param->kernel_layout == "HWIO") { - kernel_h = get_const_int(kernel_shape[0]); - kernel_w = get_const_int(kernel_shape[1]); - out_channels = get_const_int(kernel_shape[3]); - if (depthwise) { - channel_multiplier = get_const_int(kernel_shape[2]); - } + int index_k = 0; + int index_h = 2; + int index_w = 3; + int index_c = 1; + if (param->kernel_layout == "HWIO") { + index_k = 3; + index_h = 0; + index_w = 1; + index_c = 2; } else if (param->kernel_layout == "HWOI") { - kernel_h = get_const_int(kernel_shape[0]); - kernel_w = get_const_int(kernel_shape[1]); - out_channels = get_const_int(kernel_shape[2]); - if (depthwise) { - channel_multiplier = get_const_int(kernel_shape[3]); - } - } else { + index_k = 2; + index_h = 0; + index_w = 1; + index_c = 3; + } else if (param->kernel_layout == "OHWI") { + index_k = 0; + index_h = 1; + index_w = 2; + index_c = 3; + } else if (param->kernel_layout != "OIHW") { LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout"; } + kernel_h = get_const_int(kernel_shape[index_h]); + kernel_w = get_const_int(kernel_shape[index_w]); + out_channels = get_const_int(kernel_shape[index_k]); + + if (depthwise) { + channel_multiplier = get_const_int(kernel_shape[index_c]); + } + return std::make_tuple(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier); } @@ -519,6 +525,8 @@ Expr Conv2DThirdTerm(const Expr& weight, const Expr& input_zero_point, const Con axes_t3 = {0, 1, 2}; } else if (param->kernel_layout == "HWOI") { axes_t3 = {0, 1, 3}; + } else if (param->kernel_layout == "OHWI") { + axes_t3 = {1, 2, 3}; } else { LOG(FATAL) << "qnn.conv2d does not support " << param->kernel_layout << " layout"; } @@ -701,8 +709,8 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") << "qnn.conv2d supports only NCHW/NHWC input data layout."; ICHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" || - param->kernel_layout == "HWOI") - << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout."; + param->kernel_layout == "HWOI" || param->kernel_layout == "OHWI") + << "qnn.conv2d supports only OIHW/HWIO/HWOI/OHWI kernel data layout."; ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified."; int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier; diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index e731d2bb658ec..3f1985b7ddfa5 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -270,7 +270,7 @@ class AnnotateTargetRewriter : public ExprRewriter { auto tuple = Downcast(post); auto target_n_args = AnnotateArgs(tuple->fields); - auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args))); + auto new_expr = WithFields(tuple, std::get<1>(target_n_args)); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return std::move(new_expr); } @@ -378,7 +378,7 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter { for (auto f : tuple->fields) { new_fields.push_back(InsertCompilerEndAndPropogateTarget(f)); } - return WithFields(std::move(tuple), std::move(new_fields)); + return WithFields(tuple, new_fields); } Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index d40dd6c950891..c7ca2227fd908 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -334,7 +334,7 @@ class RewriteOnDevices : public ExprMutator { Expr tuple = VisitExpr(tuple_get_item_node->tuple); OnDeviceProps props = GetOnDeviceProps(tuple); - Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), std::move(tuple)); + Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), tuple); if (props.body.defined() && props.is_normal()) { VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl @@ -363,8 +363,8 @@ class RewriteOnDevices : public ExprMutator { } expr = VisitExpr(expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { - expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*opt_var=*/{}, - /*opt_value=*/std::move(std::get<1>(*itr)), /*opt_body=*/std::move(expr)); + expr = WithFields(/*let=*/std::get<0>(*itr), /*opt_var=*/{}, + /*opt_value=*/std::get<1>(*itr), /*opt_body=*/expr); } return expr; } @@ -378,7 +378,7 @@ class RewriteOnDevices : public ExprMutator { << "to be fixed to VirtualDevice " << props.virtual_device; body = MaybeOnDeviceFixed(props.body, props.virtual_device); } - return WithFields(GetRef(function_node), function_node->params, std::move(body)); + return WithFields(GetRef(function_node), function_node->params, body); } Expr VisitExpr_(const CallNode* call_node) final { @@ -990,7 +990,7 @@ class DeviceCapturer : public ExprMutator { for (const auto& field : tuple_node->fields) { fields.push_back(VisitChild(tuple, field)); } - return WithFields(std::move(tuple), std::move(fields)); + return WithFields(tuple, fields); } Expr VisitExpr_(const FunctionNode* function_node) final { @@ -1025,8 +1025,7 @@ class DeviceCapturer : public ExprMutator { /*expected_virtual_device=*/result_virtual_device, /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body); - Function func = WithFields(GetRef(function_node), std::move(function_node->params), - std::move(body)); + Function func = WithFields(GetRef(function_node), function_node->params, body); return FunctionOnDevice(func, std::move(param_virtual_devices), std::move(result_virtual_device)); } @@ -1102,9 +1101,9 @@ class DeviceCapturer : public ExprMutator { if (call_node->op == CallLoweredOp()) { Call new_call = CallLowered(Downcast(op), args, /*call_lowered_attrs=*/{}, /*span=*/{}); - return WithFields(call, std::move(new_call->op), std::move(new_call->args)); + return WithFields(call, new_call->op, new_call->args); } else { - return WithFields(call, std::move(op), std::move(args)); + return WithFields(call, op, args); } } @@ -1145,33 +1144,32 @@ class DeviceCapturer : public ExprMutator { Expr cond = VisitChild(ife, if_node->cond); Expr true_branch = VisitChild(ife, if_node->true_branch); Expr false_branch = VisitChild(ife, if_node->false_branch); - return WithFields(std::move(ife), std::move(cond), std::move(true_branch), - std::move(false_branch)); + return WithFields(ife, cond, true_branch, false_branch); } Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { auto tuple_get_item = GetRef(tuple_get_item_node); Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); - return WithFields(std::move(tuple_get_item), std::move(tuple)); + return WithFields(tuple_get_item, tuple); } Expr VisitExpr_(const RefCreateNode* ref_create_node) final { auto ref_create = GetRef(ref_create_node); Expr value = VisitChild(ref_create, ref_create_node->value); - return WithFields(std::move(ref_create), std::move(value)); + return WithFields(ref_create, value); } Expr VisitExpr_(const RefReadNode* ref_read_node) final { auto ref_read = GetRef(ref_read_node); Expr ref = VisitChild(ref_read, ref_read_node->ref); - return WithFields(std::move(ref_read), std::move(ref)); + return WithFields(ref_read, ref); } Expr VisitExpr_(const RefWriteNode* ref_write_node) final { auto ref_write = GetRef(ref_write_node); Expr ref = VisitChild(ref_write, ref_write_node->ref); Expr value = VisitChild(ref_write, ref_write_node->value); - return WithFields(std::move(ref_write), std::move(ref), std::move(value)); + return WithFields(ref_write, ref, value); } Expr VisitExpr_(const MatchNode* match_node) final { @@ -1184,7 +1182,7 @@ class DeviceCapturer : public ExprMutator { Expr rhs = VisitChild(match, clause->rhs); clauses.push_back(Clause(lhs, rhs)); } - return WithFields(std::move(match), std::move(data), std::move(clauses)); + return WithFields(match, data, clauses); } VirtualDevice GetVirtualDevice(const Expr& expr) { diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index f3c53cfc8bc0a..bafdbd3591414 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -240,6 +240,14 @@ class DynamicToStaticMutator : public MixedModeMutator { gv_ = vars[func_]; } + Expr GetCurExpr(const Expr& original_expr) { + if (original_expr.as()) { + return mod_->Lookup(gv_); + } else { + return mod_->Lookup(gv_).as()->body; + } + } + Expr PrepareInput(const Expr& expr) { BaseFunc func; if (auto* func_node = expr.as()) { @@ -249,10 +257,12 @@ class DynamicToStaticMutator : public MixedModeMutator { relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); } mod_->Update(gv_, func); + mod_ = transform::FoldConstant()(mod_); - mod_ = transform::InferType()(mod_); + transform::InferTypeLocal(GetCurExpr(expr)); mod_ = transform::FoldConstant()(mod_); - mod_ = transform::InferType()(mod_); + transform::InferTypeLocal(GetCurExpr(expr)); + Expr out; if (expr.as()) { out = mod_->Lookup(gv_); diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index d41d7dbd631fe..f530d61e0d999 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -211,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor { field_bindings.push_back(f_ad->get().forward); } // reconstruct tuple using let-bound variables to avoid duplication - auto orig = WithFields(GetRef(tuple_node), std::move(field_bindings)); + auto orig = WithFields(GetRef(tuple_node), field_bindings); orig->checked_type_ = tt; auto ret = std::make_shared(ll, orig, diag_ctx); // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)] diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index 23c45a90a5e31..0e7e9076ae0ee 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -122,7 +122,7 @@ class ForwardRewriter : private MixedModeMutator { fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i])); } - return WithFields(GetRef(tuple_node), std::move(fields)); + return WithFields(GetRef(tuple_node), fields); } Expr Rewrite_(const CallNode* call_node, const Expr& post) final { diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f2fc0af4f9c16..5037b32ce615e 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -905,7 +905,7 @@ class FuseMutator : private MixedModeMutator { } // This tuple is an intermediate node in the group Array new_fields = GetNewArguments(tuple_node->fields, ret_group); - return WithFields(GetRef(tuple_node), std::move(new_fields)); + return WithFields(GetRef(tuple_node), new_fields); } Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) { diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index d1f0f69c5e932..900442e9b9a8d 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -92,7 +92,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } new_fields.push_back(new_field); } - return WithFields(GetRef(tuple_node), std::move(new_fields)); + return WithFields(GetRef(tuple_node), new_fields); } void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 8fe67c39f1681..bc1ed518d4736 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -469,8 +469,7 @@ IRModule FlattenTupleOutputs(IRModule module) { // Return a tuple of compiler_ends in the place of the tuple that was // annotated with a compiler_end. - auto func = WithFields(GetRef(tuple_node), new_fields); - return func; + return WithFields(GetRef(tuple_node), new_fields); } } return post; diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index fbb2d73d1db09..a5266df8b057e 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -59,12 +59,12 @@ class ArgumentSplitter : public ExprRewriter { for (int j = 0; j < argsCount; ++j) { args.push_back(tuple_node->fields[j + startIdx]); } - Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(args)); + Tuple new_tuple = WithFields(GetRef(tuple_node), args); Expr body = MakeConcatenate(new_tuple, param->axis); splitted[i] = StopFusion(body); } tvm::Array tuple_args(splitted); - Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(tuple_args)); + Tuple new_tuple = WithFields(GetRef(tuple_node), tuple_args); return MakeConcatenate(new_tuple, param->axis); } return post; diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 5e32d290d2860..a0841ec44faec 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -255,7 +255,7 @@ class Fill : ExprFunctor, private transform::Lexi for (const auto& a : tuple_node->fields) { fields.push_back(VisitExpr(a)); } - return Compound(e, WithFields(GetRef(tuple_node), std::move(fields)), v); + return Compound(e, WithFields(GetRef(tuple_node), fields), v); } Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 963f365d978ec..6d8fe67847f64 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -216,7 +216,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, std::function next; next = [&]() { return (fields.size() == tuple_node->fields.size()) - ? k(WithFields(GetRef(tuple_node), std::move(fields))) + ? k(WithFields(GetRef(tuple_node), fields)) : VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) { fields.push_back(v); return next(); diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index 56affb581fd11..3dbf10e0611b9 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -300,7 +300,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj Expr tmp = push_back_one_arg(x); fields.push_back(tmp); } - normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields))); + normal_new_args.push_back(WithFields(tuple_new_arg, fields)); } else { Expr tmp = push_back_one_arg(new_arg); normal_new_args.push_back(tmp); @@ -383,7 +383,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); pt++; } - transformed_args.push_back(WithFields(tuple_arg, std::move(transformed_tuple_arg))); + transformed_args.push_back(WithFields(tuple_arg, transformed_tuple_arg)); } else { transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); pt++; diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 97704986792d7..b7476e5106fa7 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -37,89 +37,12 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Set Algo - entry_ptr->conv_entry.fwd_algo = static_cast(algo); + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape, + y->shape, x->dtype, conv_dtype); // Set Device entry_ptr->conv_entry.device = x->device; - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Dims includes N and C - int full_dims = dims + 2; - - std::vector dim(full_dims); - std::vector tensor_stride(full_dims); - - // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error - // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int - - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - if (dims == 2) { - // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor( - entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], - dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); - int ni, ci, hi, wi; - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { - ni = 0; - ci = 3; - hi = 1; - wi = 2; - } else { - ni = 0; - ci = 1; - hi = 2; - wi = 3; - } - - // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor( - entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[ni]), static_cast(w->shape[ci]), - static_cast(w->shape[hi]), static_cast(w->shape[wi]))); - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(x->shape[ni]), static_cast(x->shape[ci]), - static_cast(x->shape[hi]), static_cast(x->shape[wi]))); - // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(y->shape[ni]), static_cast(y->shape[ci]), - static_cast(y->shape[hi]), static_cast(y->shape[wi]))); - } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); - - // Set Filter - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(w->shape[i]); - } - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, - dim.data())); - // Set Input - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(x->shape[i]); - } - GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - dim.data(), tensor_stride.data())); - // Set Output - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(y->shape[i]); - } - GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, - dim.data(), tensor_stride.data())); - } - - if (cudnnGetVersion() > 7000) { - CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) - } + // Set Algo + entry_ptr->conv_entry.fwd_algo = static_cast(algo); // Set workspace size_t workspace_size = 0; @@ -137,125 +60,22 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co entry_ptr->conv_entry.output_desc, y->data)); } -void OutputShape(int format, int dims, int groups, const int pad[], const int stride[], - const int dilation[], const int x_dim[], const int w_dim[], void* out_shape, - const std::string& data_dtype, const std::string& conv_dtype) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype)); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Dims includes N and C - int full_dims = dims + 2; - - // conv desc - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, CUDNN_CROSS_CORRELATION, - entry_ptr->conv_entry.data_type)); - - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { - ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors"; - - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], - x_dim[3], x_dim[1], x_dim[2])); - - // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3], - w_dim[1], w_dim[2])); - - CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim( - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, static_cast(out_shape), - static_cast(out_shape) + 3, static_cast(out_shape) + 1, - static_cast(out_shape) + 2)); - } else { - // Set Input - std::vector tensor_stride(full_dims); - GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - x_dim, tensor_stride.data())); - // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); - - CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim( - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, full_dims, static_cast(out_shape))); - } -} - void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype)); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Dims includes N and C - int full_dims = dims + 2; - - // conv desc - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - - if (format == 1) { - ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors"; - int ni = 0; - int ci = 3; - int hi = 1; - int wi = 2; - - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(x_dim[ni]), static_cast(x_dim[ci]), static_cast(x_dim[hi]), - static_cast(x_dim[wi]))); - - CUDNN_CALL(cudnnSetFilter4dDescriptor( - entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, - static_cast(w_dim[ni]), static_cast(w_dim[ci]), static_cast(w_dim[hi]), - static_cast(w_dim[wi]))); - // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(y_dim[ni]), static_cast(y_dim[ci]), static_cast(y_dim[hi]), - static_cast(y_dim[wi]))); - - CUDNN_CALL(cudnnSetConvolution2dDescriptor( - entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], - dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); - } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, CUDNN_CROSS_CORRELATION, - entry_ptr->conv_entry.data_type)); - - std::vector tensor_stride(full_dims); - // input desc - GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - x_dim, tensor_stride.data())); - // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); - - // output desc - GetCudnnStride(full_dims, y_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, - y_dim, tensor_stride.data())); - } - - if (cudnnGetVersion() > 7000) { - CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) + const int full_dims = dims + 2; + std::vector x_dim_int64(full_dims); + std::vector w_dim_int64(full_dims); + std::vector y_dim_int64(full_dims); + for (int i = 0; i < full_dims; ++i) { + x_dim_int64[i] = x_dim[i]; + w_dim_int64[i] = w_dim[i]; + y_dim_int64[i] = y_dim[i]; } + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), + w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype), + conv_dtype); int returned_algo_count = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; @@ -327,24 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn") - .set_body([](TVMArgs args, TVMRetValue* ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - void* out_shape = args[7]; - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - int groups = args[10]; - - OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype, - conv_dtype); - }); - TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") .set_body([](TVMArgs args, TVMRetValue* ret) { int format = args[0]; diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 297cd9e7a3610..e39c47339c7fb 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -20,11 +20,16 @@ /*! * \file Use external cudnn utils function */ + #include "cudnn_utils.h" #include +#include #include +#include +#include + namespace tvm { namespace contrib { @@ -160,6 +165,96 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, + const int pad[], const int stride[], const int dilation[], int64_t x_dim[], + int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype, + const std::string& conv_dtype) { + // Set Format + entry_ptr->conv_entry.tensor_format = static_cast(format); + // Set Data Type + entry_ptr->conv_entry.data_type = + CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype)); + + cudnnDataType_t cudnn_data_type = CuDNNDataType::DLTypeToCuDNNType(data_dtype); + + // Dims includes N and C + int full_dims = dims + 2; + + std::vector dim(full_dims); + std::vector tensor_stride(full_dims); + + // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error + // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int + + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); + if (dims == 2) { + // Set Desc + CUDNN_CALL(cudnnSetConvolution2dDescriptor( + entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], + dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); + int ni, ci, hi, wi; + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + ni = 0; + ci = 3; + hi = 1; + wi = 2; + } else { + ni = 0; + ci = 1; + hi = 2; + wi = 3; + } + + // Set Input + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type, + static_cast(x_dim[ni]), static_cast(x_dim[ci]), static_cast(x_dim[hi]), + static_cast(x_dim[wi]))); + // Set Filter + CUDNN_CALL(cudnnSetFilter4dDescriptor( + entry_ptr->conv_entry.filter_desc, cudnn_data_type, entry_ptr->conv_entry.tensor_format, + static_cast(w_dim[ni]), static_cast(w_dim[ci]), static_cast(w_dim[hi]), + static_cast(w_dim[wi]))); + // Set Output + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type, + static_cast(y_dim[ni]), static_cast(y_dim[ci]), static_cast(y_dim[hi]), + static_cast(y_dim[wi]))); + } else { + ICHECK_EQ(format, 0) << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors."; + + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, entry_ptr->conv_entry.mode, + entry_ptr->conv_entry.data_type)); + + // Set Filter + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(w_dim[i]); + } + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, cudnn_data_type, + entry_ptr->conv_entry.tensor_format, full_dims, + dim.data())); + // Set Input + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(x_dim[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, cudnn_data_type, + full_dims, dim.data(), tensor_stride.data())); + // Set Output + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(y_dim[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, cudnn_data_type, + full_dims, dim.data(), tensor_stride.data())); + } + + if (cudnnGetVersion() > 7000) { + CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) + } +} + // SoftmaxEntry SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 01b92d61e66e4..89de0e90df907 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -28,6 +28,8 @@ #include #include +#include + #include "../../cuda/cuda_common.h" namespace tvm { @@ -64,7 +66,7 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { struct ConvEntry { cudnnConvolutionDescriptor_t conv_desc; - cudnnConvolutionMode_t mode; + cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION}; cudnnFilterDescriptor_t filter_desc; cudnnDataType_t data_type; cudnnTensorFormat_t tensor_format; @@ -103,6 +105,11 @@ struct CuDNNThreadEntry { static CuDNNThreadEntry* ThreadLocal(bool check_exists = true); }; // CuDNNThreadEntry +void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, + const int pad[], const int stride[], const int dilation[], int64_t x_dim[], + int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype, + const std::string& conv_dtype); + } // namespace contrib } // namespace tvm diff --git a/src/runtime/hexagon/android/sim/hexagon_device_sim.cc b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc index 250259832597e..05559a1d1a98e 100644 --- a/src/runtime/hexagon/android/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/android/sim/hexagon_device_sim.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include diff --git a/src/runtime/hexagon/hexagon/hexagon_buffer.cc b/src/runtime/hexagon/hexagon/hexagon_buffer.cc index 3bd1bbf784e37..dfb499c9a31ae 100644 --- a/src/runtime/hexagon/hexagon/hexagon_buffer.cc +++ b/src/runtime/hexagon/hexagon/hexagon_buffer.cc @@ -38,7 +38,8 @@ namespace runtime { namespace hexagon { struct Allocation { - Allocation(size_t nbytes, size_t alignment) : nbytes_(nbytes), alignment_(alignment) {} + Allocation(size_t allocation_nbytes, size_t alignment) + : allocation_nbytes_(allocation_nbytes), alignment_(alignment) {} virtual ~Allocation() {} Allocation(const Allocation&) = delete; Allocation& operator=(const Allocation&) = delete; @@ -46,7 +47,7 @@ struct Allocation { Allocation& operator=(Allocation&&) = delete; void* data_{nullptr}; - size_t nbytes_; + size_t allocation_nbytes_; size_t alignment_; }; @@ -190,9 +191,11 @@ void HexagonBuffer::SetStorageScope(Optional scope) { void HexagonBuffer::CopyTo(void* data, size_t nbytes) const { CHECK_LE(nbytes, nbytes_); + CHECK(managed_allocations_.size() && "CopyTo not supported on unmanaged `external` allocations"); + size_t copied = 0; for (size_t i = 0; i < nallocs_; ++i) { - size_t bytes_to_copy = std::min(nbytes - copied, managed_allocations_[i]->nbytes_); + size_t bytes_to_copy = std::min(nbytes - copied, managed_allocations_[i]->allocation_nbytes_); if (bytes_to_copy == 0) break; memcpy(static_cast(data) + copied, @@ -204,9 +207,12 @@ void HexagonBuffer::CopyTo(void* data, size_t nbytes) const { void HexagonBuffer::CopyFrom(void* data, size_t nbytes) { CHECK_LE(nbytes, nbytes_); + CHECK(managed_allocations_.size() && + "CopyFrom not supported on unmanaged `external` allocations"); + size_t copied = 0; for (size_t i = 0; i < nallocs_; ++i) { - size_t bytes_to_copy = std::min(nbytes - copied, managed_allocations_[i]->nbytes_); + size_t bytes_to_copy = std::min(nbytes - copied, managed_allocations_[i]->allocation_nbytes_); if (bytes_to_copy == 0) break; memcpy(static_cast(managed_allocations_[i]->data_), @@ -219,14 +225,19 @@ void HexagonBuffer::CopyFrom(void* data, size_t nbytes) { void HexagonBuffer::CopyFrom(const HexagonBuffer& other, size_t nbytes) { CHECK_LE(nbytes, nbytes_); CHECK_LE(nbytes, other.nbytes_); + CHECK(managed_allocations_.size() && + "CopyFrom not supported on unmanaged `external` allocations"); + CHECK(other.managed_allocations_.size() && + "CopyFrom not supported on unmanaged `external` allocations"); if (nallocs_ == other.nallocs_) { size_t copied = 0; for (size_t i = 0; i < nallocs_; ++i) { - size_t bytes_to_copy = std::min(nbytes - copied, managed_allocations_[i]->nbytes_); + size_t bytes_to_copy = std::min(nbytes - copied, managed_allocations_[i]->allocation_nbytes_); if (bytes_to_copy == 0) break; - CHECK_LE(other.managed_allocations_[i]->nbytes_, managed_allocations_[i]->nbytes_); + CHECK_LE(other.managed_allocations_[i]->allocation_nbytes_, + managed_allocations_[i]->allocation_nbytes_); memcpy(static_cast(managed_allocations_[i]->data_), static_cast(other.managed_allocations_[i]->data_), bytes_to_copy); diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 8996d1b76e1fb..e83e1a3a7629f 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -49,53 +49,6 @@ inline String get_name_mangled(const String& module_name, const String& name) { return ss.str(); } -/*! - * \brief Structure that can be optionally used by the executor codegen - */ -class MetadataNode : public Object { - public: - /*! \brief input information for the main function */ - Array inputs; - /*! \brief number of outputs of the main function */ - int num_outputs = 1; - /*! \brief device contexts information for the main function */ - Array devices; - /*! \brief the executor to be used to run the model */ - String executor = kTvmExecutorGraph; - /*! \brief The external API (packed or c) in use */ - String interface_api; - /*! \brief The internal API (packed or unpacked) in use */ - bool unpacked_api; - - String mod_name = ""; - - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "MetadataObj"; - TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, Object); -}; - -/*! - * \brief Managed reference to MetadataNode. - */ -class Metadata : public ObjectRef { - public: - TVM_DLL Metadata(Array inputs, Array devices, int num_outputs, String executor, - String mod_name, String interface_api = "packed", bool unpacked_api = false) { - auto n = make_object(); - n->inputs = inputs; - n->devices = devices; - n->num_outputs = num_outputs; - n->executor = executor; - n->interface_api = interface_api; - n->unpacked_api = unpacked_api; - n->mod_name = mod_name; - data_ = std::move(n); - } - - TVM_DEFINE_OBJECT_REF_METHODS(Metadata, ObjectRef, MetadataNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(MetadataNode); -}; - /*! * \brief Create a metadata module object. * diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 393d269222c16..097d6a2f53e78 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -127,6 +127,11 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { } } +std::string ModuleNode::GetFormat() { + LOG(FATAL) << "Module[" << type_key() << "] does not support GetFormat"; + return ""; +} + bool RuntimeEnabled(const std::string& target) { std::string f_name; if (target == "cpu") { @@ -179,6 +184,10 @@ TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); +TVM_REGISTER_GLOBAL("runtime.ModuleGetFormat").set_body_typed([](Module mod) { + return mod->GetFormat(); +}); + TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") diff --git a/src/runtime/pipeline/pipeline_executor.cc b/src/runtime/pipeline/pipeline_executor.cc index 32414c607df6f..0ca291a2fbbe7 100644 --- a/src/runtime/pipeline/pipeline_executor.cc +++ b/src/runtime/pipeline/pipeline_executor.cc @@ -37,11 +37,27 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, } else if (name == "get_input_pipeline_map") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { if (String::CanConvertFrom(args[0])) { - *rv = this->GetInputPipeplineMapping(args[0].operator String()); + *rv = this->GetInputPipeplineMap(args[0].operator String()); } else { LOG(FATAL) << "Function only support the input name value in the form of string"; } }); + } else if (name == "get_params_group_pipeline_map") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0])) { + *rv = this->GetParamsGroupPipelineMap(args[0].operator String()); + } else { + LOG(FATAL) << "Function only support the input name value in the form of string"; + } + }); + } else if (name == "set_param") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (String::CanConvertFrom(args[0]) && String::CanConvertFrom(args[1])) { + this->SetParam(args[0].operator String(), args[1].operator String(), args[2]); + } else { + LOG(FATAL) << "Function only support the parameter name and the key in the form of string"; + } + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(); @@ -55,11 +71,20 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name, * \param The global input name. * \return Returning the index and the input interface name of corresponding subgraph. */ -Array PipelineExecutor::GetInputPipeplineMapping(std::string input_name) { +Array PipelineExecutor::GetInputPipeplineMap(std::string input_name) { std::pair map = input_connection_config[input_name]; return {std::to_string(map.first), map.second}; } +/*! + * \brief Return the module index for the parameters group name. + * \param name The parameters group name. + * \return int The module index. + */ +int PipelineExecutor::GetParamsGroupPipelineMap(const std::string& name) { + return param_connection_config[name]; +} + /*! * \brief Use the mod_config information to create a graph runtime list. * \param mod_config The config information that generates by the export library function call. @@ -115,7 +140,18 @@ std::vector PipelineExecutor::CreateGraphModules(const ModuleConfig& mod } return ret; } - +/*! + * \brief Set a parameter into a graph module. + * \param param_group_name The parameters group name. + * \param param_key_name The parameter key name. + * \param data_in The parameter data. + */ +void PipelineExecutor::SetParam(std::string param_group_name, std::string param_key_name, + DLTensor* data_in) { + // Get the module index from the param name. + int module_index = this->GetParamsGroupPipelineMap(param_group_name); + // TODO(huajsj): set the parameters into runtime module. +} /*! * \brief Initialize the pipeline executor with a list of modules to be pipelined * and config in JSON format. diff --git a/src/runtime/pipeline/pipeline_executor.h b/src/runtime/pipeline/pipeline_executor.h index 1ae52e07c2607..6d4c7ba1fa4fd 100644 --- a/src/runtime/pipeline/pipeline_executor.h +++ b/src/runtime/pipeline/pipeline_executor.h @@ -75,14 +75,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode { * \param The global input name. * \return Returning the index and the input interface name of corresponding subgraph. */ - Array GetInputPipeplineMapping(std::string input_name); + Array GetInputPipeplineMap(std::string input_name); + /*! + * \brief This function return a module index for the global parameters group name. + * \param name The parameters group name. + * \return Returning a runtime module index. + */ + int GetParamsGroupPipelineMap(const std::string& name); + /*! + * \brief Use the parameters group name to get the specific backend runtime then use + * the param_key_name to set param data for the said backend runtime. + * \param param_group_name The parameters group name. + * \param param_key_name The parameter key name. + * \param data_in The parameter value. + */ + void SetParam(std::string param_group_name, std::string param_key_name, DLTensor* data_in); /*! * \brief Get the number of outputs. * * \return The number of outputs. */ int NumOutputs() const { return num_outputs_; } - /*!\brief Load the module files information.*/ ModuleConfig& LoadModuleConfig(dmlc::JSONReader* reader) { reader->BeginArray(); @@ -126,6 +139,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode { ConfigPipelineExecution pipeline_config_; /*!\brief The map of global input and subgraph input.*/ InputConnectionConfig input_connection_config; + /*!\brief The map includes global parameters groups and runtime modules.*/ + ParamConnectionConfig param_connection_config; /*!\brief The module information used to create the graph runtimes.*/ ModuleConfig mod_config_; /*!\brief How many outputs are in this pipeline executor.*/ @@ -139,6 +154,8 @@ class TVM_DLL PipelineExecutor : public ModuleNode { reader->Read(&pipeline_config_); } else if (key == "input_connection") { reader->Read(&input_connection_config); + } else if (key == "param_connection") { + reader->Read(¶m_connection_config); } else { LOG(FATAL) << "do not support key " << key; } diff --git a/src/runtime/pipeline/pipeline_struct.h b/src/runtime/pipeline/pipeline_struct.h index 52422b7645642..aa831070ccdb6 100644 --- a/src/runtime/pipeline/pipeline_struct.h +++ b/src/runtime/pipeline/pipeline_struct.h @@ -251,6 +251,48 @@ struct InputConnectionConfig { } }; +/*! + * \brief A map includes global module parameters groups and graph modudles. + */ +struct ParamConnectionConfig { + /*!\brief Mapping from the name of a global module parameters group to the index of a runtime + * module. + */ + std::unordered_map param_connection; + bool Empty() { return param_connection.empty(); } + int operator[](const std::string key) { + if (param_connection.find(key) == param_connection.end()) { + LOG(FATAL) << "do not support key " << key; + } + return param_connection[key]; + } + /*! + * \brief Load from JSONReader. + * \param reader Json reader. + */ + void Load(dmlc::JSONReader* reader) { + reader->BeginArray(); + while (reader->NextArrayItem()) { + reader->BeginObject(); + std::string key; + std::string global_param_name; + int mod_idx = -1; + while (reader->NextObjectItem(&key)) { + if (key == "global_param_name") { + reader->Read(&global_param_name); + } else if (key == "mod_idx") { + reader->Read(&mod_idx); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + ICHECK(mod_idx >= 0) << "Invalid module index value " << mod_idx; + ICHECK(!global_param_name.empty()) << "Invalid global parameter group name value"; + param_connection[global_param_name] = mod_idx; + } + } +}; + /*! * \brief The information used to initialize the graph executor module, the information * come from the export library function call. diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h index d1b54d5be65bf..a081cf97db4a4 100644 --- a/src/runtime/rpc/rpc_local_session.h +++ b/src/runtime/rpc/rpc_local_session.h @@ -60,7 +60,7 @@ class LocalSession : public RPCSession { protected: /*! - * \brief internal encode return fucntion. + * \brief internal encode return function. * \param rv The return value. * \param encode_return The encoding function. */ diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index af814158f7b6f..1c6a6f8b4350a 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -87,7 +87,7 @@ class RingBuffer { } /*! - * \brief Peform a non-blocking read from buffer + * \brief Perform a non-blocking read from buffer * size must be smaller than this->bytes_available() * \param data the data pointer. * \param size The number of bytes to read. diff --git a/src/support/socket.h b/src/support/socket.h index a83a67c85d761..42d5d9004c156 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -516,7 +516,7 @@ class TCPSocket : public Socket { [&]() { return recv(sockfd, buf, static_cast(len), flags); }); } /*! - * \brief peform block write that will attempt to send all data out + * \brief perform block write that will attempt to send all data out * can still return smaller than request when error occurs * \param buf_ the pointer to the buffer * \param len the size of the buffer @@ -538,7 +538,7 @@ class TCPSocket : public Socket { return ndone; } /*! - * \brief peform block read that will attempt to read all data + * \brief perform block read that will attempt to read all data * can still return smaller than request when error occurs * \param buf_ the buffer pointer * \param len length of data to recv @@ -654,7 +654,7 @@ struct PollHelper { } /*! - * \brief peform poll on the set defined, read, write, exception + * \brief perform poll on the set defined, read, write, exception * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block * \return */ diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 485fc9f67dbd7..556f05d2e33ae 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -71,7 +71,7 @@ #include #include #include -#if TVM_LLVM_VERSION >= 140 && !defined(TVM_USE_HEXAGON_LLVM) +#if TVM_LLVM_VERSION >= 140 #include #else #include diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index 2b190e5d66ed0..2facf1de64d56 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -37,7 +37,7 @@ namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& params, tvm::runtime::Module target_module, const Array& ext_modules, Target target, - tvm::relay::Runtime runtime, runtime::Metadata metadata) { + tvm::relay::Runtime runtime, relay::backend::ExecutorCodegenMetadata metadata) { // Here we split modules into two groups: // 1. Those modules which can be exported to C-runtime. These are DSO-exportable // (i.e. llvm or c) modules which return nothing from get_const_vars(). diff --git a/src/target/metadata_module.h b/src/target/metadata_module.h index ee6f7231b3a19..2afcf3497ab8a 100644 --- a/src/target/metadata_module.h +++ b/src/target/metadata_module.h @@ -33,7 +33,7 @@ #include #include -#include "../runtime/meta_data.h" +#include "../relay/backend/utils.h" namespace tvm { namespace codegen { @@ -54,7 +54,7 @@ namespace codegen { runtime::Module CreateMetadataModule( const std::unordered_map& params, runtime::Module target_module, const Array& ext_modules, Target target, tvm::relay::Runtime runtime, - runtime::Metadata metadata); + relay::backend::ExecutorCodegenMetadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ac6f3adad6066..e6f81646242d6 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -362,61 +362,6 @@ void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { ICHECK_EQ(scope, "global"); } -void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) - ICHECK_EQ(t.lanes(), 1) << "do not yet support vector types"; - if (t.is_handle()) { - os << "void*"; - return; - } - if (t.is_float()) { - if (t.bits() == 32) { - os << "float"; - return; - } - if (t.bits() == 64) { - os << "double"; - return; - } - } else if (t.is_uint()) { - switch (t.bits()) { - case 8: - case 16: - case 32: - case 64: { - os << "uint" << t.bits() << "_t"; - return; - } - case 1: - os << "int"; - return; - } - } else if (t.is_int()) { - switch (t.bits()) { - case 8: - case 16: - case 32: - case 64: { - os << "int" << t.bits() << "_t"; - return; - } - } - } - LOG(FATAL) << "Cannot convert type " << t << " to C type"; -} - -void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) - if (auto* ptr = type.as()) { - return PrintType(ptr->dtype, os); - } else if (auto* ptr = type.as()) { - PrintType(ptr->element_type, os); - os << '*'; - } else if (IsVoidType(type)) { - os << "void"; - } else { - LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; - } -} - inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) if (op->dtype == DataType::Int(32)) { std::ostringstream temp; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 299f7e0a9ceff..3b042b9fbd2c5 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -163,18 +163,7 @@ class CodeGenC : public ExprFunctor, void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; - /*! - * Print Type represetnation of type t. - * \param t The type representation. - * \param os The stream to print the ctype into - */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) - /*! - * Print Type represetnation of type type. - * \param type The type representation. - * \param os The stream to print the ctype into - */ - virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) + /*! * \brief Print expr representing the thread tag * \param IterVar iv The thread index to be binded; diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 9f0cf9a70b615..5dcf1587bdb97 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -113,5 +113,60 @@ void CodeGenSourceBase::EndScope(int scope_id) { indent_ -= 2; } +void CodeGenSourceBase::PrintType(DataType type, std::ostream& os) { // NOLINT(*) + ICHECK_EQ(type.lanes(), 1) << "do not yet support vector types"; + if (type.is_handle()) { + os << "void*"; + return; + } + if (type.is_float()) { + if (type.bits() == 32) { + os << "float"; + return; + } + if (type.bits() == 64) { + os << "double"; + return; + } + } else if (type.is_uint()) { + switch (type.bits()) { + case 8: + case 16: + case 32: + case 64: { + os << "uint" << type.bits() << "_t"; + return; + } + case 1: + os << "int"; + return; + } + } else if (type.is_int()) { + switch (type.bits()) { + case 8: + case 16: + case 32: + case 64: { + os << "int" << type.bits() << "_t"; + return; + } + } + } + LOG(FATAL) << "Cannot convert type " << type << " to C type"; +} + +void CodeGenSourceBase::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) + if (auto* ptr = type.as()) { + return PrintType(ptr->dtype, os); + } else if (auto* ptr = type.as()) { + PrintType(ptr->element_type, os); + os << '*'; + } else if (IsVoidType(type)) { + os << "void"; + } else { + LOG(FATAL) << "Type " << type << " does not have a corresponding C Type"; + } +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index d938469b89698..8f8f9e1b8bf25 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -52,6 +52,18 @@ class CodeGenSourceBase { * \param value The constant value. */ void MarkConst(std::string value); + /*! + * Print Type representation of type type. + * \param t The type representation. + * \param os The stream to print the ctype into + */ + virtual void PrintType(DataType type, std::ostream& os); // NOLINT(*) + /*! + * Print Type representation of type type. + * \param type The type representation. + * \param os The stream to print the ctype into + */ + virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) protected: /*! \brief entry in ssa assign map */ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index e01a3d93d087d..8faac3f1d9660 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -31,6 +31,7 @@ #include #include +#include "../../relay/backend/name_transforms.h" #include "../../runtime/file_utils.h" #include "../../support/str_escape.h" #include "../func_registry_generator.h" @@ -62,6 +63,8 @@ class SourceModuleNode : public runtime::ModuleNode { std::string GetSource(const std::string& format) final { return code_; } + std::string GetFormat() { return fmt_; } + protected: std::string code_; std::string fmt_; @@ -101,10 +104,12 @@ class CSourceModuleNode : public runtime::ModuleNode { std::string GetSource(const std::string& format) final { return code_; } + std::string GetFormat() { return fmt_; } + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); - if (fmt == "c" || fmt == "cu") { + if (fmt == "c" || fmt == "cc" || fmt == "cpp" || fmt == "cu") { ICHECK_NE(code_.length(), 0); SaveBinaryToFile(file_name, code_); } else { @@ -127,10 +132,27 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, return runtime::Module(n); } +/*! + * \brief A concrete class to get access to base methods of CodegenSourceBase. + * + * This class exist to get access to methods of CodegenSourceBase without duplicating + * them. Therefore, keeping alignment with how codegen and source_module here generates + * code. + */ +class ConcreteCodegenSourceBase : public CodeGenSourceBase { + /*! + * \brief Do nothing as this class exist to get access to methods of CodeGenSourceBase + */ + void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final { + return; + } +}; + class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { public: CSourceCrtMetadataModuleNode(const Array& func_names, const std::string& fmt, - Target target, relay::Runtime runtime, runtime::Metadata metadata) + Target target, relay::Runtime runtime, + relay::backend::ExecutorCodegenMetadata metadata) : fmt_(fmt), func_names_(func_names), target_(target), @@ -142,6 +164,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { std::string GetSource(const std::string& format) final { return code_.str(); } + std::string GetFormat() { return fmt_; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { return PackedFunc(nullptr); } @@ -149,7 +172,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); - if (fmt == "c") { + if (fmt == "c" || fmt == "cc" || fmt == "cpp") { auto code_str = code_.str(); ICHECK_NE(code_str.length(), 0); SaveBinaryToFile(file_name, code_str); @@ -164,7 +187,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { Array func_names_; Target target_; relay::Runtime runtime_; - runtime::Metadata metadata_; + relay::backend::ExecutorCodegenMetadata metadata_; + ConcreteCodegenSourceBase codegen_c_base_; void CreateFuncRegistry() { code_ << "#include \n"; @@ -197,45 +221,161 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << "}\n"; } + String GenerateDLTensorStructWrapper(String reference_arg) { + code_ << "DLTensor " << reference_arg << "_dltensor = {\n"; + code_ << ".data = &" << reference_arg << "\n"; + code_ << "};\n"; + code_ << "TVMValue " << reference_arg << "_tvm_value = {\n"; + code_ << ".v_handle = &" << reference_arg << "_dltensor\n"; + code_ << "};\n"; + return reference_arg + "_tvm_value"; + } + + void GenerateInternalWorkspaceBuffers() { + if (metadata_->pool_inputs.defined()) { + for (const auto& kv : metadata_->pool_inputs.value()) { + tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second; + if (allocated_pool_info->pool_info->is_internal) { + code_ << "__attribute__((section(\".data.tvm\"), "; + code_ << "aligned(" << 16 << ")))\n"; + code_ << "static uint8_t " << allocated_pool_info->pool_info->pool_name << "[" + << allocated_pool_info->allocated_size->value << "];\n"; + } + } + } + } + + bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) { + if (metadata_->pool_inputs.defined()) { + Map allocated_pool_infos = + metadata_->pool_inputs.value(); + if (allocated_pool_infos.find(pool_var) != allocated_pool_infos.end()) { + tir::usmp::AllocatedPoolInfo allocate_pool_info = allocated_pool_infos[pool_var]; + if (allocate_pool_info->pool_info->is_internal) { + return true; + } + } + } + return false; + } + void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name, const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func << "("; - unsigned int total_args = (metadata_->inputs.size() + metadata_->num_outputs); - for (unsigned int i = 0; i < total_args; ++i) { - code_ << "void* arg" << i; - if (i + 1 != total_args) { - code_ << ","; + + { + std::stringstream call_args_ss; + for (const tir::Var& input_var : metadata_->inputs) { + if (input_var->type_annotation.defined()) { + codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); + } else { + codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); + } + call_args_ss << " " << input_var->name_hint << ","; + } + for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + call_args_ss << "void* output" << i << ","; } + for (const tir::Var& pool_var : metadata_->pools) { + if (pool_var->type_annotation.defined()) { + codegen_c_base_.PrintType(pool_var->type_annotation, call_args_ss); + } else { + codegen_c_base_.PrintType(pool_var.dtype(), call_args_ss); + } + call_args_ss << " " << pool_var->name_hint << ","; + } + std::string call_args_str = call_args_ss.str(); + call_args_str.pop_back(); + code_ << call_args_str; } + code_ << ");\n"; code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; code_ << "return " << run_func << "("; - for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { - code_ << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; + + { + std::stringstream call_args_ss; + for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { + call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; + } + for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + int j = metadata_->inputs.size() + i; + call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,"; + } + for (const tir::Var& pool_var : metadata_->pools) { + if (IsInternalWorkspaceBuffer(pool_var)) { + call_args_ss << "&" << metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name + << ","; + } + } + std::string call_args_str = call_args_ss.str(); + call_args_str.pop_back(); + code_ << call_args_str; + code_ << ");\n"; + code_ << "}\n"; } - for (int i = 0; i < metadata_->num_outputs; ++i) { - int j = metadata_->inputs.size() + i; - code_ << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data"; - if (i + 1 != metadata_->num_outputs) { - code_ << ","; + } + + std::unordered_map GenerateRunFuncToEntryPointArgMap() { + std::unordered_map run_func_to_entry_point_args; + int entrypoint_arg_count = 0; + int run_func_arg_count = 0; + + for (unsigned int i = 0; i < metadata_->inputs.size(); i++) { + run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); + entrypoint_arg_count++; + run_func_arg_count++; + } + for (unsigned int i = 0; i < metadata_->num_outputs; i++) { + run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); + entrypoint_arg_count++; + run_func_arg_count++; + } + for (const tir::Var& pool_var : metadata_->pools) { + if (IsInternalWorkspaceBuffer(pool_var)) { + tir::usmp::AllocatedPoolInfo allocated_pool_info = metadata_->pool_inputs.value()[pool_var]; + run_func_to_entry_point_args[run_func_arg_count] = + allocated_pool_info->pool_info->pool_name; + run_func_arg_count++; } } - code_ << ");\n"; - code_ << "}\n"; + return run_func_to_entry_point_args; } void GenerateEntrypointForPackedAPI(const std::string& entrypoint_name, const std::string& run_func) { code_ << "TVM_DLL int32_t " << run_func; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " - "out_type_code, void* resource_handle);\n"; + "out_type_code, void* resource_handle);\n\n"; + code_ << "int32_t " << entrypoint_name; code_ << "(void* args, void* type_code, int num_args, void* out_value, void* " "out_type_code, void* resource_handle) {\n"; + + // We are creating a copy of the set of pointers + size_t number_of_io_tensors = + metadata_->inputs.size() + metadata_->num_outputs + metadata_->pools.size(); + code_ << "TVMValue tensors[" << number_of_io_tensors << "];\n"; + + std::unordered_map run_func_to_entry_point_args = + GenerateRunFuncToEntryPointArgMap(); + for (unsigned int i = 0; i < number_of_io_tensors; i++) { + if (run_func_to_entry_point_args.find(i) != run_func_to_entry_point_args.end()) { + if (run_func_to_entry_point_args[i]->IsInstance()) { + String pool_name = Downcast(run_func_to_entry_point_args[i]); + String pool_name_tvmv = GenerateDLTensorStructWrapper(pool_name); + code_ << "tensors[" << i << "] = " << pool_name_tvmv << ";\n"; + } else { + code_ << "tensors[" << i << "] = ((TVMValue*)args)[" + << run_func_to_entry_point_args[Integer(i)] << "];\n"; + } + } + } + code_ << "return " << run_func; - code_ << "(args, type_code, num_args, out_value, out_type_code, resource_handle);\n"; + code_ << "((void*)tensors, type_code, num_args, out_value, out_type_code, resource_handle);\n"; code_ << "}\n"; } @@ -245,14 +385,35 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { const std::string& mod_name) { code_ << "#include <" << mod_name << ".h>\n"; code_ << "TVM_DLL int32_t " << run_func << "("; - unsigned int total_args = - (metadata_->inputs.size() + metadata_->devices.size() + metadata_->num_outputs); - for (unsigned int i = 0; i < total_args; ++i) { - code_ << "void* arg" << i; - if (i + 1 != total_args) { - code_ << ","; + { + std::stringstream call_args_ss; + for (const tir::Var& input_var : metadata_->inputs) { + if (input_var->type_annotation.defined()) { + codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss); + } else { + codegen_c_base_.PrintType(input_var.dtype(), call_args_ss); + } + call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ","; + } + for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + call_args_ss << "void* output" << i << ","; + } + for (const tir::Var& pool_var : metadata_->pools) { + if (pool_var->type_annotation.defined()) { + codegen_c_base_.PrintType(pool_var->type_annotation, call_args_ss); + } else { + codegen_c_base_.PrintType(pool_var.dtype(), call_args_ss); + } + call_args_ss << " " << pool_var->name_hint << ","; } + for (const String& device : metadata_->devices) { + call_args_ss << "void* " << device << ","; + } + std::string call_args_str = call_args_ss.str(); + call_args_str.pop_back(); + code_ << call_args_str; } + code_ << ");\n"; code_ << "int32_t " << entrypoint_name << "("; code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; @@ -265,32 +426,32 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << ") {" << "return " << run_func << "("; - for (const auto& input : metadata_->inputs) { - std::string sanitised_input = input; - std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_'); - code_ << "inputs->" << sanitised_input << ","; - } - if (metadata_->num_outputs == 1) { - code_ << "outputs->output"; - } else { - for (int i = 0; i < metadata_->num_outputs; ++i) { - code_ << "outputs->output" << i; - if (i + 1 != metadata_->num_outputs) { - code_ << ","; + + { + std::stringstream call_args_ss; + for (const auto& input : metadata_->inputs) { + call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ","; + } + if (metadata_->num_outputs == 1) { + call_args_ss << "outputs->output,"; + } else { + for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + call_args_ss << "outputs->output" << i << ","; } } - } - - if (!metadata_->devices.empty()) { - code_ << ","; - for (const String& device : metadata_->devices) { - code_ << "devices->" << device; - if (device != metadata_->devices.back()) { - code_ << ","; + for (const tir::Var& pool_var : metadata_->pools) { + if (IsInternalWorkspaceBuffer(pool_var)) { + call_args_ss << "&" << metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name + << ","; } } + for (const String& device : metadata_->devices) { + call_args_ss << "devices->" << device << ","; + } + std::string call_args_str = call_args_ss.str(); + call_args_str.pop_back(); + code_ << call_args_str; } - code_ << ");\n"; code_ << "}\n"; } @@ -309,6 +470,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "extern \"C\" {\n"; code_ << "#endif\n"; + GenerateInternalWorkspaceBuffers(); + if (metadata_->unpacked_api) { if (metadata_->interface_api == "c") { GenerateCInterfaceEntrypoint(entrypoint_mangled, run_func_mangled, metadata_->mod_name); @@ -339,7 +502,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { }; runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::Metadata metadata) { + relay::Runtime runtime, + relay::backend::ExecutorCodegenMetadata metadata) { Array func_names; for (runtime::Module mod : modules) { auto pf_funcs = mod.GetFunction("get_func_names"); @@ -350,7 +514,7 @@ runtime::Module CreateCSourceCrtMetadataModule(const Array& mod } } } - auto n = make_object(func_names, "cc", target, runtime, metadata); + auto n = make_object(func_names, "c", target, runtime, metadata); auto csrc_metadata_module = runtime::Module(n); for (const auto& mod : modules) { csrc_metadata_module.Import(mod); @@ -423,7 +587,8 @@ TVM_REGISTER_GLOBAL("runtime.CreateCSourceCrtMetadataModule") .set_body_typed([](const Array& modules, Target target, relay::Runtime runtime) { // Note that we don't need metadata when we compile a single operator - return CreateCSourceCrtMetadataModule(modules, target, runtime, runtime::Metadata()); + return CreateCSourceCrtMetadataModule(modules, target, runtime, + relay::backend::ExecutorCodegenMetadata()); }); } // namespace codegen diff --git a/src/target/source/source_module.h b/src/target/source/source_module.h index fde363c1198a5..3b482a107600f 100644 --- a/src/target/source/source_module.h +++ b/src/target/source/source_module.h @@ -29,6 +29,7 @@ #include #include +#include "../../relay/backend/utils.h" #include "../../runtime/meta_data.h" namespace tvm { @@ -43,7 +44,8 @@ namespace codegen { * \return The wrapped module. */ runtime::Module CreateCSourceCrtMetadataModule(const Array& modules, Target target, - relay::Runtime runtime, runtime::Metadata metadata); + relay::Runtime runtime, + relay::backend::ExecutorCodegenMetadata metadata); } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e4bf48b2a51e0..c562c78bd1874 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -254,6 +254,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("num-cores") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 5de0538960fce..94dd0b044d710 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -282,7 +282,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { << "Only te.placeholder and te.compute are allowed for now."; } - // Infomations used in CreatePrimFunc and its sub-funtions. + // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. Array root_stmts; diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 9e2d3d0e725f0..b31b61b739c19 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -35,7 +35,7 @@ namespace te { using namespace tir; -// Detect the region of input and output to be tensrized. +// Detect the region of input and output to be tensorized. // out_dom: the domain of root iter vars in output op // in_region: region of each input tensor. // return The location of the tensorized scope start. diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 776538adbc0fa..07dcace0b381a 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -56,6 +56,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { private: /*! \brief Iteration range for loop_vars */ std::unordered_map dom_map_; + /*! \brief Extra iteration range hint for free vars */ + std::unordered_map hint_map_; /*! \brief The buffers that the current block reads */ std::vector read_buffers_; /*! \brief The buffers that the current block writes */ @@ -96,6 +98,9 @@ class BlockReadWriteDetector : public StmtExprVisitor { /*! \brief Helper function to update a opaque access. */ void UpdateOpaque(const Var& buffer_var); + /*! \brief Helper function to relax the buffer indices */ + arith::IntSet RelaxAccessIndex(const PrimExpr& index); + void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -140,10 +145,22 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { ExprVisitor::VisitExpr_(op); } +arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) { + arith::IntSet relaxed = arith::EvalSet(index, dom_map_); + if (!hint_map_.empty()) { + // take non-relaxed var bound hints into considerations + // eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope + // then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf) + arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_); + relaxed = arith::Intersect({relaxed, hint_bound}); + } + return relaxed; +} + void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(RelaxAccessIndex(index)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -160,12 +177,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { VisitExpr(op->condition); { // Visit then branch - With ctx(op->condition, &dom_map_, true); + With ctx(op->condition, &dom_map_, &hint_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } if (op->else_case.defined()) { // Visit else branch - With ctx(op->condition, &dom_map_, false); + With ctx(op->condition, &dom_map_, &hint_map_, false); StmtExprVisitor::VisitStmt(op->else_case); } } @@ -175,12 +192,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { VisitExpr(op->args[0]); { // Visit then branch - With ctx(op->args[0], &dom_map_, true); + With ctx(op->args[0], &dom_map_, &hint_map_, true); StmtExprVisitor::VisitExpr(op->args[1]); } { // Visit else branch - With ctx(op->args[0], &dom_map_, false); + With ctx(op->args[0], &dom_map_, &hint_map_, false); StmtExprVisitor::VisitExpr(op->args[2]); } return; @@ -196,7 +213,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(RelaxAccessIndex(index)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index ae2b9500649e0..26cf66c4d4c01 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -255,7 +255,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (!new_buffer_var.same_as(new_load->buffer_var)) { return Load(load_node->dtype, new_buffer_var, load_node->index, load_node->predicate); } - return new_load; + return std::move(new_load); } PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final { @@ -265,7 +265,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (!new_buffer.same_as(new_buffer_load->buffer)) { return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); } - return new_buffer_load; + return std::move(new_buffer_load); } Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final { @@ -284,7 +284,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { new_attr_stmt->body); } } - return new_attr_stmt; + return std::move(new_attr_stmt); } // ForNode default ok since loop_var never of PointerType @@ -302,7 +302,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (!new_buffer_var.same_as(new_store->buffer_var)) { Store(new_buffer_var, new_store->value, new_store->index, new_store->predicate); } - return new_store; + return std::move(new_store); } Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final { @@ -313,7 +313,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, new_buffer_store->span); } - return new_buffer_store; + return std::move(new_buffer_store); } Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final { @@ -324,7 +324,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { return BufferRealize(new_buffer, new_buffer_realize->bounds, new_buffer_realize->condition, new_buffer_realize->body, new_buffer_realize->span); } - return new_buffer_realize; + return std::move(new_buffer_realize); } // IfThenElseNode default ok @@ -338,7 +338,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { if (!new_buffer.same_as(new_prefetch->buffer)) { return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span); } - return new_prefetch; + return std::move(new_prefetch); } // SeqStmtNode default ok @@ -390,7 +390,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { new_block->name_hint, new_block->body, new_block->init, new_block->alloc_buffers, std::move(new_match_buffers), new_block->annotations, new_block->span); } - return new_block; + return std::move(new_block); } // BlockRealizeNode default ok diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ae72d592339f8..636cc7d0a5dbf 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ #include +#include #include #include @@ -266,6 +267,39 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to + * \param loop_sref The loop to be checked + * \return The IterVarType of the specific loop + */ +IterVarType GetLoopIterType(const StmtSRef& loop_sref); + +/*! + * \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree + * \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried + * \return The lowest common ancestor of the input block srefs or loop srefs + * \note The input array is required to have at least one sref + */ +StmtSRef GetSRefLowestCommonAncestor(const Array& srefs); + +/*! + * \brief Checks if the given block has been applied by multi-level tiling. We check this by + * examine the block's annotation. + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has been multi-level tiled. + */ +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); + +/*! + * \brief Collect all the feasible compute-at locations of the input block + * \param self The schedule state + * \param block_sref The block whose compute-at locations are to be collected + * \return All the feasible compute-at locations of the input block, given as an array of loop srefs + * and an array of their indices among the outer loops of the input block + */ +std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, + const StmtSRef& block_sref); + /******** Producer-consumer relation ********/ /*! @@ -442,6 +476,88 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops); +/*! + * \brief Checks if the given AST contains the specific operators + * \param stmt The AST statement to be checked + * \param ops The list of operators to be checked + * \return A boolean indicating whether the AST contains the specific operators + */ +bool HasOp(const Stmt& stmt, const Array& ops); + +/*! + * \brief Checks if the given AST statement contains if-then-else, including + * 1) IfThenElse statement + * 2) Select expression + * 3) The operator `tir.if_then_else` + * 4) non-constant-true Block predicates + * \param stmt The AST statement to be checked + * \return A boolean indicating whether the statement contains the if-then-else pattern + */ +bool HasIfThenElse(const Stmt& stmt); + +/*! + * \brief Given the read/write region, extract the pattern of their index correspondence + * namely, the mapping from read index to the write index. + * \param read_region The read region + * \param write_region The write region + * \return A tuple of booleans, the extracted pattern + * 0) exists: if the pattern is found + * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once + * e.g. A[i, j] = B[i, i, j] + * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. + * e.g. A[i, j] = B[i] + * 3) ordered: if the mapping is ordered + * 4) no_const_read: if there is no constant indexing in the read indices, + * e.g. A[i, j] = B[0, i, j] + * 5) no_shift_read: if there is no constant shift in the read indices, + * e.g. A[i, j] = B[i + 1, j] + */ +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); + +/*! + * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a data parallel block + */ +bool IsSpatial(const StmtSRef& block_sref); + +/*! + * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, + * from outer to inner. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block has a trivial binding + */ +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is + * beneficial. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has data reuse opportunity + */ +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. + * \param self The schedule state. + * \param block_sref The block to be checked. + * \param max_parallel_extent The maximum parallel jobs on the target. + * \param max_parallel_basic The maximum cores on the target. + * \return A boolean indicating whether the operation is beneficial. + */ +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 0a7d57effd0df..2053f8ddde934 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -646,6 +646,158 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +IterVarType GetLoopIterType(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const Var& loop_var = loop->loop_var; + int n_spatial = 0; + int n_reduce = 0; + int n_other = 0; + auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { + if (const auto* realize = obj.as()) { + const BlockNode* block = realize->block.get(); + // Number of block vars and their bindings + ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); + size_t n = realize->iter_values.size(); + for (size_t i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + // Categorize the current block var + int* ref = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + ref = &n_spatial; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + ref = &n_reduce; + } else { + ref = &n_other; + } + // Visit the binding to see if `loop_var` appears + PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void { + if (obj.same_as(loop_var)) { + (*ref) += 1; + } + }); + } + return false; + } + return true; + }; + PreOrderVisit(loop->body, f_visit); + if (n_other) { + return IterVarType::kOpaque; + } else if (n_spatial && n_reduce) { + return IterVarType::kOpaque; + } else if (n_reduce) { + return IterVarType::kCommReduce; + } else { + return IterVarType::kDataPar; + } +} + +StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { + CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref"; + + std::unordered_map sref_visited_cnt; + for (const StmtSRef& sref : srefs) { + const StmtSRefNode* p = sref.get(); + while (p != nullptr) { + ++sref_visited_cnt[p]; + p = p->parent; + } + } + size_t n_sref = srefs.size(); + const StmtSRefNode* p = srefs[0].get(); + while (p != nullptr && sref_visited_cnt[p] != n_sref) { + p = p->parent; + } + ICHECK(p != nullptr); + return GetRef(p); +} + +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).defined(); +} + +std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, + const StmtSRef& block_sref) { + Array location_srefs; + std::vector location_indices; + + // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can + // be inlined. + if (CanComputeInline(self, block_sref)) { + location_srefs.push_back(StmtSRef::InlineMark()); + location_indices.push_back(-2); + } + location_srefs.push_back(StmtSRef::RootMark()); + location_indices.push_back(-1); + + // Step 2. If the block has no consumer, there is no more candidate. + Array consumers = GetConsumers(self, block_sref); + if (consumers.empty()) { + return std::make_pair(location_srefs, location_indices); + } + + // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If + // such a loop cannot be found, there is no more candidate and we just return. + StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers) + : GetRef(consumers[0]->parent); + if (loop_boundary->StmtAs() == nullptr) { + return std::make_pair(location_srefs, location_indices); + } + + // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position + // of the boundary loop reveals the number of possible additional candidates. + Array loop_srefs = GetLoops(consumers[0]); + size_t lca_pos = + std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); + ICHECK_LT(lca_pos, loop_srefs.size()); + size_t n_candidate = lca_pos + 1; + + // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This + // position is used for removing the unwanted candidates from the perspective of performance. + std::vector loop_iter_types; + loop_iter_types.reserve(n_candidate); + int i_last_datapar = -1; + for (size_t i = 0; i < n_candidate; ++i) { + // TODO(siyuan): improve the performance + IterVarType iter_type = GetLoopIterType(loop_srefs[i]); + loop_iter_types.push_back(iter_type); + if (iter_type == IterVarType::kDataPar) { + i_last_datapar = i; + } + } + // Step 6. Check and add the candidates in turn according to the following rules: + // - skip the unit loops (loops with extent 1); + // - do not consider the data-parallel loops after a not-data-parallel loop; + // - do not consider the trailing not-data-parallel loops. + location_srefs.reserve(n_candidate + 2); + location_indices.reserve(n_candidate + 2); + bool visited_reduce = false; + for (size_t i = 0; i < n_candidate; ++i) { + const int64_t* loop_extent = GetLoopIntExtent(loop_srefs[i]); + if (loop_extent != nullptr && *loop_extent == 1) { + continue; + } + + if (loop_iter_types[i] == IterVarType::kDataPar) { + if (visited_reduce) { + break; + } + } else { + visited_reduce = true; + if (static_cast(i) > i_last_datapar) { + break; + } + } + if (CanComputeAt(self, block_sref, loop_srefs[i], true)) { + location_srefs.push_back(loop_srefs[i]); + location_indices.push_back(i); + } + } + + return std::make_pair(location_srefs, location_indices); +} + /******** Producer-consumer relation ********/ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { @@ -1345,6 +1497,139 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +/******** Misc ********/ + +bool HasOp(const Stmt& stmt, const Array& ops) { + std::unordered_set op_set; + op_set.reserve(ops.size()); + for (const Op& op : ops) { + op_set.insert(op.operator->()); + } + bool found = false; + PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { + if (found) { + return false; + } + if (const auto* call = obj.as()) { + if (op_set.count(call->op.operator->())) { + found = true; + } + } + return !found; + }); + return found; +} + +bool HasIfThenElse(const Stmt& stmt) { + bool has_branch = false; + auto f_visit = [&has_branch](const ObjectRef& obj) -> bool { + if (has_branch) { + // stop visiting + return false; + } + if (const auto* realize = obj.as()) { + // Case 1: BlockRealize + if (!is_one(realize->predicate)) { + has_branch = true; + } + } else if (obj->IsInstance() || obj->IsInstance()) { + // Case 2: IfThenElse / Select + has_branch = true; + } else if (const auto* call = obj.as()) { + // Case 3: Call the `if_then_else` operator + static const Op& op_if_then_else = Op::Get("tir.if_then_else"); + if (call->op.same_as(op_if_then_else)) { + has_branch = true; + } + } + return !has_branch; + }; + PreOrderVisit(stmt, f_visit); + return has_branch; +} + +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) { + static constexpr const std::tuple kNotExist = + std::make_tuple(false, false, false, false, false, false); + // Step 1. Extract the write indices + int w_dim = write_region->buffer->shape.size(); + std::unordered_map var2idx; + var2idx.reserve(w_dim); + for (int i = 0; i < w_dim; ++i) { + const Range& dom = write_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + if (const auto* v = dom->min.as()) { + var2idx.emplace(v, i); + } else { + return kNotExist; + } + } + // Step 2. Map each read index to a write index + bool no_const_read = true; + bool no_shift_read = true; + int r_dim = read_region->buffer->shape.size(); + std::vector mapped(r_dim, -1); + for (int i = 0; i < r_dim; ++i) { + const Range& dom = read_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + // Case 1. Read index is a constant + if (as_const_int(dom->min) != nullptr) { + no_const_read = false; + continue; + } + // Case 2. Read index cannot be recognized as `var +/- const` + // where `var` is a write index and `const` is an optional constant shift + Optional opt_const = NullOpt; + const VarNode* var = + static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); + if (var == nullptr || !var2idx.count(var)) { + return kNotExist; + } + // Case 3. Read index is `var +/- const` + mapped[i] = var2idx.at(var); + if (opt_const.defined()) { + no_shift_read = false; + } + } + // Step 3. Check if the mapping is ordered, and count how many times each var is mapped + std::vector mapped_counter(w_dim, 0); + bool ordered = true; + int last_mapped = -1; + for (int i : mapped) { + if (i != -1) { + ++mapped_counter[i]; + if (last_mapped != -1 && last_mapped > i) { + ordered = false; + } + last_mapped = i; + } + } + // Step 4. Check if the mapping is surjective or injective + // Surjective: each write index is mapped at least once + // Injective: each write index is mapped at most once + bool surjective = true; + bool injective = true; + for (int cnt : mapped_counter) { + if (cnt == 0) { + surjective = false; + } else if (cnt >= 2) { + injective = false; + } + } + return std::make_tuple(/*exist=*/true, surjective, injective, ordered, no_const_read, + no_shift_read); +} + /******** Storage Scope ********/ void CheckStorageScope(const ScheduleState& self, String storage_scope) { @@ -1376,5 +1661,191 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } } +bool IsSpatial(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != IterVarType::kDataPar) { + return false; + } + } + return true; +} + +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = GetLoops(block_sref); + Array binds = GetBlockRealize(self, block_sref)->iter_values; + if (loops.size() != binds.size()) { + return false; + } + for (int i = 0, n = loops.size(); i < n; ++i) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + if (binds[i].get() != loop->loop_var.get()) { + return false; + } + } + return true; +} + +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || + !IsTrivialBinding(self, block_sref)) { + return false; + } + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + // Step 1. Sort out spatial block variables + std::vector spatial_block_vars; + spatial_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& block_var : block->iter_vars) { + if (block_var->iter_type == IterVarType::kDataPar) { + spatial_block_vars.push_back(block_var->var.get()); + } + } + // Step 2. Enumerate each read region, check the number of block vars that are not used + // to index the read region + int total_unused_block_vars = 0; + std::unordered_set read_buffers; + read_buffers.reserve(block->reads.size()); + for (const BufferRegion& buffer_region : block->reads) { + const BufferNode* buffer = buffer_region->buffer.get(); + const Array& regions = buffer_region->region; + // Step 2.1. Duplication of read buffers are not allowed + if (read_buffers.insert(buffer).second == false) { + return false; + } + // Step 2.2. Skip the reduction buffer + if (buffer == write_buffer) { + continue; + } + // Step 2.3. Collect the block vars that are used to index the read region + std::unordered_set vars; + for (const Range& range : regions) { + if (as_const_int(range->extent) == nullptr) { + return false; + } + for (const Var& var : UndefinedVars(range->min)) { + vars.insert(var.get()); + } + } + // Step 2.4. Check if the block vars are not used to index the read region + int n_unused_block_vars = 0; + for (const VarNode* block_var : spatial_block_vars) { + if (vars.count(block_var) == 0) { + ++n_unused_block_vars; + } + } + total_unused_block_vars += n_unused_block_vars; + } + return total_unused_block_vars >= 1; +} + +std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref) { + Array loops = tir::GetLoops(block_sref); + int64_t cum_space_len = 1, cum_reduce_len = 1; + /* + * Return (-1, -1) if + * 1. there is some loop with type other than kDataPar and kCommReduce; + * 2. there is some loop which is dynamic. + */ + for (const tir::StmtSRef& loop_sref : loops) { + tir::IterVarType type = GetLoopIterType(loop_sref); + if (type == tir::kDataPar) { + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (*extent != -1) { + cum_space_len *= *extent; + } else { + return std::make_pair(-1, -1); + } + } else if (type == tir::kCommReduce) { + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (*extent != -1) { + cum_reduce_len *= *extent; + } else { + return std::make_pair(-1, -1); + } + } else { + return std::make_pair(-1, -1); + } + } + return std::make_pair(cum_space_len, cum_reduce_len); +} + +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = tir::GetLoops(block_sref); + + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return false; + } + + // Cond 2. The block is a reduction block and has trivial binding. + const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false); + if (!(IsReductionBlock(self, block_sref, scope_sref) && // + IsTrivialBinding(self, block_sref))) { + return false; + } + + // Cond 3. Every the loop axis must be either spatial axis or reduction axis. + for (const tir::StmtSRef& loop_sref : loops) { + const tir::IterVarType& type = GetLoopIterType(loop_sref); + if (type != tir::kDataPar && type != tir::kCommReduce) { + return false; + } + } + + // Cond 4. Whether there is at least one reduction loop. + // Cond 5. The loops are continuous, and the body of the innermost loop is exactly the block. + bool has_reduction_loop = false; + for (size_t i = 0; i < loops.size(); ++i) { + // Cond 4. + if (GetLoopIterType(loops[i]) == tir::kCommReduce) { + has_reduction_loop = true; + } + + // Cond 5. + const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); + if (i < loops.size() - 1) { + const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); + if (loop_i->body.get() != loop_i1) { + return false; + } + } else { + const auto* block_realize = loop_i->body.as(); + if (!block_realize || block_realize->block.get() != block) { + return false; + } + } + } + if (!has_reduction_loop) { + return false; + } + + // Cond 6. Can successfully calculating the cumulative loop length. + int64_t cum_space_len, cum_reduce_len; + std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(self, block_sref); + if (cum_space_len == -1 || cum_reduce_len == -1) { + return false; + } + + // Cond 7. + if (NeedsMultiLevelTiling(self, block_sref)) { + // Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops. + return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent); + } else if (cum_reduce_len > 1) { + // Always try rfactor/cross-thread-reduction for other reduction blocks. + return cum_reduce_len > max_parallel_basic; + } else { + return false; + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9e5b6f949feb9..9f8dc6dd2dafd 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -242,6 +242,15 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int throw; } +LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV( + tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); + TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index d420728a9e3c0..96cb0f728835e 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -86,6 +86,8 @@ class ConcreteScheduleNode : public ScheduleNode { Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; + LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 45efd9f76cefa..f0b38af01b5f7 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -98,6 +98,17 @@ TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, Optional>* decision); +/*! + * \brief Sample a compute-at location of the given block + * \param self The schedule state + * \param rand_state The random state + * \param block_sref The sref of the block whose compute-at location is to be sampled + * \param decision The sampling decision + * \return The sampled loop where the input block is to be computed at + */ +TVM_DLL tir::StmtSRef SampleComputeLocation( + tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, + const tir::StmtSRef& block_sref, Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 181e5a6cfa697..418e770a5c932 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -278,7 +278,7 @@ class StorageScopeMutator : StmtExprMutator { ptr->buffer = it->second; return PrimExpr(ptr); } else { - return res; + return std::move(res); } } @@ -291,7 +291,7 @@ class StorageScopeMutator : StmtExprMutator { ptr->buffer = it->second; return Stmt(ptr); } else { - return res; + return std::move(res); } } @@ -348,7 +348,7 @@ class StorageScopeMutator : StmtExprMutator { Block new_block(n); block_sref_reuse_->Set(GetRef(block), new_block); - return new_block; + return std::move(new_block); } } diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 159171ecae317..4a80279d97cb0 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -635,7 +635,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff /*require_subtree_compact_dataflow=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); - // Step 2. Creat CacheStageInfo + // Step 2. Create CacheStageInfo CacheStageInfo info; info.read_buffer = read_buffer; // Create the corresponding buffer to be written, i.e. result of cache_read diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 6d944b38d46a2..0e767825573ff 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -354,6 +354,40 @@ std::vector SamplePerfectTile( return result; } +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const StmtSRef& block_sref, Optional* decision) { + // Step 1. Collect all possible compute-at locations. + Array location_srefs; + std::vector location_indices; + std::tie(location_srefs, location_indices) = CollectComputeLocation(self, block_sref); + ICHECK_EQ(location_srefs.size(), location_indices.size()); + + // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the + // location candidates. Otherwise, pick the location before the previous decision. + // Step 3. If there was not a previous decision, sample a decision from the collected locations. + if (decision->defined()) { + int64_t old_decision = Downcast(*decision)->value; + auto it = std::lower_bound(location_indices.begin(), location_indices.end(), old_decision); + int idx = it - location_indices.begin(); + + if (it != location_indices.end() && *it == old_decision) { + *decision = Integer(old_decision); + return location_srefs[idx]; + } else if (it != location_indices.begin()) { + *decision = Integer(location_indices[idx - 1]); + return location_srefs[idx - 1]; + } else { + *decision = Integer(-1); + return StmtSRef::RootMark(); + } + } else { + int sampled_idx = SampleInt(rand_state, 0, location_indices.size()); + *decision = Integer(location_indices[sampled_idx]); + return location_srefs[sampled_idx]; + } +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -418,8 +452,38 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleComputeLocation"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 1; + + static LoopRV UnpackedApplyToSchedule(Schedule sch, // + BlockRV block_rv, // + Optional decision) { + return sch->SampleComputeLocation(block_rv, decision); + } + + static String UnpackedAsPython(Array outputs, // + String block_rv, // + Optional decision) { + PythonAPICall py("sample_compute_location"); + py.Input("block", block_rv); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); +TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 75939f00b8f43..6e33862c07cae 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -125,6 +125,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") + .set_body_method(&ScheduleNode::SampleComputeLocation); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b4d1ba01e93e4..da7a2641b1627 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -73,6 +73,20 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } +LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, + this->GetSRef(block_rv), &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 5ce4763f117f1..b35f1b6e17bb2 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -51,6 +51,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; + LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 860b3f64b5dca..be6d5a18a47f1 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -229,6 +229,34 @@ inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { return as_const_int(loop->extent); } +/*! + * \brief Check if an expression consists of a single variable, + * or a variable plus/minus an constant integer shift + * \param expr The expression to be checked + * \return The single variable in the expression, or NullOpt if the expression is neither a variable + * or a constant shift from a variable + */ +inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { + if (const auto* var = expr.as()) { + *constant = NullOpt; + return GetRef(var); + } + arith::PVar var; + arith::PVar shift; + // match: "var + shift" + if ((var + shift).Match(expr) || (shift + var).Match(expr)) { + *constant = shift.Eval(); + return var.Eval(); + } + // match: "var - shift" + if ((var - shift).Match(expr)) { + IntImm result = shift.Eval(); + *constant = IntImm(result->dtype, -result->value); + return var.Eval(); + } + return NullOpt; +} + /******** Annotation ********/ /*! @@ -280,6 +308,72 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& an return result.defined() && result.value() == ann_val; } +/*! + * \brief Check if a Block/For has a specific pair of annotation key and values + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be checked + * \param ann_val The boolean annotation value to be checked + * \return Whether a Block/For has a specific pair of annotation key and values + */ +inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { + Optional result = GetAnn(sref, ann_key); + return result.defined() && result.value()->value == ann_val; +} + +/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ + +/*! + * \brief Reorder the reduction loops to innermost positions if needed. + * \param sch The schedule + * \param block_rv The block where to apply the reorder + * \param fused_reduce_loop The fusion-generated loop to return. + * \param num_spatial_loops The number of spatial loops to return. + * \note Before invoking this helper function, make sure that the block has only spatial and + * reduction loop axes. + */ +inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, + tir::LoopRV* fused_reduce_loop, + size_t* num_spatial_loops) { + Array loops = sch->GetLoops(block_rv); + Array loop_srefs; + for (const tir::LoopRV& loop_rv : loops) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + } + + Array new_order; + // Step 1. Add spatial loops. + *num_spatial_loops = 0; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { + new_order.push_back(loops[i]); + (*num_spatial_loops)++; + } + } + // Step 2. Add reduction loops. + Array reduction_loops; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { + new_order.push_back(loops[i]); + reduction_loops.push_back(loops[i]); + } + } + // Step 3. Apply reordering if new_order differs from the original order. + ICHECK_EQ(new_order.size(), loops.size()); + for (size_t i = 0; i < loops.size(); ++i) { + if (!new_order[i].same_as(loops[i])) { + sch->Reorder(new_order); + break; + } + } + // Step 4. Fuse all the reduction loops if there are multiple reduction loops. + CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop"; + if (reduction_loops.size() > 1) { + *fused_reduce_loop = sch->Fuse(reduction_loops); + } else { + *fused_reduce_loop = reduction_loops[0]; + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 07f977860d933..20ddd7f84a35d 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->condition); { // Visit then branch - With ctx(op->condition, &dom_map_, true); + With ctx(op->condition, &dom_map_, &hint_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } if (op->else_case.defined()) { // Visit else branch - With ctx(op->condition, &dom_map_, false); + With ctx(op->condition, &dom_map_, &hint_map_, false); StmtExprVisitor::VisitStmt(op->else_case); } } @@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->args[0]); { // Visit then branch - With ctx(op->args[0], &dom_map_, true); + With ctx(op->args[0], &dom_map_, &hint_map_, true); StmtExprVisitor::VisitExpr(op->args[1]); } { // Visit else branch - With ctx(op->args[0], &dom_map_, false); + With ctx(op->args[0], &dom_map_, &hint_map_, false); StmtExprVisitor::VisitExpr(op->args[2]); } return; @@ -282,6 +282,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief The map from loop vars to their iter range. */ std::unordered_map dom_map_; + /*! \brief Extra map from free vars to their iter range hints. */ + std::unordered_map hint_map_; /*! \brief The analyzer aware of loop domains. */ arith::Analyzer dom_analyzer_; /*! \brief The map from Buffer to it's relaxed access set. */ diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 424a1bbb0ae6d..7a6d2d37c3760 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -450,7 +450,7 @@ class CoProcInstDepDetector : public StmtVisitor { std::unordered_set exit_ctx; // existing pop performed at enter std::vector > enter_pop; - // existing push peformed at exit + // existing push performed at exit std::vector > exit_push; // clear the state void clear() { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 2423b09d4fb7c..4eb9cc5b1a90d 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -258,8 +258,8 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { arith::Analyzer analyzer; PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); Array equations; - std::unordered_set var_set; - std::function fvisit = [&equations, &var_set, &fvisit](const PrimExpr& e) { + Array vars; + std::function fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) { if (e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance()) { bool is_simple = true; @@ -278,7 +278,12 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }); if (is_simple && !cand_vars.empty()) { - for (const Var& var : cand_vars) var_set.insert(var); + for (const Var& new_var : cand_vars) { + if (!std::any_of(vars.begin(), vars.end(), + [&new_var](const Var& v) { return v.same_as(new_var); })) { + vars.push_back(new_var); + } + } equations.push_back(Downcast(e)); } } else if (e->IsInstance()) { @@ -293,18 +298,24 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }; fvisit(condition); - if (equations.empty() || var_set.empty()) { + if (equations.empty() || vars.empty()) { return Map(); } // build dom ranges for related vars - Array vars = Array(var_set.begin(), var_set.end()); Map ranges; for (const Var& v : vars) { - auto it = dom_map_->find(v.get()); - if (it != dom_map_->end()) { - const auto& int_set = it->second; - ranges.Set(v, Range::FromMinExtent(int_set.min(), - analyzer.Simplify(int_set.max() - int_set.min() + 1))); + arith::IntSet dom; + auto relax_it = relax_map_->find(v.get()); + if (relax_it != relax_map_->end()) { + dom = relax_it->second; + } else { + auto hint_it = hint_map_->find(v.get()); + if (hint_it != hint_map_->end()) { + dom = hint_it->second; + } + } + if (dom.defined()) { + ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1))); } } // solve constraints @@ -314,24 +325,53 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } ConditionalBoundsContext::ConditionalBoundsContext( - const PrimExpr& condition, std::unordered_map* dom_map, - bool is_true_branch) - : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {} + const PrimExpr& condition, std::unordered_map* relax_map, + std::unordered_map* hint_map, bool is_true_branch) + : condition_(condition), + relax_map_(relax_map), + hint_map_(hint_map), + is_true_branch_(is_true_branch) {} void ConditionalBoundsContext::EnterWithScope() { for (const auto& p : GetVarBoundsFromCondition()) { const auto* var = p.first.get(); - auto it = dom_map_->find(var); - if (it != dom_map_->end()) { - origin_map_.emplace(var, it->second); - it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)}); + arith::IntSet new_dom = arith::IntSet::FromRange(p.second); + auto relax_it = relax_map_->find(var); + if (relax_it != relax_map_->end()) { + // this is a bound for relaxed var + origin_map_.emplace(var, relax_it->second); + relax_it->second = arith::Intersect({relax_it->second, new_dom}); + } else { + // this is a bound for free var + auto hint_it = hint_map_->find(var); + if (hint_it != hint_map_->end()) { + origin_map_.emplace(var, hint_it->second); + hint_it->second = arith::Intersect({hint_it->second, new_dom}); + } else { + origin_map_.emplace(var, arith::IntSet::Nothing()); + hint_map_->insert(hint_it, {var, new_dom}); + } } } } void ConditionalBoundsContext::ExitWithScope() { for (const auto& p : origin_map_) { - (*dom_map_)[p.first] = p.second; + const auto* var = p.first; + auto relax_it = relax_map_->find(var); + if (relax_it != relax_map_->end()) { + // recover bound for relaxed var + relax_it->second = p.second; + } else { + // recover bound for free var + auto hint_it = hint_map_->find(var); + ICHECK(hint_it != hint_map_->end()); + if (p.second.IsNothing()) { + hint_map_->erase(hint_it); + } else { + hint_it->second = p.second; + } + } } } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 7b1d34c8162de..da52a82a2f087 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -231,9 +231,9 @@ Bool IsFromLegacyTESchedule(PrimFunc f); *\brief Context helper to update domain map within conditional scope. * * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is - *[0, 8]. Then `With ctx(&dom_map, bounds, true)` step into scope where - *dom_map[i] is [0, 8] and `With ctx(&dom_map, bounds, false)` step into - *scope where dom_map[i] is [9, 20] + * [0, 8]. Then `With ctx(condition, &relax_map, &hint_map, true)` step + *into scope where dom_map[i] is [0, 8] and `With ctx(condition, + *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20] */ class ConditionalBoundsContext { private: @@ -241,11 +241,13 @@ class ConditionalBoundsContext { /*! * \brief Construct a condition bounds context. * \param condition The condition holds on true branch. - * \param dom_map The global domain map to be updated. + * \param relax_map The domain map for relaxed vars to update. + * \param hint_map The domain map for free vars to update. * \param is_true_branch Whether step into the branch where condition bounds holds. */ ConditionalBoundsContext(const PrimExpr& condition, - std::unordered_map* dom_map, + std::unordered_map* relax_map, + std::unordered_map* hint_map, bool is_true_branch); void EnterWithScope(); void ExitWithScope(); @@ -255,8 +257,10 @@ class ConditionalBoundsContext { /*! \brief the condition holds on true branch. */ const PrimExpr& condition_; - /*! \brief global domain map to updated */ - std::unordered_map* dom_map_; + /*! \brief domain map for relaxed vars to update */ + std::unordered_map* relax_map_; + /*! \brief domain map for free vars to update */ + std::unordered_map* hint_map_; /*! \brief whether is on true branch */ bool is_true_branch_; /*! \brief used to record and restore original var bounds */ diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 169983a525dfd..6365e09246fc7 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -57,33 +57,26 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) { const Stmt nop = Evaluate(0); std::vector device_init; - // Create arg to buffer binder - std::unordered_map vmap; - ArgBinder binder(&vmap); // Collect variables and buffers to map between Array args; - std::vector> var_def; - bool buffer_map_found = false; - - for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { - Var param = func_ptr->params[i]; - - auto it = func_ptr->buffer_map.find(param); - if (it != func_ptr->buffer_map.end()) { - args.push_back((*it).second->data); - buffer_map_found = true; + for (const Var& param : func->params) { + // Ideally all func params should have Buffers defined in the buffer_map + // We should look to insert buffer_maps for all PrimFuncs that are returned + // to the core compiler. + if (func->buffer_map.find(param) != func->buffer_map.end()) { + args.push_back(func->buffer_map[param]->data); } else { args.push_back(param); } } - if (buffer_map_found) { + if (func->buffer_map.size()) { device_init.push_back(AttrStmt(node, attr::device_id, device_id, nop)); device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop)); } - func_ptr->body = MergeNest({device_init, binder.init_nest(), binder.asserts()}, func_ptr->body); + func_ptr->body = MergeNest(device_init, func_ptr->body); func_ptr->params = args; func_ptr->ret_type = PrimType(DataType::Int(32)); diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index 9c1aab6344961..d9b5f529a35c1 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -66,7 +66,7 @@ class ThreadBindingUnifier : public StmtExprMutator { } For new_loop = Downcast(stmt); new_loop.CopyOnWrite()->annotations = std::move(annotations); - return new_loop; + return std::move(new_loop); } template diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index 5e1ce5f289c13..324474c569d4a 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -39,6 +39,8 @@ #include #include #include +#include +#include #include namespace tvm { @@ -47,109 +49,93 @@ namespace usmp { namespace algo { /*! - * \brief This is the base class for Greedy Algorithms where the sorting - * is specialized in the extended classes based on the greedy criteria. + * \brief Rounds up the offset to satisfy the alignement requirement */ -class GreedyBase { - public: - GreedyBase() {} - /*! - * \brief This function should be implemented by the extended classes to sort the BufferInfo - * objects based on a criteria and then calling PostSortAllocation. - */ - virtual Map PlanMemory(const Array& buffer_info_arr) = 0; +size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { + return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; +} - protected: - /*! - * \brief Rounds up the offset to satisfy the alignement requirement - */ - size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, - const int& byte_alignment) { - return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; +/*! + * \brief A helper function check whether a offset is valid given the constraints + */ +bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { + if (candidate_pool->size_hint_bytes == -1) { + // this means pool is not bounded + return true; + } + auto pool_size = static_cast(candidate_pool->size_hint_bytes->value); + auto max_address = next_offset + size_bytes; + if (max_address <= pool_size) { + return true; } + return false; +} - /*! - * \brief A helper function check whether a offset is valid given the constraints - */ - bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, - const size_t& size_bytes) { - if (candidate_pool->size_hint_bytes == -1) { - // this means pool is not bounded - return true; - } - auto pool_size = static_cast(candidate_pool->size_hint_bytes->value); - auto max_address = next_offset + size_bytes; - if (max_address <= pool_size) { - return true; +/*! + * \brief Selects a pool for placement in the given set of ordered pool candidates + */ +PoolInfo GreedyBase::SelectPlacementPool( + const BufferInfo& buf_info, + const std::unordered_map& pool_offsets) { + // Here the pool candidates are ordered when it is consumed by the algorithm. + // This could be from order the user has specified. However, schedulers are + // welcome to change the order for performance reasons. + for (const auto& pool_info : buf_info->pool_candidates) { + if (pool_offsets.count(pool_info)) { + return pool_info; } - return false; } + CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " + "trying to allocate the buffer : " + << buf_info << "\n. Please increase the size_hints for memory pools."; + return PoolInfo(); +} - /*! - * \brief Selects a pool for placement in the given set of ordered pool candidates - */ - PoolInfo SelectPlacementPool( - const BufferInfo& buf_info, - const std::unordered_map& pool_offsets) { - // Here the pool candidates are ordered when it is consumed by the algorithm. - // This could be from order the user has specified. However, schedulers are - // welcome to change the order for performance reasons. +/*! + * \brief This is the base allocation function that works on sorted BufferInfo objects based + * on the greedy heuristic. The sorting algorithm has to be called before calling this. + */ +Map GreedyBase::PostSortAllocation( + const std::vector& buffer_info_vec) { + Map pool_allocations; + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map pool_offset_candidates; for (const auto& pool_info : buf_info->pool_candidates) { - if (pool_offsets.count(pool_info)) { - return pool_info; + // Mark pool candidates that satisfy the size constraints. + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; } } - CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " - "trying to allocate the buffer : " - << buf_info << "\n. Please increase the size_hints for memory pools."; - return PoolInfo(); - } - /*! - * \brief This is the base allocation function that works on sorted BufferInfo objects based - * on the greedy heuristic. The sorting algorithm has to be called before calling this. - */ - Map PostSortAllocation( - const std::vector& buffer_info_vec) { - Map pool_allocations; - for (const auto& buf_info : buffer_info_vec) { - std::unordered_map pool_offset_candidates; - for (const auto& pool_info : buf_info->pool_candidates) { - // Mark pool candidates that satisfy the size constraints. - if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { - pool_offset_candidates[pool_info] = 0; - } - } - - for (const auto& conflict_buf_info_obj : buf_info->conflicts) { - auto conflict_buf_info = Downcast(conflict_buf_info_obj); - size_t next_offset = 0; - // We only look at already allocated BufferInfo in-terms of conflicts. - if (pool_allocations.count(conflict_buf_info)) { - auto pool_allocation = pool_allocations[conflict_buf_info]; - next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; - next_offset = - round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); - // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. - if (IsValidPlacement(pool_allocation->pool_info, next_offset, - buf_info->size_bytes->value)) { - // There could be multiple conflicting BufferInfo in the same pool. - // Thus, we need to make sure we pick the largest offset of them all. - if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { - pool_offset_candidates[pool_allocation->pool_info] = next_offset; - } - } else { - pool_offset_candidates.erase(pool_allocation->pool_info); + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + auto conflict_buf_info = Downcast(conflict_buf_info_obj); + size_t next_offset = 0; + // We only look at already allocated BufferInfo in-terms of conflicts. + if (pool_allocations.count(conflict_buf_info)) { + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + // There could be multiple conflicting BufferInfo in the same pool. + // Thus, we need to make sure we pick the largest offset of them all. + if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); } } - auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); - pool_allocations.Set( - buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); } - return pool_allocations; + auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); + pool_allocations.Set( + buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); } -}; + return pool_allocations; +} /*! * \brief This class implements Greedy by the size of BufferInfo diff --git a/src/tir/usmp/algo/hill_climb.cc b/src/tir/usmp/algo/hill_climb.cc new file mode 100644 index 0000000000000..c4ed73eb2feb2 --- /dev/null +++ b/src/tir/usmp/algo/hill_climb.cc @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/algo/hill_climb.cc + * \brief Implement greedy by size memory planning algorithm + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/* + * Simulated annealing / Hill climb + * + * Works by continiously invoking 'greedy-by-size' allocation, + * assessing the result, and introducing permutations to the allocation + * order which hopefully will led to more 'compact' memory allocation. + */ +class HillClimbAllocator : public GreedyBase { + private: + size_t memory_pressure_ = 0; + + public: + explicit HillClimbAllocator(size_t memory_pressure) + : GreedyBase(), memory_pressure_(memory_pressure) {} + + protected: + using alloc_map_t = std::unordered_map; + + /* + * Initial sorting routine + */ + void sort_vector(std::vector* buffer_info_vec) { + std::sort(buffer_info_vec->begin(), buffer_info_vec->end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->size_bytes->value == b->size_bytes->value) { + if (a->conflicts.size() == b->conflicts.size()) { + return std::string(a->name_hint->data) > std::string(b->name_hint->data); + } else { + return a->conflicts.size() > b->conflicts.size(); + } + } + return a->size_bytes->value > b->size_bytes->value; + }); + } + + /* + * HillClimb's version of greedy allocation + * \param buffer_info_vec - buffers in specific order for allocation + */ + alloc_map_t greedy(const std::vector& buffer_info_vec) { + alloc_map_t pool_allocations(buffer_info_vec.size()); + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map pool_offset_candidates; + for (const auto& pool_info : buf_info->pool_candidates) { + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; + } + } + + std::vector buf_conf; + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + const BufferInfoNode* conflict_buf_info = conflict_buf_info_obj.as(); + if (pool_allocations.end() != pool_allocations.find(conflict_buf_info)) { + buf_conf.push_back(conflict_buf_info); + } + } + + // extra sorting for pool offsets + std::sort(buf_conf.begin(), buf_conf.end(), + [&pool_allocations](const auto* a, const auto* b) { + return pool_allocations[a]->byte_offset->value < + pool_allocations[b]->byte_offset->value; + }); + + for (const auto* conflict_buf_info : buf_conf) { + size_t next_offset = 0; + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + if (!pool_offset_candidates.count(pool_allocation->pool_info)) { + continue; + } + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + if (next_offset > pool_offset_candidates[pool_allocation->pool_info] && + pool_offset_candidates[pool_allocation->pool_info] + + static_cast(buf_info->size_bytes) > + static_cast(pool_allocation->byte_offset)) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; + } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); + } + } + auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); + pool_allocations[buf_info.as()] = + PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])); + } + return pool_allocations; + } + + /* + * Finds highest allocated memory address for each pool + */ + std::unordered_map find_highest( + alloc_map_t* pool_allocations) { + std::unordered_map pool_sizes; + for (const auto& it : *pool_allocations) { + const BufferInfoNode* buf = it.first; + const PoolAllocation& pa = it.second; + size_t high_sz = pa->byte_offset + buf->size_bytes; + if (pool_sizes[pa->pool_info] <= high_sz) { + pool_sizes[pa->pool_info] = high_sz; + } + } + return pool_sizes; + } + + /* + * Collects lists of first and secind level neigbors for provided buf. + * First level are the immediate neighbors of the buf and + * second level are the immediate neighbors of the first level nodes + */ + template + void collect_neighbor_lists(const BufferInfoNode* buf, + std::vector* first_level, + std::vector* second_level, const TPos& _pos) { + std::unordered_map first_level_set; + std::unordered_map second_level_set; + + auto buf_pos = _pos(buf); + for (const auto& c1 : buf->conflicts) { + const auto* c1_buf = c1.as(); + int c1_pos = _pos(c1_buf); + if (buf_pos > c1_pos) { + first_level_set[c1_pos] = c1_buf; + } + int c2_pos = -1; + for (const auto& c2 : c1_buf->conflicts) { + const auto c2_buf = c2.as(); + if (c1_pos > (c2_pos = _pos(c2_buf))) { + second_level_set[c2_pos] = c2_buf; + } + } + } + + // std::vector first_level; + for (const auto& i : first_level_set) { + first_level->push_back(i.second); + } + // std::vector second_level; + for (const auto& i : second_level_set) { + second_level->push_back(i.second); + } + } + + public: + Map PlanMemory(const Array& buffer_info_arr) { +// rand_r does not exist on Windows platform +#if defined(__linux__) || defined(__ANDROID__) + unsigned int _seedp = 0; +#define rnd_func() rand_r(&_seedp) +#else +#define rnd_func() rand() +#endif + + std::vector buffer_info_vec; + for (const auto& buffer_info : buffer_info_arr) { + ICHECK(buffer_info->pool_candidates.size()) + << "Cannot process buffer \"" << buffer_info->name_hint << "\" with no pool candidates"; + buffer_info_vec.push_back(std::move(buffer_info)); + } + + sort_vector(&buffer_info_vec); + + // populate positional index map + std::unordered_map _pos_map; + for (size_t index = 0; index < buffer_info_vec.size(); ++index) { + _pos_map[buffer_info_vec[index].as()] = index; + } + + size_t total_size = 0; + int attempts = 0; + + int swap_i1 = -1; + int swap_i2 = -1; + size_t desired_bytes_ = memory_pressure_; + constexpr auto _max_attempts = 500; + alloc_map_t rollback_pool_allocations; + alloc_map_t result_pool_allocations; + alloc_map_t pool_allocations; + + auto swap_buffers = [&buffer_info_vec, &_pos_map](int i1, int i2) { + if (i1 == i2) return; + auto b1 = buffer_info_vec[i1]; + auto b2 = buffer_info_vec[i2]; + buffer_info_vec[i1] = b2; + buffer_info_vec[i2] = b1; + + _pos_map[b1.as()] = i2; + _pos_map[b2.as()] = i1; + }; + + auto _pos = [&_pos_map](const auto* e) { + auto it = _pos_map.find(e); + if (it != _pos_map.end()) { + return it->second; + } + LOG(FATAL) << "node is not indexed in the _pos_map"; + return -1; + }; + + for (; attempts < _max_attempts; ++attempts) { + rollback_pool_allocations = std::move(pool_allocations); + pool_allocations = std::move(greedy(buffer_info_vec)); + + // estimate result buffers + std::unordered_map pool_sizes = + find_highest(&pool_allocations); + // calculate summary + size_t total = 0; + for (const auto& el : pool_sizes) { + total += el.second; + } + // accept/reject result heuristic + if (!total_size || /* first run */ + (total_size > total || /* always accept if better or with some probability */ + rnd_func() % 100 < static_cast(50 * (total - total_size) / total / attempts))) { + // remember winning combination + result_pool_allocations = pool_allocations; + total_size = total; + + // reached desired size + if (total_size <= desired_bytes_) { + break; + } + + } else { + // rollback + swap_buffers(swap_i2, swap_i1); + pool_allocations = std::move(rollback_pool_allocations); + pool_sizes = find_highest(&pool_allocations); + } + + std::vector max_pool_buf; + + for (const auto& it : pool_allocations) { + const auto* buf = it.first; + const auto pa = it.second; + size_t high_sz = pa->byte_offset + buf->size_bytes; + if (pool_sizes[pa->pool_info] == high_sz) { + max_pool_buf.push_back(buf); + } + } + + // pick highest + const BufferInfoNode* node = max_pool_buf[rnd_func() % max_pool_buf.size()]; + std::vector first_level; + std::vector second_level; + collect_neighbor_lists(node, &first_level, &second_level, _pos); + + // retry if no first level neightbors were collected + if (!first_level.size()) { + continue; + } + + // pick the buffers + const BufferInfoNode* swap_buf1 = first_level[rnd_func() % first_level.size()]; + const BufferInfoNode* swap_buf2 = swap_buf1; + while (swap_buf2 == swap_buf1) { + swap_buf2 = second_level.size() && (!first_level.size() || (rnd_func() % 100 > 25)) + ? second_level[rnd_func() % second_level.size()] + : first_level[rnd_func() % first_level.size()]; + + if (second_level.size() < 2 && first_level.size() < 2) break; + } + if (swap_buf1 == swap_buf2) { + continue; + } + + swap_i1 = _pos(swap_buf1); + swap_i2 = _pos(swap_buf2); + // do swap + swap_buffers(swap_i1, swap_i2); + } + + Map result; + // return winning combination + for (auto it : result_pool_allocations) { + result.Set(GetRef(it.first), it.second); + } + return result; + } +}; + +Map HillClimb(const Array& buffer_info_arr, + const Integer& memory_pressure) { + return HillClimbAllocator(memory_pressure).PlanMemory(buffer_info_arr); +} + +TVM_REGISTER_GLOBAL("tir.usmp.algo.hill_climb") + .set_body_typed([](Array buffer_info_arr, Integer memory_pressure) { + return HillClimb(buffer_info_arr, memory_pressure); + }); + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index ea53f27e5558a..fb4fb52c507e1 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -58,7 +59,9 @@ class BufferInfoExtractor : public StmtExprVisitor { public: explicit BufferInfoExtractor(const IRModule& module) : module_(module) { for (const auto& gv_func : module_->functions) { - functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + if (gv_func.second->IsInstance()) { + functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } } // Pushing a scope info for the initial body of the main function scope_stack_.push(ScopeInfo()); @@ -342,16 +345,24 @@ void BufferInfoExtractor::VisitExpr_(const VarNode* op) { Array static GetMatchedBuffers(const PrimFunc& func) { Array buffer_vars; - for (const auto& param : func->params) { + for (unsigned int i = 0; i < func->params.size() - 1; i++) { + Var param = func->params[i]; buffer_vars.push_back(func->buffer_map[param]->data); } + Var last_param = func->params.back(); + // Checks whether last var is present in the buffer map + // because it could be the resource handle + if (func->buffer_map.find(last_param) != func->buffer_map.end()) { + buffer_vars.push_back(func->buffer_map[last_param]->data); + } return buffer_vars; } void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimFunc& func) { auto param_buffers = GetMatchedBuffers(func); - ICHECK(args.size() == param_buffers.size()); - for (size_t i = 0; i < args.size(); i++) { + // Last var could be a resource handle that does not have a Buffer + ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size()); + for (size_t i = 0; i < param_buffers.size(); i++) { auto arg = args[i]; auto param_buf = param_buffers[i]; // If tir.allocates are passed in to functions diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc new file mode 100644 index 0000000000000..516ddd1a241bf --- /dev/null +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { + +/*! \brief Assign PoolInfo objects to allocate that does not have any. + * The schedulers have the oppurtunity to assign PoolInfo objects to + * allocate nodes. However, each allocate node is expected to have + * at least one PoolInfo node assigned to it. If it was not the case, + * this Pass will assign all PoolInfo objects that the target could + * access.*/ +class PoolInfoAssigner : public StmtExprMutator { + public: + explicit PoolInfoAssigner(const IRModule& module) { + PrimFunc main_func = + Downcast(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + ICHECK(main_func.defined()) << "main function is not in the module"; + Optional target_host = main_func->GetAttr(tvm::attr::kTarget); + ICHECK(target_host) << "main function does not have a target attr"; + Array pool_infos = + module->GetAttr>(tvm::attr::kPoolInfoIRModuleAttr) + .value_or({usmp::PoolInfo("global_workspace", + {{target_host.value(), usmp::kTargetPoolReadWriteAccess}}, + usmp::kUnrestrictedPoolSizeHint, Bool(true))}); + for (const usmp::PoolInfo& pool_info : pool_infos) { + for (const auto& kv : pool_info->target_access) { + Target tgt = kv.first; + if (target_pool_infos_.find(tgt) == target_pool_infos_.end()) { + target_pool_infos_.Set(tgt, Array()); + } + Array pool_info_arr = target_pool_infos_[tgt]; + pool_info_arr.push_back(pool_info); + target_pool_infos_.Set(tgt, pool_info_arr); + } + } + mod_ = module->ShallowCopy(); + } + + IRModule operator()(); + + private: + Stmt VisitStmt_(const AllocateNode* op) override; + + IRModule mod_; + Map> target_pool_infos_; + PrimFunc func_; +}; + +Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { + Optional tgt = func_->GetAttr(tvm::attr::kTarget).value(); + ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; + Map annotations = Map(op->annotations); + if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { + annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()]); + } + Stmt body = VisitStmt(op->body); + auto allocate = + Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body, annotations); + return allocate; +} + +IRModule PoolInfoAssigner::operator()() { + for (const auto& kv : mod_->functions) { + GlobalVar gv = kv.first; + if (kv.second->IsInstance()) { + func_ = Downcast(kv.second); + Stmt body = this->VisitStmt(func_->body); + PrimFunc new_prim_func = + PrimFunc(func_->params, body, func_->ret_type, func_->buffer_map, func_->attrs); + mod_->Update(gv, new_prim_func); + } + } + return mod_; +} + +namespace transform { + +tvm::transform::Pass AssignPoolInfo() { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return PoolInfoAssigner(m)(); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.AssignPoolInfo", {}); +} + +TVM_REGISTER_GLOBAL("tir.usmp.transform.AssignPoolInfo").set_body_typed(AssignPoolInfo); + +} // namespace transform + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 5ebf3c557b06a..cd797681d4743 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -120,12 +121,15 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { /*! \brief This is a helper to append the pool args to * the callsite of the function. */ - Array AppendPoolParamsToArgs(const Array& args); + Array AppendPoolParamsToArgs(Array args, const PrimFunc& func); /*! \brief Some arguments that used to be Allocate nodes * should be replaced by Let nodes in the pass that loads * the space from a pool variable. */ Array ReplaceAllocateArgsWithLetArgs(const Array& args); + /*! \brief Obtain a resource handle if its there + */ + Optional GetResourceHandle(const PrimFunc& func); /*! \brief The tir::Var map to PoolInfo objects */ Map primfunc_args_to_pool_info_map_; @@ -151,10 +155,23 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { std::unordered_set visited_primfuncs; }; +Optional PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& func) { + if (func->buffer_map.find(func->params.back()) == func->buffer_map.end()) { + return func->params.back(); + } + return Optional(); +} + PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( const PrimFunc& original_func) { ScopeInfo si; + + Optional resource_handle = GetResourceHandle(original_func); si.params = original_func->params; + if (resource_handle) { + si.params.pop_back(); + ICHECK(si.params.size() == original_func->params.size() - 1); + } si.buffer_map = original_func->buffer_map; Map ret; for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) { @@ -179,6 +196,9 @@ PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::Upda si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, 16, 1, BufferType::kDefault)); } + if (resource_handle) { + si.params.push_back(resource_handle.value()); + } return si; } @@ -199,7 +219,7 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( PrimFunc ret = PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); if (!emit_tvmscript_printable_) { - return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); + ret = WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); } visited_primfuncs.insert(ret); return ret; @@ -207,9 +227,14 @@ PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( return original_primfunc; } -Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs( - const Array& args) { +Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(Array args, + const PrimFunc& func) { Array new_args; + PrimExpr resource_handle_arg; + if (args.size() == func->params.size() + 1) { + resource_handle_arg = args.back(); + args.pop_back(); + } for (const auto& arg : args) { new_args.push_back(VisitExpr(arg)); } @@ -219,6 +244,9 @@ Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs( Buffer buffer_var = top_scope.buffer_map[pool_var]; new_args.push_back(buffer_var->data); } + if (resource_handle_arg.defined()) { + new_args.push_back(resource_handle_arg); + } return new_args; } @@ -240,12 +268,13 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { String func_name = Downcast(op->args[0])->value; Array new_args; - if (module_->ContainGlobalVar(func_name)) { + if (module_->ContainGlobalVar(func_name) && + module_->Lookup(func_name)->IsInstance()) { GlobalVar gv = module_->GetGlobalVar(func_name); PrimFunc func = Downcast(module_->Lookup(gv)); PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); module_->Update(gv, prim_func); - new_args = AppendPoolParamsToArgs(op->args); + new_args = AppendPoolParamsToArgs(op->args, prim_func); new_args = ReplaceAllocateArgsWithLetArgs(new_args); } else { new_args = ReplaceAllocateArgsWithLetArgs(op->args); @@ -255,8 +284,7 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { if (op->op->IsInstance()) { PrimFunc func = Downcast(op->op); PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); - Array new_args = AppendPoolParamsToArgs(op->args); - new_args = AppendPoolParamsToArgs(new_args); + Array new_args = AppendPoolParamsToArgs(op->args, prim_func); new_args = ReplaceAllocateArgsWithLetArgs(new_args); return Call(op->dtype, prim_func, new_args); } @@ -329,8 +357,7 @@ IRModule PoolAllocationToOffsetConverter::operator()() { namespace transform { tvm::transform::Pass ConvertPoolAllocationsToOffsets( - const Map& pool_allocations, - Bool emit_tvmscript_printable = Bool(false)) { + const Map& pool_allocations, Bool emit_tvmscript_printable) { auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { return Downcast(PoolAllocationToOffsetConverter( m, pool_allocations, emit_tvmscript_printable->value != 0)()); diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc new file mode 100644 index 0000000000000..5a2125077566b --- /dev/null +++ b/src/tir/usmp/unified_static_memory_planner.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/unified_static_memory_planner.cc + * \brief This is the pass that integrates the USMP passes to + * a single composite pass. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { + +TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String); + +namespace tir { +namespace usmp { + +static constexpr const char* kDefaultAlgo = "greedy_by_size"; + +static std::unordered_map( + const Array&, const Integer&)>> + algorithms{{"greedy_by_size", algo::GreedyBySize}, + {"greedy_by_conflicts", algo::GreedyByConflicts}}; + +IRModule PlanMemory(const IRModule& mod, String algo) { + VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod); + PrimFunc main_func = Downcast(mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, mod); + Array buffer_info_arr = + CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts); + CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo + << "is not defined. Please define it in the above algorithms map."; + Map buffer_info_pool_allocations = + algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure); + Map stmt_pool_allocations = AssignStmtPoolAllocations( + buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations); + IRModule ret = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(mod); + tir::PrimFunc tir_main_func = + Downcast(ret->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + Optional> allocated_pool_infos = + tir_main_func->GetAttr>(tvm::attr::kPoolArgs); + if (allocated_pool_infos) { + for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { + VLOG(1) << "pool_size = " << allocated_pool_info->allocated_size; + } + } + return ret; +} + +} // namespace usmp + +namespace transform { + +tvm::transform::Pass UnifiedStaticMemoryPlanner() { + auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo)); + return Downcast( + usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo)))); + }; + + return tvm::transform::Sequential( + {tvm::tir::usmp::transform::AssignPoolInfo(), + tvm::transform::CreateModulePass(usmp_main_pass_func, 0, + "tir.transform.UnifiedStaticMemoryPlanner", {})}); +} + +TVM_REGISTER_GLOBAL("tir.transform.UnifiedStaticMemoryPlanner") + .set_body_typed(UnifiedStaticMemoryPlanner); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 14b3d26641a3e..1fff70f5892e9 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -24,7 +24,11 @@ #include #include +#include +#include +#include #include +#include #include namespace tvm { @@ -88,11 +92,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n memory_pressure=" << node->memory_pressure << ")"; }); -PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes) { +PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, + Bool is_internal) { auto poolinfo_node = make_object(); poolinfo_node->pool_name = pool_name; poolinfo_node->size_hint_bytes = size_hint_bytes; poolinfo_node->target_access = target_access; + poolinfo_node->is_internal = is_internal; data_ = std::move(poolinfo_node); } @@ -195,6 +201,66 @@ Integer CalculateExtentsSize(const AllocateNode* op) { return Integer(num_elements * element_size_bytes); } +class ModuleWorkspaceSizeCalculator : public StmtExprVisitor { + public: + explicit ModuleWorkspaceSizeCalculator(const IRModule& module) : mod_(module) { + for (const auto& gv_func : mod_->functions) { + functions_.Set(gv_func.first->name_hint, Downcast(gv_func.second)); + } + main_func_ = Downcast(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix)); + ICHECK(main_func_.defined()) << "main function is not in the module"; + Optional target_host = main_func_->GetAttr(tvm::attr::kTarget); + ICHECK(target_host) << "main function does not have a target attr"; + target_host_ = target_host.value(); + } + + Integer operator()() { + UpdateWorkspaceData(main_func_); + return Integer(max_workspace_size); + } + + private: + void UpdateWorkspaceData(const PrimFunc& func) { + Target tgt = func->GetAttr(tvm::attr::kTarget).value_or(target_host_); + Integer workspace_byte_alignment = + tgt->GetAttr("workspace-byte-alignment").value_or(16); + Integer workspace_req = CalculateWorkspaceBytes(func, workspace_byte_alignment); + if (workspace_req) { + current_workspace_size_ += workspace_req->value; + } + if (max_workspace_size < current_workspace_size_) { + max_workspace_size = current_workspace_size_; + } + this->VisitStmt(func->body); + if (workspace_req) { + current_workspace_size_ -= workspace_req->value; + } + } + + void VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::call_extern())) { + PrimFunc func = functions_.at(Downcast(op->args[0])->value); + UpdateWorkspaceData(func); + } else if (op->op->IsInstance()) { + PrimFunc func = Downcast(op->op); + UpdateWorkspaceData(func); + } else { + StmtExprVisitor::VisitExpr_(op); + } + } + + IRModule mod_; + Target target_host_; + PrimFunc main_func_; + Map functions_; + size_t current_workspace_size_ = 0; + size_t max_workspace_size = 0; +}; + +Integer CalculateModuleWorkspaceSize(const IRModule& mod) { + return ModuleWorkspaceSizeCalculator(mod)(); +} + TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") .set_body_typed([](Map buffer_info_map) { return (CreateArrayBufferInfo(buffer_info_map)); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index d5a4c91a3c431..ff3641cd6982b 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -107,18 +107,22 @@ TEST(BuildModule, Heterogeneous) { auto elemwise_sub = compute( C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); - With cuda_scope(target_cuda); - auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + auto fcreate_s1 = [=]() { + With cuda_scope(target_cuda); + return topi::cuda::schedule_injective(target_cuda, {elemwise_add}); + }; - With llvm_scope(target_llvm); - auto s2 = create_schedule({elemwise_sub->op}); + auto fcreate_s2 = [=]() { + With llvm_scope(target_llvm); + return create_schedule({elemwise_sub->op}); + }; auto args1 = Array({A, B, elemwise_add}); auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = LowerSchedule(s1, args1, "elemwise_add", binds); - auto lowered_s2 = LowerSchedule(s2, args2, "elemwise_sub", binds); + auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds); + auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds); Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; auto module = build(inputs, Target()); diff --git a/tests/cpp/runtime/hexagon_buffer.cc b/tests/cpp/runtime/hexagon_buffer.cc index 9a07056f46f1a..6b777bbd93021 100644 --- a/tests/cpp/runtime/hexagon_buffer.cc +++ b/tests/cpp/runtime/hexagon_buffer.cc @@ -259,3 +259,17 @@ TEST(HexagonBuffer, external) { Optional invalid("invalid"); EXPECT_THROW(HexagonBuffer hb_vtcm(data.data(), data.size(), invalid), InternalError); } + +TEST(HexagonBuffer, external_copy) { + std::vector data1{0, 1, 2, 3, 4, 5, 6, 7}; + Optional global("global"); + HexagonBuffer hb_ext(data1.data(), data1.size(), global); + + std::vector data2{0, 1, 2, 3, 4, 5, 6, 7}; + EXPECT_THROW(hb_ext.CopyTo(data2.data(), data2.size()), InternalError); + EXPECT_THROW(hb_ext.CopyFrom(data2.data(), data2.size()), InternalError); + + HexagonBuffer hb(8 /* nbytes */, 8 /* alignment */, global); + EXPECT_THROW(hb.CopyFrom(hb_ext, 8), InternalError); + EXPECT_THROW(hb_ext.CopyFrom(hb, 8), InternalError); +} diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index d8c559cec6e05..e9eb6fb3a1453 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -124,22 +124,118 @@ def make_model( @tvm.testing.requires_cmsisnn -@pytest.mark.parametrize("ifm_shape", [(1, 25, 25, 12), (1, 64, 100, 4)]) -@pytest.mark.parametrize("kernel_size", [(5, 5)]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1))]) @pytest.mark.parametrize("relu_type", ["RELU"]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize( "input_zero_point, input_scale, kernel_scale, out_channels", [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], ) -def test_conv2d_int8( - ifm_shape, - kernel_size, +def test_conv2d_symmetric_padding_int8( + padding, + enable_bias, + relu_type, + input_zero_point, + input_scale, + kernel_scale, + out_channels, +): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + ifm_shape = (1, 64, 100, 4) + kernel_size = (3, 3) + strides = (1, 1) + dilation = (1, 1) + dtype = "int8" + groups = 1 + weight_format = "HWIO" + kernel_h = kernel_size[0] + kernel_w = kernel_size[1] + kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels) + kernel_zero_point = 0 + in_min, in_max = get_range_for_dtype_str(dtype) + + output_scale, output_zero_point = get_conv2d_qnn_params( + kernel_shape, + input_scale, + input_zero_point, + kernel_scale, + kernel_zero_point, + dtype, + dtype, + dtype, + ) + + model, params = make_model( + ifm_shape, + kernel_shape, + input_zero_point, + input_scale, + kernel_zero_point, + kernel_scale, + output_zero_point, + output_scale, + padding, + strides, + dilation, + groups, + dtype, + dtype, + out_channels, + weight_format, + enable_bias, + relu_type, + ) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsis-nn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsis-nn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + rng = np.random.default_rng(12345) + inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=params, + output_tolerance=1, + ), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("relu_type", ["RELU", "NONE"]) +@pytest.mark.parametrize("enable_bias", [True, False]) +@pytest.mark.parametrize( + "input_zero_point, input_scale, kernel_scale, out_channels", + [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)], +) +def test_conv2d_asymmetric_padding_int8( padding, - strides, - dilation, enable_bias, relu_type, input_zero_point, @@ -151,6 +247,10 @@ def test_conv2d_int8( use_unpacked_api = True test_runner = AOT_CORSTONE300_RUNNER + ifm_shape = (1, 25, 25, 12) + kernel_size = (5, 5) + strides = (2, 2) + dilation = (1, 1) dtype = "int8" groups = 1 weight_format = "HWIO" diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 7f504fcc1ed7b..bc2cc80f362d4 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -29,7 +29,7 @@ requires_cudnn = pytest.mark.skipif( - tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None, + tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True) is None, reason="CuDNN is not enabled", ) @@ -307,13 +307,5 @@ def conv_output_shape_kwargs(request): return request.param -@tvm.testing.requires_gpu -@requires_cudnn -def test_conv_output_shape(conv_output_shape_kwargs): - shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs) - shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs) - assert shape_from_cudnn == shape_from_python - - if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/contrib/test_ethosu/cascader/conftest.py b/tests/python/contrib/test_ethosu/cascader/conftest.py index 58ffb51a59675..eacf57c251a84 100644 --- a/tests/python/contrib/test_ethosu/cascader/conftest.py +++ b/tests/python/contrib/test_ethosu/cascader/conftest.py @@ -29,7 +29,11 @@ from tvm.relay.testing import run_opt_pass from .infra import create_te_graph - from ..infra import make_ethosu_conv2d + from ..infra import ( + make_ethosu_conv2d, + make_ethosu_depthwise_conv2d, + make_ethosu_binary_elementwise, + ) def make_TwoConv2DWithSliceTE(): def _get_func(): @@ -71,3 +75,62 @@ def _get_func(): @pytest.fixture def TwoConv2DWithSliceTE(): return make_TwoConv2DWithSliceTE() + + def make_MobileNetv2DiamondTE(): + def _get_func(): + ifm = relay.var("ifm", shape=(1, 56, 56, 96), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm=ifm, + ifm_channels=96, + ofm_channels=24, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + conv2 = make_ethosu_conv2d( + ifm=conv1, + ifm_channels=24, + ofm_channels=144, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + depth1 = make_ethosu_depthwise_conv2d( + ifm=conv2, + channels=144, + kernel_shape=(3, 3), + padding=(1, 1, 1, 1), + strides=(1, 1), + dilation=(1, 1), + ) + conv3 = make_ethosu_conv2d( + ifm=depth1, + ifm_channels=144, + ofm_channels=24, + kernel_shape=(1, 1), + padding=(0, 0, 0, 0), + strides=(1, 1), + dilation=(1, 1), + ) + add1 = make_ethosu_binary_elementwise( + ifm=conv1, + ifm2=conv3, + ifm_channels=24, + ifm2_channels=24, + operator_type="ADD", + ofm_dtype="int8", + ) + func = relay.Function(relay.analysis.free_vars(add1), add1) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + te_graph, const_dict = create_te_graph(func) + sch = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + return sch, te_graph, const_dict + + @pytest.fixture + def MobileNetv2DiamondTE(): + return make_MobileNetv2DiamondTE() diff --git a/tests/python/contrib/test_ethosu/cascader/infra.py b/tests/python/contrib/test_ethosu/cascader/infra.py index baf398dc36022..5f41dce30147f 100644 --- a/tests/python/contrib/test_ethosu/cascader/infra.py +++ b/tests/python/contrib/test_ethosu/cascader/infra.py @@ -18,6 +18,8 @@ from tvm import relay from tvm.relay.backend.contrib.ethosu.tir.compiler import extract_constants, lower_to_te +import numpy as np + def create_te_graph(func): func, consts = extract_constants(func) @@ -25,3 +27,100 @@ def create_te_graph(func): func = relay.transform.InferType()(mod)["main"] te_graph = lower_to_te(func) return te_graph, consts + + +def make_matrices( + op_type, kernel, stride, padding, ifm_layout, ofm_layout, dilation=(1, 1), ifm_channels=1 +): + kernel_h, kernel_w = kernel + stride_h, stride_w = stride + dilation_h, dilation_w = dilation + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + if op_type == "ethosu_conv2d": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, ifm_channels], + [0, 0, 0, 0, 1], + ] + elif op_type == "ethosu_depthwise_conv2d": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, kernel_h], + [0, 0, 0, 0, kernel_w], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ] + elif op_type == "ethosu_pooling": + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], + [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + weight_matrix = [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + scale_bias_matrix = [ + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 10], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + weight_matrix = np.matmul(weight_matrix, nhcwb16_to_nhwc).tolist() + scale_bias_matrix = np.matmul(scale_bias_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + ifm_offset = ( + [0, -padding[0], -padding[1], 0] + if ifm_layout == "NHWC" + else [0, -padding[0], 0, -padding[1], 0] + ) + weight_offset = [0, 0, 0, 0] + scale_bias_offset = [0, 0] + return ( + ifm_matrix, + ifm_offset, + weight_matrix, + weight_offset, + scale_bias_matrix, + scale_bias_offset, + ) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py new file mode 100644 index 0000000000000..bb1be7b8e251d --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_binary_elementwise_matcher.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.binary_elementwise import ( + match_ethosu_binary_elementwise, + binary_elementwise_compute, +) + + +def _make_matrices(broadcast, ifm_layout, ifm2_layout, ofm_layout): + broadcast_h, broadcast_w, broadcast_c = broadcast + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + ifm2_matrix = [ + [1, 0, 0, 0, 0], + [0, (1 - broadcast_h), 0, 0, broadcast_h], + [0, 0, (1 - broadcast_w), 0, broadcast_w], + [0, 0, 0, (1 - broadcast_c), broadcast_c], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + ifm2_matrix = np.matmul(ifm2_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + if ifm2_layout == "NHCWB16": + ifm2_matrix = np.matmul(nhwc_to_nhcwb16, ifm2_matrix).tolist() + + return (ifm_matrix, ifm2_matrix) + + +@pytest.mark.parametrize( + "ofm_shape", + [ + [1, 12, 15, 128], + [1, 16, 16, 16], + [1, 1, 1, 1024], + [1, 73, 51, 20], + [1, 124, 172, 5], + ], +) +@pytest.mark.parametrize("ifm2_broadcast", [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ifm2_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("op_type", ["MUL", "ADD", "MIN"]) +def test_ethosu_binary_elementwise_matcher( + ofm_shape, ifm2_broadcast, ifm_layout, ifm2_layout, ofm_layout, op_type +): + ifm_shape = ofm_shape.copy() + ifm2_shape = [1] + [1 if (b == 1) else a for a, b in zip(ofm_shape[1:], ifm2_broadcast)] + ifm_channels = ifm_shape[3] + ifm2_channels = ifm2_shape[3] + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + broadcast = [1 if a == 1 else 0 for a in ifm2_shape[1:]] + if ifm_layout == "NHCWB16": + ifm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ifm2_layout == "NHCWB16": + ifm2_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm2_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ofm_layout == "NHCWB16": + ofm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ofm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + order = [1, 2, 4, 3, 0] + else: + order = [1, 2, 3, 4] + + ifm = te.placeholder(ifm_shape, dtype="int8") + ifm2 = te.placeholder(ifm2_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = binary_elementwise_compute( + ifm=ifm, + ifm2=ifm2, + lut=lut, + operator_type=op_type, + ifm_scale=1, + ifm_zero_point=0, + ifm2_scale=1, + ifm2_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ifm_channels=ifm_channels, + ifm2_channels=ifm2_channels, + reversed_operands=False, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + ifm_layout=ifm_layout, + ifm2_layout=ifm2_layout, + ofm_layout=ofm_layout, + ofm_dtype="int8", + ) + ifm_propagator = out.op.attrs["ifm_propagator"] + ifm2_propagator = out.op.attrs["ifm2_propagator"] + + offset = [0] * len(ofm_shape) + stripes = [0] * len(ofm_shape) + output_stripe_config = cs.StripeConfig(ofm_shape, ofm_shape, ofm_shape, order, stripes, offset) + + (ifm_transform, ifm2_transform) = _make_matrices( + broadcast, + ifm_layout, + ifm2_layout, + ofm_layout, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_binary_elementwise(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 2 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[1].transform == ifm2_transform + + propagated_ifm = ifm_propagator.propagate(output_stripe_config).shape + propagated_ifm2 = ifm2_propagator.propagate(output_stripe_config).shape + + # Layout conversions will align the propagated IFMs to the brick, i.e. 16 + # so the expected ifm(2)_shape needs to be rounded up to 16 + if ifm_layout != ofm_layout: + assert ifm_shape[:-1] == propagated_ifm[:-1] + assert ((ifm_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm[-1] + else: + assert ifm_shape == propagated_ifm + + if ifm2_layout != ofm_layout: + assert ifm2_shape[:-1] == propagated_ifm2[:-1] + assert ((ifm2_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm2[-1] + else: + assert ifm2_shape == propagated_ifm2 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py new file mode 100644 index 0000000000000..3f3935fff1f9a --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +import tvm.contrib.ethosu.cascader as cs + +from .infra import make_matrices + + +@pytest.mark.parametrize( + "test_id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", + [ + # Conv2D + ( + 0, + "ethosu_conv2d", + "NONE", + (34, 19), + (2, 2), + (1, 1), + (0, 0, 0, 0), + (1, 266, 111, 15), + (1, 117, 47, 15), + ), + ( + 1, + "ethosu_conv2d", + "NONE", + (14, 14), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 125, 63, 64), + (1, 112, 50, 128), + ), + ( + 2, + "ethosu_conv2d", + "NONE", + (7, 1), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 13, 4, 12), + (1, 4, 4, 511), + ), + ( + 3, + "ethosu_conv2d", + "NONE", + (5, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 96, 16, 276), + (1, 92, 12, 16), + ), + ( + 4, + "ethosu_conv2d", + "NONE", + (5, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 96, 16, 276), + (1, 92, 12, 1), + ), + ( + 5, + "ethosu_conv2d", + "NONE", + (3, 3), + (1, 1), + (2, 2), + (0, 0, 0, 0), + (1, 62, 94, 32), + (1, 58, 90, 16), + ), + # Depthwise Conv2D + ( + 6, + "ethosu_depthwise_conv2d", + "NONE", + (3, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 77, 23, 18), + (1, 75, 19, 18), + ), + ( + 7, + "ethosu_depthwise_conv2d", + "NONE", + (3, 3), + (2, 2), + (1, 1), + (1, 1, 1, 1), + (1, 25, 10, 276), + (1, 13, 5, 276), + ), + # Pooling + ( + 8, + "ethosu_pooling", + "NONE", + (13, 5), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 13, 5, 276), + (1, 1, 1, 276), + ), + ( + 9, + "ethosu_pooling", + "NONE", + (7, 3), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 317, 14, 21), + (1, 156, 12, 21), + ), + ], +) +@pytest.mark.parametrize( + "layouts", + [ + ("NHWC", "NHWC"), + ("NHCWB16", "NHCWB16"), + ("NHWC", "NHCWB16"), + ("NHCWB16", "NHWC"), + ], +) +@pytest.mark.parametrize( + "acc_config, expected_block_configs", + [ + ( + "ethos-u55-32", + [ + # Conv2D + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 16, 1, 4, 4)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + # Depthwise Conv2D + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), + # Pooling + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), + ], + ), + ( + "ethos-u55-64", + [ + # Conv2D + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 8, 4, 16), (1, 8, 1, 4, 16)), + ((1, 10, 6, 8), (1, 16, 1, 4, 8)), + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), + # Depthwise Conv2D + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), + # Pooling + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), + ], + ), + ( + "ethos-u55-128", + [ + # Conv2D + ((1, 7, 6, 16), (1, 7, 1, 6, 16)), + ((1, 5, 8, 16), (1, 5, 1, 8, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 16, 4, 16), (1, 16, 1, 4, 16)), + ((1, 8, 12, 8), (1, 8, 1, 12, 8)), + ((1, 10, 6, 16), (1, 10, 1, 6, 16)), + # Depthwise Conv2D + ((1, 7, 10, 16), (1, 7, 1, 10, 16)), + ((1, 7, 6, 16), (1, 7, 1, 6, 16)), + # Pooling + ((1, 1, 2, 80), (1, 1, 5, 2, 16)), + ((1, 10, 6, 16), (1, 10, 1, 6, 16)), + ], + ), + ( + "ethos-u55-256", + [ + # Conv2D + ((1, 14, 8, 16), (1, 14, 1, 8, 16)), + ((1, 16, 8, 16), (1, 16, 1, 8, 16)), + ((1, 4, 4, 16), (1, 4, 1, 4, 16)), + ((1, 32, 4, 16), (1, 10, 12, 16), (1, 32, 1, 4, 16), (1, 10, 1, 12, 16)), + ((1, 20, 12, 8), (1, 20, 1, 12, 8)), + ((1, 12, 10, 16), (1, 12, 1, 10, 16)), + # Depthwise Conv2D + ((1, 8, 20, 16), (1, 8, 1, 20, 16)), + ((1, 14, 6, 16), (1, 14, 1, 6, 16)), + # Pooling + ((1, 2, 2, 48), (1, 2, 3, 2, 16)), + ((1, 10, 12, 16), (1, 10, 1, 12, 16)), + ], + ), + ], +) +def test_best_block_config( + test_id, + op_type, + activation, + kernel, + stride, + dilation, + padding, + in_shape, + out_shape, + layouts, + acc_config, + expected_block_configs, +): + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( + op_type, kernel, stride, padding, layouts[0], layouts[1], dilation, in_shape[3] + ) + + ofm_channels = out_shape[3] + ifm_channels = in_shape[3] + + if layouts[0] == "NHCWB16": + in_shape = [ + int(math.ceil(n)) for n in np.matmul(nhwc_to_nhcwb16, in_shape + (1,)).tolist()[:-1] + ] + if layouts[1] == "NHCWB16": + out_shape = [ + int(math.ceil(n)) for n in np.matmul(nhwc_to_nhcwb16, out_shape + (1,)).tolist()[:-1] + ] + + propagator = cs.Propagator(ifm_matrix, ifm_offset) + weight_propagator = cs.Propagator(weight_matrix, weight_offset) + + subkernels = ((kernel[0] + 7) // 8) * ((kernel[1] + 7) // 8) + + op_attrs = { + "op": op_type, + "activation": activation, + "stride_h": stride[0], + "stride_w": stride[1], + "dilation_h": dilation[0], + "dilation_w": dilation[1], + } + + device_config = cs.EthosuDeviceConfig(acc_config) + block_configs = device_config.get_valid_block_configs( + propagator, + op_attrs, + out_shape, + ofm_channels, + ifm_channels, + layouts[1], + layouts[0], + "int8", + "int8", + kernel[0], + kernel[1], + ) + + output_quantum = [1, 1, 2, 8] + if layouts[1] == "NHCWB16": + output_quantum = [1, 1, 1, 2, 8] + + # Create EthosUPart + te_subgraph = cs.TESubgraph([], None) + part = cs.EthosuPart( + te_subgraph, + [propagator, weight_propagator], + output_quantum, + subkernels, + block_configs, + 1, + ) + + order = [1, 2, 3, 4] if layouts[1] == "NHCWB16" else [1, 2, 4, 3, 0] + stripes = [1] * len(output_quantum) + offset = [0] * len(output_quantum) + + stripe_config = cs.StripeConfig(out_shape, out_shape, out_shape, order, stripes, offset) + + block = part.get_block_config(stripe_config) + block_shape = tuple(int(a) for a in block.output_shape) + + assert block_shape in expected_block_configs[test_id] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py index 79a139594b3ec..5bd2be49f6204 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_conv2d_matcher.py @@ -22,71 +22,7 @@ import tvm.contrib.ethosu.cascader as cs from tvm.relay.backend.contrib.ethosu.te.convolution import match_ethosu_conv2d, conv2d_compute -import numpy as np - - -def _make_matrices(kernel, stride, dilation, padding, ifm_channels, ifm_layout, ofm_layout): - kernel_h, kernel_w = kernel - stride_h, stride_w = stride - dilation_h, dilation_w = dilation - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - nhwc_to_nhcwb16 = [ - [1, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 0, 1 / 16, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 16], - [0, 0, 0, 0, 1], - ] - nhcwb16_to_nhwc = [ - [1, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0], - [0, 0, 16, 0, 0, 0], - [0, 0, 0, 0, 0, 1], - ] - ifm_matrix = [ - [1, 0, 0, 0, 0], - [0, stride_h, 0, 0, (dilated_kernel_h - stride_h)], - [0, 0, stride_w, 0, (dilated_kernel_w - stride_w)], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] - weight_matrix = [ - [0, 0, 0, 1, 0], - [0, 0, 0, 0, kernel_h], - [0, 0, 0, 0, kernel_w], - [0, 0, 0, 0, ifm_channels], - [0, 0, 0, 0, 1], - ] - scale_bias_matrix = [ - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 10], - [0, 0, 0, 0, 1], - ] - if ofm_layout == "NHCWB16": - ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() - weight_matrix = np.matmul(weight_matrix, nhcwb16_to_nhwc).tolist() - scale_bias_matrix = np.matmul(scale_bias_matrix, nhcwb16_to_nhwc).tolist() - if ifm_layout == "NHCWB16": - ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() - - ifm_offset = ( - [0, -padding[0], -padding[1], 0] - if ifm_layout == "NHWC" - else [0, -padding[0], 0, -padding[1], 0] - ) - weight_offset = [0, 0, 0, 0] - scale_bias_offset = [0, 0] - return ( - ifm_matrix, - ifm_offset, - weight_matrix, - weight_offset, - scale_bias_matrix, - scale_bias_offset, - ) +from .infra import make_matrices @pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) @@ -137,17 +73,19 @@ def test_ethosu_conv2d_matcher( weight_offset, scale_bias_transform, scale_bias_offset, - ) = _make_matrices( + ) = make_matrices( + "ethosu_conv2d", kernel, stride, - dilation, padding, - ifm_channels, ifm_layout, ofm_layout, + dilation, + ifm_channels, ) - part = match_ethosu_conv2d(out) + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_conv2d(out, device_config) assert isinstance(part, cs.EthosuPart) assert len(part.propagators) == 3 diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py new file mode 100644 index 0000000000000..c2c45b6524f1b --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_depthwise2d_matcher.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.depthwise import ( + match_ethosu_depthwise_conv2d, + depthwise_conv2d_compute, +) +from .infra import make_matrices + + +@pytest.mark.parametrize("kernel", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("dilation", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_depthwise2d_matcher(kernel, stride, dilation, padding, ifm_layout, ofm_layout): + ofm_channels = 57 + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ofm_channels) + else: + ifm_shape = (1, 12, 1 + ((ofm_channels - 1) // 16), 15, 16) + kernel_h, kernel_w = kernel + ifm = te.placeholder(ifm_shape, dtype="int8") + weight = te.placeholder((ofm_channels, kernel_h, kernel_w, 1), dtype="int8") + scale_bias = te.placeholder((ofm_channels, 10), dtype="uint8") + lut = te.placeholder((), dtype="uint8") + out = depthwise_conv2d_compute( + ifm=ifm, + weight=weight, + scale_bias=scale_bias, + lut=lut, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + weight_zero_point=0, + strides=stride, + padding=padding, + dilation=dilation, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ofm_dtype=ifm.dtype, + ) + ( + ifm_transform, + ifm_offset, + weight_transform, + weight_offset, + scale_bias_transform, + scale_bias_offset, + ) = make_matrices( + "ethosu_depthwise_conv2d", + kernel, + stride, + padding, + ifm_layout, + ofm_layout, + dilation, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_depthwise_conv2d(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 3 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + assert part.propagators[1].transform == weight_transform + assert part.propagators[1].offset == weight_offset + assert part.propagators[2].transform == scale_bias_transform + assert part.propagators[2].offset == scale_bias_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py index a3639ba030775..1eebbe40c1b32 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_inline_matcher.py @@ -37,7 +37,8 @@ def test_ethosu_inline_matcher(): ] ifm_offset = [0, 0, 0] - part = match_ethosu_inline(out) + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_inline(out, device_config) assert isinstance(part, cs.InlinePart) assert len(part.propagators) == 1 diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py index ef449a49976ca..fca136cf4ab4c 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part.py @@ -18,28 +18,40 @@ pytest.importorskip("ethosu.vela") -import tvm.contrib.ethosu.cascader as pl +import tvm.contrib.ethosu.cascader as cs +from tvm.contrib.ethosu.cascader.graph import BufferMode from tvm.contrib.ethosu.cascader.parts import EthosuPart def test_ethosu_part(): - te_subgraph = pl.TESubgraph([], None) + te_subgraph = cs.TESubgraph([], None) output_quantum = [1, 2, 2, 8] - quantum_cycles = 32 - propagator = pl.Propagator( + propagator = cs.Propagator( [[1, 0, 0, 0, 2], [0, 1, 0, 0, 2], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], [0, 0, 0, 0], ) - stripe_config = pl.StripeConfig( + stripe_config = cs.StripeConfig( [1, 4, 4, 16], [1, 64, 72, 96], [1, 4, 4, 16], [1, 2, 3, 4], [1, 16, 13, 6], [0, 0, 0, 0] ) + subkernels = 3 - part = EthosuPart(te_subgraph, [propagator], output_quantum, quantum_cycles) + valid_block_configs = [cs.BlockConfig([1, 2, 4, 16], 15000, 7500)] + + part = EthosuPart( + te_subgraph, + [propagator], + output_quantum, + subkernels, + valid_block_configs, + 1, + ) + input_tensor = cs.Tensor(shape=[1, 66, 74, 16], dtype="int8") + part.set_input(0, input_tensor) assert part.get_stripe_align_hint() == output_quantum # Check that the performance model runs, don't verify output - part.get_performance_info(stripe_config, False) - part.get_performance_info(stripe_config, True) + part.get_performance_info(stripe_config, BufferMode.ROLLING) + part.get_performance_info(stripe_config, BufferMode.RECOMPUTE) if __name__ == "__main__": diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py new file mode 100644 index 0000000000000..ba6346afa5d54 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_part_performance.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +from functools import reduce +import numpy as np +import math + +import tvm.contrib.ethosu.cascader as cs +from tvm.contrib.ethosu.cascader.device_config import _Shape + +from .infra import make_matrices + + +@pytest.mark.parametrize( + "acc_config, expected", + [ + ("ethos-u55-256", (1, 0.125, 0.75, 0.375, 0.75)), + ("ethos-u55-128", (1, 0.25, 1.5, 0.75, 0.75)), + ("ethos-u55-64", (1, 0.5, 3, 1.5, 1.5)), + ("ethos-u55-32", (2, 1, 6, 3, 3)), + ], +) +def test_device_config_cycles(acc_config, expected): + device_config = cs.EthosuDeviceConfig(acc_config) + + conv_type = "ethosu_conv2d" + conv_str = None + conv_ifm_dtype = "int8" + conv_ofm_dtype = "int8" + conv_activation = "LUT" + conv_cycles = device_config._get_output_cycles( + conv_type, conv_str, conv_ifm_dtype, conv_ofm_dtype, conv_activation + ) + assert conv_cycles == expected[0] + + pool_type = "ethosu_pooling" + pool_str = "MAX" + pool_ifm_dtype = "int8" + pool_ofm_dtype = "int8" + pool_activation = "NONE" + pool_cycles = device_config._get_output_cycles( + pool_type, pool_str, pool_ifm_dtype, pool_ofm_dtype, pool_activation + ) + assert pool_cycles == expected[1] + + add_type = "ethosu_binary_elementwise" + add_str = "ADD" + add_ifm_dtype = "int8" + add_ofm_dtype = "int8" + add_activation = "NONE" + add_cycles = device_config._get_output_cycles( + add_type, add_str, add_ifm_dtype, add_ofm_dtype, add_activation + ) + assert add_cycles == expected[2] + + mul_type = "ethosu_binary_elementwise" + mul_str = "MUL" + mul_ifm_dtype = "int8" + mul_ofm_dtype = "int8" + mul_activation = "NONE" + mul_cycles = device_config._get_output_cycles( + mul_type, mul_str, mul_ifm_dtype, mul_ofm_dtype, mul_activation + ) + assert mul_cycles == expected[3] + + mul_32_type = "ethosu_binary_elementwise" + mul_32_str = "MUL" + mul_32_ifm_dtype = "int8" + mul_32_ofm_dtype = "int32" + mul_32_activation = "NONE" + mul_32_cycles = device_config._get_output_cycles( + mul_32_type, mul_32_str, mul_32_ifm_dtype, mul_32_ofm_dtype, mul_32_activation + ) + assert mul_32_cycles == expected[4] + + +@pytest.mark.parametrize( + "accelerator, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape, block_shape, input_block_shape, expected", + [ + ( + "ethos-u55-128", + "ethosu_conv2d", + "NONE", + (3, 3), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 16, 16, 96), + (1, 16, 16, 96), + (1, 8, 8, 16), + (1, 10, 10, 32), + 167733, + ), + ( + "ethos-u55-128", + "ethosu_conv2d", + "NONE", + (10, 4), + (2, 1), + (1, 1), + (0, 0, 0, 0), + (1, 58, 13, 1), + (1, 25, 10, 276), + (1, 6, 10, 32), + (1, 18, 14, 8), + 174105, + ), + ( + "ethos-u55-128", + "ethosu_depthwise_conv2d", + "NONE", + (3, 3), + (2, 2), + (1, 1), + (1, 1, 1, 1), + (1, 25, 10, 276), + (1, 13, 5, 276), + (1, 7, 6, 16), + (1, 15, 14, 16), + 17590, + ), + ( + "ethos-u55-128", + "ethosu_depthwise_conv2d", + "NONE", + (4, 9), + (1, 1), + (1, 1), + (0, 0, 0, 0), + (1, 28, 81, 42), + (1, 25, 73, 41), + (1, 4, 16, 16), + (1, 7, 24, 16), + 173414, + ), + ], +) +def test_conv_performance( + accelerator, + op_type, + activation, + kernel, + stride, + dilation, + padding, + in_shape, + out_shape, + block_shape, + input_block_shape, + expected, +): + ifm_channels = in_shape[3] + ifm_matrix, ifm_offset, weight_matrix, weight_offset, _, _ = make_matrices( + op_type, + kernel, + stride, + padding, + "NHWC", + "NHWC", + dilation, + ifm_channels, + ) + + propagator = cs.Propagator(ifm_matrix, ifm_offset) + weight_propagator = cs.Propagator(weight_matrix, weight_offset) + + subkernels = ((kernel[0] + 7) // 8) * ((kernel[1] + 7) // 8) + + device_config = cs.EthosuDeviceConfig(accelerator) + + output_cycles = device_config._get_output_cycles(op_type, "", "int8", "int8", activation) + output_cycles *= reduce(lambda a, b: a * b, block_shape, 1) + is_partkernel = device_config.is_partkernel( + op_type, ifm_channels, "int8", kernel[0] * kernel[1] + ) + compute_cycles = device_config._estimate_compute_cycles_per_block( + op_type, + _Shape(block_shape), + _Shape(input_block_shape), + kernel[0], + kernel[1], + ifm_channels, + "int8", + is_partkernel, + ) + block_configs = [cs.BlockConfig(block_shape, compute_cycles, int(output_cycles))] + + output_quantum = [1, 1, 2, 8] + te_subgraph = cs.TESubgraph([], None) + part = cs.EthosuPart( + te_subgraph, + [propagator, weight_propagator], + output_quantum, + subkernels, + block_configs, + 1, + ) + + stripes = [1] * len(output_quantum) + offset = [0] * len(output_quantum) + order = [1, 2, 3, 4] + + stripe_config = cs.StripeConfig(out_shape, out_shape, out_shape, order, stripes, offset) + + compute_cycles = part.get_performance_info(stripe_config, cs.BufferMode.ROLLING).compute_cycles + tolerance = expected * 0.1 + + assert expected - tolerance <= compute_cycles <= expected + tolerance + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py new file mode 100644 index 0000000000000..6ce8ee9a2986d --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.pooling import match_ethosu_pooling, pooling_compute +from .infra import make_matrices + + +@pytest.mark.parametrize("pool_shape", [(3, 3), (2, 1), (3, 5)]) +@pytest.mark.parametrize("stride", [(1, 1), (2, 1), (3, 2)]) +@pytest.mark.parametrize("padding", [(0, 0, 0, 0), (3, 2, 3, 2), (2, 1, 0, 1)]) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +def test_ethosu_pooling_matcher(pool_shape, stride, padding, ifm_layout, ofm_layout): + ofm_channels = 21 + if ifm_layout == "NHWC": + ifm_shape = (1, 12, 15, ofm_channels) + else: + ifm_shape = (1, 12, 1 + ((ofm_channels - 1) // 16), 15, 16) + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = pooling_compute( + ifm=ifm, + lut=lut, + pooling_type="MAX", + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + pool_shape=pool_shape, + ofm_channels=ofm_channels, + strides=stride, + padding=padding, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + (ifm_transform, ifm_offset, _, _, _, _) = make_matrices( + "ethosu_pooling", + pool_shape, + stride, + padding, + ifm_layout, + ofm_layout, + ) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_pooling(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + assert part.propagators[0].offset == ifm_offset + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py b/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py new file mode 100644 index 0000000000000..0570524e09073 --- /dev/null +++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_unary_elementwise_matcher.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import numpy as np +import math + +from tvm import te +import tvm.contrib.ethosu.cascader as cs +from tvm.relay.backend.contrib.ethosu.te.unary_elementwise import ( + match_ethosu_unary_elementwise, + unary_elementwise_compute, +) + + +def _make_matrices(ifm_layout, ofm_layout): + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + nhcwb16_to_nhwc = [ + [1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 16, 0, 1, -16], + [0, 0, 0, 0, 0, 1], + ] + ifm_matrix = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + if ofm_layout == "NHCWB16": + ifm_matrix = np.matmul(ifm_matrix, nhcwb16_to_nhwc).tolist() + if ifm_layout == "NHCWB16": + ifm_matrix = np.matmul(nhwc_to_nhcwb16, ifm_matrix).tolist() + + return ifm_matrix + + +@pytest.mark.parametrize( + "ofm_shape", + [ + [1, 12, 15, 128], + [1, 16, 16, 16], + [1, 1, 1, 1024], + [1, 53, 91, 7], + [1, 182, 12, 72], + ], +) +@pytest.mark.parametrize("ifm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("ofm_layout", ["NHWC", "NHCWB16"]) +@pytest.mark.parametrize("op_type", ["ABS", "CLZ"]) +def test_ethosu_unary_elementwise_matcher(ofm_shape, ifm_layout, ofm_layout, op_type): + ifm_shape = ofm_shape.copy() + ofm_channels = ofm_shape[3] + nhwc_to_nhcwb16 = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1 / 16, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 16], + [0, 0, 0, 0, 1], + ] + if ifm_layout == "NHCWB16": + ifm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ifm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + if ofm_layout == "NHCWB16": + ofm_shape = [ + int(math.ceil(n)) + for n in np.matmul( + nhwc_to_nhcwb16, + ofm_shape + + [ + 1, + ], + ).tolist()[:-1] + ] + order = [1, 2, 4, 3, 0] + else: + order = [1, 2, 3, 4] + + ifm = te.placeholder(ifm_shape, dtype="int8") + lut = te.placeholder((), dtype="uint8") + out = unary_elementwise_compute( + ifm=ifm, + lut=lut, + operator_type=op_type, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ofm_channels=ofm_channels, + activation="NONE", + clip_min=0, + clip_max=0, + rounding_mode="TFL", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + ifm_propagator = out.op.attrs["ifm_propagator"] + + offset = [0] * len(ofm_shape) + stripes = [0] * len(ofm_shape) + output_stripe_config = cs.StripeConfig(ofm_shape, ofm_shape, ofm_shape, order, stripes, offset) + + ifm_transform = _make_matrices(ifm_layout, ofm_layout) + + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + part = match_ethosu_unary_elementwise(out, device_config) + + assert isinstance(part, cs.EthosuPart) + assert len(part.propagators) == 1 + assert part.propagators[0].transform == ifm_transform + + propagated_ifm = ifm_propagator.propagate(output_stripe_config).shape + + # Layout conversions will align the propagated IFMs to the brick, i.e. 16 + # so the expected ifm_shape needs to be rounded up to 16 + if ifm_layout != ofm_layout: + assert ifm_shape[:-1] == propagated_ifm[:-1] + assert ((ifm_shape[-1] + 16 - 1) // 16) * 16 == propagated_ifm[-1] + else: + assert ifm_shape == propagated_ifm + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/cascader/test_graph.py b/tests/python/contrib/test_ethosu/cascader/test_graph.py index 3bab83f241439..616800f69d7ee 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_graph.py +++ b/tests/python/contrib/test_ethosu/cascader/test_graph.py @@ -54,7 +54,7 @@ def test_inline_part(): assert len(part.propagators) == 1 assert part.in_line == True assert part.get_stripe_align_hint() == [1, 1] - performance_info = part.get_performance_info(output_stripe_config, is_rolling=False) + performance_info = part.get_performance_info(output_stripe_config, cs.BufferMode.RECOMPUTE) assert performance_info.compute_cycles == 0 assert performance_info.read_bytes == [0] assert performance_info.write_bytes == 0 @@ -127,7 +127,8 @@ def test_small_graph(): def test_create_cascader_graph(TwoConv2DWithSliceTE): _, te_graph, const_dict = TwoConv2DWithSliceTE - graph = cs.create_cascader_graph(te_graph, const_dict) + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + graph = cs.create_cascader_graph(te_graph, const_dict, device_config) output_tensor = graph.output_tensors[0] assert output_tensor.shape == [1, 6, 1, 6, 16] @@ -175,5 +176,29 @@ def test_create_cascader_graph(TwoConv2DWithSliceTE): assert conv1_part.input_tensors[2].is_constant +def test_create_diamond_graph(MobileNetv2DiamondTE): + _, te_graph, const_dict = MobileNetv2DiamondTE + device_config = cs.EthosuDeviceConfig("ethos-u55-256") + graph = cs.create_cascader_graph(te_graph, const_dict, device_config) + + output_tensor = graph.output_tensors[0] + assert output_tensor.shape == [1, 56, 56, 24] + assert len(output_tensor.producers) == 1 + assert not output_tensor.is_constant + + add1_part = output_tensor.producers[0] + assert isinstance(add1_part, cs.EthosuPart) + assert len(add1_part.input_tensors) == 2 + assert graph.get_part_id(add1_part) == 0 + + assert add1_part.input_tensors[0].shape == [1, 56, 56, 24] + assert len(add1_part.input_tensors[0].producers) == 1 + assert not add1_part.input_tensors[0].is_constant + + assert add1_part.input_tensors[1].shape == [1, 56, 56, 24] + assert len(add1_part.input_tensors[0].producers) == 1 + assert not add1_part.input_tensors[0].is_constant + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 96c8433a63848..0b058a94fb608 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -214,7 +214,7 @@ def create_test_runner(accel="ethos-u55-256"): pass_config={ "relay.ext.ethos-u.options": { "accelerator_config": accel, - } + }, }, ) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 23fd74dc486dd..4042bb057bd37 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -629,12 +629,13 @@ def create_mod_from_relay(): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("dtype", ["int8", "uint8"]) -def test_elementwise_add_from_constant_scalar(accel_type, dtype): +@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)]) +def test_elementwise_add_from_constant_scalar(accel_type, dtype, constant): ifm_shape = (1, 4, 4, 8) def create_relay_graph(): inp = relay.var("input", shape=ifm_shape, dtype=dtype) - scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) + scalar = relay.const(constant, dtype=dtype) add = relay.qnn.op.add( inp, scalar, diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py index 4df6311a230c7..e1688b8aa512e 100644 --- a/tests/python/contrib/test_ethosu/test_compiler.py +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -34,8 +34,7 @@ def test_lower_to_tir(): kernel_layout="HWIO", out_dtype="int32", ) - multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) - tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + tile = relay.tile(p2, reps=(1, 1, 1, 1001)) subtract = relay.subtract(conv, tile) func = subtract expr = relay.Function(relay.analysis.free_vars(func), func) diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py b/tests/python/contrib/test_ethosu/test_layout_optimizer.py index aafae1497ea40..62a1fabe0b98f 100644 --- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py @@ -33,14 +33,17 @@ from tvm import relay from tvm.relay.op.contrib.ethosu import partition_for_ethosu from tvm.relay.backend.contrib.ethosu.codegen import LayoutOptimizer +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func from . import infra -def _run_pass(expr, relay_pass): - """Create IRModule and run Relay pass.""" +def _optimize(expr, optimize=True): + """Create IRModule and run layout optimizer pass.""" mod = tvm.IRModule.from_expr(expr) - mod = relay_pass(mod) + mod = relay.transform.InferType()(mod) + if optimize: + mod = LayoutOptimizer()(mod) entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body @@ -111,8 +114,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -144,8 +147,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -176,8 +179,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -222,8 +225,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv_2), conv_2) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -268,8 +271,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv_2), conv_2) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -322,8 +325,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(pool_3), pool_3) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -368,8 +371,8 @@ def get_graph(): ) return relay.Function(relay.analysis.free_vars(conv), conv) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(), optimize=False) _assert_structural_equal(a, b) @@ -413,8 +416,8 @@ def get_graph(get_expected=False): concat = relay.concatenate(poolings, axis=0) return relay.Function(relay.analysis.free_vars(concat), concat) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -467,8 +470,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(add_3), add_3) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -500,8 +503,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -530,8 +533,8 @@ def get_graph(get_expected=False): ) return relay.Function(relay.analysis.free_vars(x), x) - a = _run_pass(get_graph(), LayoutOptimizer()) - b = _run_pass(get_graph(get_expected=True), relay.transform.InferType()) + a = _optimize(get_graph()) + b = _optimize(get_graph(get_expected=True), optimize=False) _assert_structural_equal(a, b) @@ -619,5 +622,32 @@ def representative_dataset(): _compile_and_compare_model(create_model(), ifm_shape, dtype) +def test_layout_optimizer_runs_in_compilation_pipeline(): + """Checks that the layout optimization pass runs as part of the NPU compilation + pipeline.""" + + def get_graph(): + x = relay.var("x", shape=(1, 4, 4, 4), dtype="int8") + for _ in range(2): + x = relay.nn.max_pool2d(x, layout="NHWC") + + func = relay.Function(relay.analysis.free_vars(x), x) + return tvm.IRModule.from_expr(func) + + mod = get_graph() + mod = partition_for_ethosu(mod) + + external_gv_name = mod["main"].body.op.name_hint + external_func = mod[external_gv_name] + prim_func = relay_to_tir_func(external_func) + + # Check for hints in the TIR prim func that the layout optimization pass has ran + ops = prim_func.body.body.seq + max_pool1, max_pool2 = ops + + assert str(max_pool1.value.args[31]) == '"NHCWB16"' + assert str(max_pool2.value.args[14]) == '"NHCWB16"' + + if __name__ == "__main__": pytest.main([__file__] + sys.argv[1:]) diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py b/tests/python/contrib/test_ethosu/test_lut_optimizer.py index 16835ce94ed77..d9a543c1a7716 100644 --- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py +++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py @@ -21,9 +21,16 @@ pytest.importorskip("ethosu.vela") +import tensorflow as tf +import numpy as np + import tvm from tvm import relay from tvm.relay.backend.contrib.ethosu.codegen import LUTsOptimizer +from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir_func +from tvm.relay.op.contrib.ethosu import partition_for_ethosu + +from .test_codegen import _get_tflite_graph from . import infra @@ -59,6 +66,7 @@ def after(): return mod mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) assert tvm.ir.structural_equal(mod, after()) @@ -91,5 +99,35 @@ def after(): return mod mod = LUTsOptimizer()(before()) + mod = relay.transform.InferType()(mod) assert tvm.ir.structural_equal(mod, after()) + + +def test_lut_optimizer_runs_in_compilation_pipeline(): + """Test that the LUT optimization pass runs as part of the NPU compilation pipeline.""" + ifm_shape = (1, 4, 4, 4) + + @tf.function + def get_graph(x): + weight1 = tf.constant(np.random.uniform(size=(1, 1, 4, 4)), dtype=tf.float32) + op = tf.nn.conv2d(x, weight1, (1, 1), "VALID") + op = tf.nn.tanh(op) + weight2 = tf.constant(np.random.uniform(size=(1, 1, 4, 1)), dtype=tf.float32) + op = tf.nn.depthwise_conv2d(op, weight2, (1, 1, 1, 1), "VALID") + return tf.nn.tanh(op) + + mod, _ = _get_tflite_graph(get_graph, [ifm_shape]) + mod = partition_for_ethosu(mod) + + external_gv_name = mod["main"].body.op.name_hint + external_func = mod[external_gv_name] + prim_func = relay_to_tir_func(external_func) + + # Check for hints in the TIR prim func that the LUT optimization pass has ran. + # If the module was optimized, there should be no identity operations. + def check_identity(stmt): + if isinstance(stmt, tvm.tir.expr.Call): + assert stmt.args[0] != "ethosu_identity" + + tvm.tir.stmt_functor.post_order_visit(prim_func.body, check_identity) diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index 0027a6b417364..90c3f566b34ed 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -35,6 +35,7 @@ from caffe.proto import caffe_pb2 as pb import tvm +import tvm.testing from tvm import relay from tvm.contrib import utils, graph_executor from tvm.contrib.download import download_testdata @@ -451,6 +452,35 @@ def test_forward_Deconvolution(): bias_filler=dict(type="xavier"), ), ) + _test_deconvolution( + data, + convolution_param=dict( + num_output=16, + bias_term=False, + pad=0, + kernel_size=2, + stride=2, + dilation=1, + group=16, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ), + ) + data = np.random.rand(1, 100, 32, 32).astype(np.float32) + _test_deconvolution( + data, + convolution_param=dict( + num_output=100, + bias_term=False, + pad=0, + kernel_size=2, + stride=2, + dilation=1, + group=100, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ), + ) ####################################################################### @@ -511,6 +541,45 @@ def test_forward_Eltwise(): operation=1, coeff=[0.5, 1], ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=0, + ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=1, + ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=2, + ) + _test_eltwise( + [ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32), + ], + operation=1, + coeff=[0.5, 1, 0.2, 1.8, 3.1, 0.1], + ) ####################################################################### diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 701906e4be40d..7485faa5f8c76 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -189,6 +189,7 @@ def verify_with_ort_with_inputs( opt_level=opt_level, convert_config=convert_config, ) + if not isinstance(tvm_out, list): tvm_out = [tvm_out] if not isinstance(ort_out, list): @@ -1272,6 +1273,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4)) verify_batch_matmul((4, 32, 16), (16, 32), (4, 32, 32)) verify_batch_matmul((4, 32, 16, 32), (32, 16), (4, 32, 16, 16)) + verify_batch_matmul((4, 32, 16, 32), (1, 32, 32, 16), (4, 32, 16, 16)) # Test transb=False verify_batch_matmul( (2, 3, 4, 3), @@ -1281,6 +1283,39 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): ) +@tvm.testing.parametrize_targets +def test_use_nt_batch_matmul(target, dev): + a_shape = (2, 3, 4) + b_shape = (2, 4, 3) + out_shape = [2, 3, 3] + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + + for use_nt_batch_matmul in [True, False]: + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + + graph = helper.make_graph( + [mul_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="matmul_test") + _, shape_dict = get_input_data_shape_dict(model, [a_array, b_array]) + + mod, _ = relay.frontend.from_onnx( + model, shape_dict, convert_config={"use_nt_batch_matmul": use_nt_batch_matmul} + ) + has_transpose_op = "transpose" in str(mod) + # use_nt_batch_matmul implies, TVM converts qualified onnx `matmul` + # to `transpose(weight) + nn.batch_matmul_NT`, otherwise to `nn.batch_matmul` + assert has_transpose_op == use_nt_batch_matmul + + @tvm.testing.parametrize_targets def test_matmulinteger16(target, dev): def verify_matmulinteger16(a_shape, b_shape, out_shape): @@ -2859,6 +2894,14 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1): # Test undefined groups. verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2], group=None) + if "llvm" in target: + # GPU does not support groups != 1 for convtranspose, so only test llvm + # Test depthwise-convolution + verify_convtranspose((1, 10, 3, 3), (10, 1, 3, 3), (1, 10, 7, 3), [1, 2, 1, 2], group=10) + + # Test grouped-convolution + verify_convtranspose((1, 10, 3, 3), (10, 1, 3, 3), (1, 5, 7, 3), [1, 2, 1, 2], group=5) + def repeat(N, D): return tuple([N for _ in range(D)]) @@ -5047,6 +5090,7 @@ def verify_eyelike(indata): "test_reduce_sum_keepdims_random", "test_reduce_sum_negative_axes_keepdims_example", "test_reduce_sum_negative_axes_keepdims_random", + "test_reshape_allowzero_reordered", "test_rnn_seq_length", "test_round", "test_sequence_insert_at_back", @@ -6287,6 +6331,7 @@ def verify_scan( test_random_uniform() test_convinteger() test_batch_matmul() + test_use_nt_batch_matmul() test_global_lppool() test_scan() test_random_uniform_like() diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 15da955722588..fb5b155e78528 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -40,7 +40,7 @@ def torch_version_check(): return version.parse(torch.__version__) > version.parse("1.4.0") -def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False): +def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=False, target="llvm"): input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch( script_module, input_shapes, keep_quantized_weight=keep_quantized_weight @@ -53,9 +53,9 @@ def get_tvm_runtime(script_module, input_name, ishape, keep_quantized_weight=Fal with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda # also not to make CI too slow - lib = relay.build(mod, target="llvm", params=params) + lib = relay.build(mod, target=target, params=params) - runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.cpu(0))) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0))) return runtime @@ -329,6 +329,15 @@ def test_quantized_modules(): print(module_name, max_abs_diff, mean_abs_diff, match_ratio) + if "linear" in module_name and tvm.get_global_func("tvm.contrib.cublas.matmul", True): + runtime = get_tvm_runtime(script_module, input_name, ishape, target="cuda -libs=cublas") + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + cublas_result = runtime.get_output(0).numpy() + # It is generally safe to enable this assertion, but disabled for CI + # tvm.testing.assert_allclose(cublas_result, pt_result, atol=1e-5, rtol=1e-5) + print(np.max(np.abs(cublas_result - pt_result))) + # sample outputs """ relu 0.0039215684 2.6052087e-08 0.9999933567176871 diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index be32ca308ba1b..c76803b8fb3c4 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1537,19 +1537,29 @@ def run(dtype_str, infer_shape): element_shape = tf.TensorShape([tf.Dimension(None)]) else: element_shape = None - t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) - indices = tf.constant([2, 1, 0]) - ta1 = tf.TensorArray( - dtype=dtype, size=3, infer_shape=infer_shape, element_shape=element_shape - ) - ta2 = ta1.scatter(indices, t) - out0 = ta2.read(0) - out1 = ta2.read(1) - out2 = ta2.read(2) + ta0 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 3) + out0 = ta0.read(0) + out1 = ta0.read(1) + out2 = ta0.read(2) + ta1 = _construct_scatter(dtype, dtype_str, element_shape, infer_shape, 4) + out4 = ta1.read(0) g = tf.get_default_graph() compare_tf_with_tvm([], [], ["TensorArrayReadV3:0"], mode="vm") compare_tf_with_tvm([], [], ["TensorArrayReadV3_1:0"], mode="vm") compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0"], mode="vm") + compare_tf_with_tvm([], [], ["TensorArrayReadV3_2:0", out4.name], mode="vm") + + def _construct_scatter(dtype, dtype_str, element_shape, infer_shape, size): + arr = [[float(i)] for i in range(size)] + indices_arr = [i for i in range(size - 1, -1, -1)] + + t = tf.constant(np.array(arr).astype(dtype_str), dtype=dtype) + indices = tf.constant(indices_arr) + ta1 = tf.TensorArray( + dtype=dtype, size=size, infer_shape=infer_shape, element_shape=element_shape + ) + ta2 = ta1.scatter(indices, t) + return ta2 for dtype in ["float32", "int8"]: run(dtype, False) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 6bc9d9ca43e99..77acce459fc9d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -4037,6 +4037,7 @@ def _test_fully_connected( # reshape N H W C into N H*W*C in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1]) out = math_ops.mat_mul(in_data_reshape, in_filter) + # TODO : Need to construct a fc op with (keep_num_dims == True) # if we have bias if bias_in_size: diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 3f448ca4a7d9c..6900bdc2e6e10 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -702,8 +702,9 @@ def run_and_check( t = tarfile.open(tar_file) t.extractall(base_path) - workspace_bytes += model.extra_memory_in_bytes - if interface_api == "packed": + workspace_bytes = model.extra_memory_in_bytes + use_usmp = runner.pass_config.get("tir.usmp.enable", False) + if interface_api == "packed" and not use_usmp: workspace_bytes += mlf_extract_workspace_size_bytes(tar_file) for key in model.inputs: @@ -815,6 +816,7 @@ def compile_and_run( target=target, target_opts=target_opts, ) + run_and_check( models=compiled_test_mods, runner=runner, diff --git a/tests/python/relay/aot/corstone300.ld b/tests/python/relay/aot/corstone300.ld index ddf55b8687806..a825da74c1dba 100644 --- a/tests/python/relay/aot/corstone300.ld +++ b/tests/python/relay/aot/corstone300.ld @@ -141,6 +141,8 @@ SECTIONS . = ALIGN (16); *(.rodata.tvm) . = ALIGN (16); + *(.data.tvm) + . = ALIGN (16); } > DDR .text : diff --git a/tests/python/relay/aot/corstone300.mk b/tests/python/relay/aot/corstone300.mk index 8ff11e29e5e2f..5a734f646d28f 100644 --- a/tests/python/relay/aot/corstone300.mk +++ b/tests/python/relay/aot/corstone300.mk @@ -69,8 +69,10 @@ QUIET ?= @ $(endif) CRT_SRCS = $(shell find $(CRT_ROOT)) -CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c)) -CODEGEN_OBJS = $(subst .c,.o,$(CODEGEN_SRCS)) +C_CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.c)) +CC_CODEGEN_SRCS = $(shell find $(abspath $(CODEGEN_ROOT)/host/src/*.cc)) +C_CODEGEN_OBJS = $(subst .c,.o,$(C_CODEGEN_SRCS)) +CC_CODEGEN_OBJS = $(subst .cc,.o,$(CC_CODEGEN_SRCS)) CMSIS_STARTUP_SRCS = $(shell find ${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Source/*.c) UART_SRCS = $(shell find ${PLATFORM_PATH}/*.c) @@ -96,9 +98,9 @@ $(build_dir)/tvm_ethosu_runtime.o: $(TVM_ROOT)/src/runtime/contrib/ethosu/bare_m $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^ -$(build_dir)/libcodegen.a: $(CODEGEN_SRCS) - $(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(CODEGEN_SRCS) - $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(CODEGEN_OBJS) +$(build_dir)/libcodegen.a: $(C_CODEGEN_SRCS) $(CC_CODEGEN_SRCS) + $(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c $(PKG_CFLAGS) $(C_CODEGEN_SRCS) $(CC_CODEGEN_SRCS) + $(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(C_CODEGEN_OBJS) $(CC_CODEGEN_OBJS) $(QUIET)$(RANLIB) $(abspath $(build_dir)/libcodegen.a) ${build_dir}/libcmsis_startup.a: $(CMSIS_STARTUP_SRCS) diff --git a/tests/python/relay/aot/default.mk b/tests/python/relay/aot/default.mk index f5edcb3d64223..b7258a3c6df8c 100644 --- a/tests/python/relay/aot/default.mk +++ b/tests/python/relay/aot/default.mk @@ -22,6 +22,7 @@ ENABLE_TVM_PLATFORM_ABORT_BACKTRACE = 0 DMLC_CORE=$(TVM_ROOT)/3rdparty/dmlc-core PKG_COMPILE_OPTS = -g CC = gcc +#CC = g++ AR = ar RANLIB = ranlib CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB) @@ -39,10 +40,12 @@ $(endif) aot_test_runner: $(build_dir)/aot_test_runner -source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c) -lib_objs =$(source_libs:.c=.o) +c_source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.c) +cc_source_libs= $(wildcard $(build_dir)/../codegen/host/src/*.cc) +c_lib_objs =$(c_source_libs:.c=.o) +cc_lib_objs =$(cc_source_libs:.cc=.o) -$(build_dir)/aot_test_runner: $(build_dir)/test.c $(source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o +$(build_dir)/aot_test_runner: $(build_dir)/test.c $(c_source_libs) $(cc_source_libs) $(build_dir)/stack_allocator.o $(build_dir)/crt_backend_api.o $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) -lm @@ -50,6 +53,10 @@ $(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) +$(build_dir)/%.o: $(build_dir)/../codegen/host/src/%.cc + $(QUIET)mkdir -p $(@D) + $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) + $(build_dir)/stack_allocator.o: $(STANDALONE_CRT_DIR)/src/runtime/crt/memory/stack_allocator.c $(QUIET)mkdir -p $(@D) $(QUIET)$(CC) $(CFLAGS) -c $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py index 473b8d5ee3005..d369fd0a4a30c 100644 --- a/tests/python/relay/aot/test_c_device_api.py +++ b/tests/python/relay/aot/test_c_device_api.py @@ -92,7 +92,7 @@ def compile_to_main_func(interface_api="c", use_unpacked_api=True): workspace_byte_alignment=16, pass_config=test_runner.pass_config, ) - main_ir_module = compiled_models[0].executor_factory.lowered_ir_mods.items()[1][1] + main_ir_module = compiled_models[0].executor_factory.lowered_ir_mods.items()[0][1] main_func = main_ir_module["run_model"] return main_func @@ -136,44 +136,28 @@ def test_device_api_hooks_unpacked_api(device_api_main_func): # Activate Device assert ( - str(main_func.body[0][0].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUActivate",' - + " device_context_ethos_u: handle," - + " dtype=int32)" + str(main_func.body[0]) + == "tir.call_extern(" + '"TVMDeviceEthosUActivate",' + " device_context_ethos_u)\n" ) # Open Device assert ( - str(main_func.body[1].body.body[0][0][0].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUOpen",' - + " device_context_ethos_u: handle," - + " dtype=int32)" + str(main_func.body[1][0][0][0]) + == "tir.call_extern(" + '"TVMDeviceEthosUOpen",' + " device_context_ethos_u)\n" ) # Device Call assert ( - str(main_func.body[1].body.body[0][0][1].value) - == "@tir.call_extern(" - + '"tvmgen_default_ethos_u_main_0",' - + " input: handle, output: handle," - + " device_context_ethos_u: handle," - + " dtype=int32)" + str(main_func.body[1][0][0][1]) + == 'tir.call_extern("tvmgen_default_ethos_u_main_0", x_int8_buffer_var, output_buffer_var, device_context_ethos_u)\n' ) # Close Device assert ( - str(main_func.body[1].body.body[0][0][2].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUClose",' - + " device_context_ethos_u: handle," - + " dtype=int32)" + str(main_func.body[1][0][0][2]) + == "tir.call_extern(" + '"TVMDeviceEthosUClose",' + " device_context_ethos_u)\n" ) # Deactivate Device assert ( - str(main_func.body[2][0].value) - == "@tir.call_extern(" - + '"TVMDeviceEthosUDeactivate",' - + " device_context_ethos_u: handle," - + " dtype=int32)" + str(str(main_func.body[2])) + == "tir.call_extern(" + '"TVMDeviceEthosUDeactivate",' + " device_context_ethos_u)\n" ) @@ -231,13 +215,9 @@ def test_without_device_api_unpacked_api(non_device_api_main_func): """Test a graph without the Device API with the unpacked internal calls""" main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True) - assert ( - str(main_func.body[1].body.body[0][0].value) - == "@tir.call_extern(" - + '"tvmgen_default_fused_multiply",' - + " input: handle, input_1: handle, output: handle," - + " dtype=int32)" + str(main_func.body) + == 'tir.call_extern("tvmgen_default_fused_multiply", x_buffer_var, y_buffer_var, output_buffer_var)\n' ) @@ -245,12 +225,17 @@ def test_without_device_api_packed_api(non_device_api_main_func): """Test a graph without the Device API with the packed internal calls""" main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False) - assert ( - str(main_func.body[1].body.body[0][0]) - == 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' - + "tir.tvm_struct_set(tvm_value_0, 0, 1, tir.reinterpret((uint64)0))\n" - + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", input, input, output, tvm_value_0)\n' + str(main_func.body) + == 'let tvm_value_3 = tir.tvm_stack_alloca("array", 1)\n' + + 'let tvm_value_2 = tir.tvm_stack_alloca("array", 1)\n' + + 'let tvm_value_1 = tir.tvm_stack_alloca("array", 1)\n' + + 'let tvm_value_0 = tir.tvm_stack_alloca("array", 1)\n' + + "tir.tvm_struct_set(tvm_value_0, 0, 1, x_buffer_var)\n" + + "tir.tvm_struct_set(tvm_value_1, 0, 1, y_buffer_var)\n" + + "tir.tvm_struct_set(tvm_value_2, 0, 1, output_buffer_var)\n" + + "tir.tvm_struct_set(tvm_value_3, 0, 1, tir.reinterpret((uint64)0))\n" + + 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", tvm_value_0, tvm_value_1, tvm_value_2, tvm_value_3)\n' ) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 8a2b1f1bb84d2..566566da1dcef 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -515,8 +515,8 @@ def test_quant_mobilenet_tfl(): import tvm.relay.testing.tf as tf_testing - interface_api = "packed" - use_unpacked_api = False + use_unpacked_api = True + interface_api = "c" test_runner = AOT_DEFAULT_RUNNER tflite_model_file = tf_testing.get_workload_official( @@ -657,42 +657,6 @@ def test_deprecated_target_arguments(capsys): ) -@pytest.mark.parametrize( - "workspace_byte_alignment,main_workspace_size,sum_workspace_size", - [ - (8, 10368, 15200), - (16, 10368, 15232), - (256, 10752, 17408), - ], -) -def test_memory_planning(workspace_byte_alignment, main_workspace_size, sum_workspace_size): - mod, params = tvm.relay.testing.synthetic.get_workload() - target = "c" - runtime = Runtime("crt") - executor = Executor( - "aot", - { - "workspace-byte-alignment": workspace_byte_alignment, - }, - ) - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) - - assert ( - sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size - ) - assert ( - sum( - [ - size - for metadata in lib.function_metadata.values() - for size in metadata.workspace_sizes.values() - ] - ) - == sum_workspace_size - ) - - def test_aot_codegen_backend_alloc_workspace_calls(): """This test checks whether AoT lowering creates TVMBackendAllocWorkspace calls""" diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py new file mode 100644 index 0000000000000..b88e0905dba5a --- /dev/null +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" This file contains test that use USMP + AoT using C runtime APIs""" + +from collections import OrderedDict +import sys + +import numpy as np +import pytest + +import tvm +from tvm import relay, TVMError +from tvm.ir.module import IRModule +from tvm.relay import testing, transform +from tvm.relay.testing import byoc +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.backend import Executor, Runtime +from aot_test_utils import ( + AOTTestModel, + AOTTestRunner, + generate_ref_data, + convert_to_relay, + compile_and_run, + compile_models, + parametrize_aot_options, + run_and_check, +) + + +def check_for_no_tvm_backendallocworkspace_calls(mod: tvm.runtime.module): + """This checker checks whether any c-source produced has TVMBackendAllocWorkspace calls. + If USMP is invoked, none of them should have TVMBAW calls""" + dso_modules = mod._collect_dso_modules() + for dso_mod in dso_modules: + assert ( + dso_mod.type_key == "c" + ), 'Current CRT AoT codegen flow should only produce type "c" runtime modules' + source = dso_mod.get_source() + source.count( + "TVMBackendAllocWorkspace" + ) == 0, "This is failing because USMP was unable to plan for every tir.allocate node" + + +@pytest.mark.parametrize( + "workspace_byte_alignment,main_workspace_size", + [ + (8, 17280), + (16, 17280), + (256, 17792), + ], +) +def test_memory_planning(workspace_byte_alignment, main_workspace_size): + mod, params = tvm.relay.testing.synthetic.get_workload() + target = "c" + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "workspace-byte-alignment": workspace_byte_alignment, + }, + ) + with tvm.transform.PassContext( + opt_level=3, + config={ + "tir.disable_vectorize": True, + "tir.disable_storage_rewrite": True, + "tir.usmp.enable": True, + "tir.usmp.algorithm": "greedy_by_conflicts", + }, + ): + lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) + assert ( + sum(lib.function_metadata["__tvm_main__"].workspace_sizes.values()) == main_workspace_size + ) + + +@parametrize_aot_options +@pytest.mark.parametrize("groups,weight_shape", [(1, 32), (32, 1)]) +def test_conv2d(interface_api, use_unpacked_api, test_runner, groups, weight_shape): + """Test a subgraph with a single conv2d operator.""" + dtype = "float32" + ishape = (1, 32, 14, 14) + wshape = (32, weight_shape, 3, 3) + pass_config = {"tir.usmp.enable": True} + test_runner = AOTTestRunner( + makefile=test_runner.makefile, + prologue=test_runner.prologue, + epilogue=test_runner.epilogue, + includes=test_runner.includes, + parameters=test_runner.parameters, + pass_config=pass_config, + ) + + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=wshape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=groups) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, wshape).astype(dtype) + + inputs = OrderedDict([("data", i_data), ("weight", w1_data)]) + + output_list = generate_ref_data(mod, inputs) + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) + compiled_test_mods = compile_models( + models=AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + ) + + +@pytest.mark.parametrize("merge_compiler_regions", [False, True]) +def test_byoc_microtvm(merge_compiler_regions): + """This is a simple test to check BYOC capabilities of AOT - with and without merging compiler regions to test for https://github.com/apache/tvm/issues/9036""" + use_unpacked_api = False + interface_api = "packed" + test_runner = AOTTestRunner(pass_config={"tir.usmp.enable": True}) + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + + # z0 = x + w0 + x_ = compiler_begin(x, "ccompiler") + w0_ = compiler_begin(w0, "ccompiler") + z0_ = relay.add(x_, w0_) + z0 = compiler_end(z0_, "ccompiler") + + # z1 = z0 + w1 + z0__ = compiler_begin(z0, "ccompiler") + w1_ = compiler_begin(w1, "ccompiler") + z1_ = relay.add(z0__, w1_) + z1 = compiler_end(z1_, "ccompiler") + + # z2 = z0 + z1 + z2 = relay.add(z0, z1) + + f = relay.Function([x, w0, w1], z2) + mod = tvm.IRModule() + mod["main"] = f + + if merge_compiler_regions: + mod = transform.MergeCompilerRegions()(mod) + + mod = transform.PartitionGraph("mod_name")(mod) + mod = transform.InferType()(mod) + + x_data = [("x", np.random.rand(10, 10).astype("float32"))] + w_data = [("w{}".format(i), np.random.rand(10, 10).astype("float32")) for i in range(2)] + + map_inputs = OrderedDict(x_data + w_data) + output_list = generate_ref_data(mod, map_inputs) + + compiled_test_mods = compile_models( + AOTTestModel(name="my_mod", module=mod, inputs=map_inputs, outputs=output_list), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + ) + + +MOBILENET_V1_URL = ( + "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", + "mobilenet_v1_1.0_224_quant.tflite", +) + + +@pytest.mark.parametrize( + "model_url, usmp_algo, workspace_size,", + [ + (MOBILENET_V1_URL, "greedy_by_size", 4845696), + (MOBILENET_V1_URL, "greedy_by_conflicts", 4845696), + ], +) +def test_tflite_model(model_url, usmp_algo, workspace_size): + """This checks for ML models and the memory used by them when using USMP with different algorithms""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + use_unpacked_api = True + interface_api = "c" + test_runner = AOTTestRunner( + pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": usmp_algo} + ) + + tflite_model_file = tf_testing.get_workload_official( + model_url[0], + model_url[1], + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data_shape = (1, 224, 224, 3) + in_min, in_max = (0, 255) + data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8") + mod, params = convert_to_relay(tflite_model_buf, data, "input") + inputs = {"input": data} + output_list = generate_ref_data(mod, inputs, params) + + compiled_test_mods = compile_models( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + # Checking the workspace size + assert ( + sum( + compiled_model.executor_factory.function_metadata[ + "__tvm_main__" + ].workspace_sizes.values() + ) + == workspace_size + ) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + ) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 9ac04d4933a6d..da36bba965568 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -541,7 +541,7 @@ def verify_any_conv2d( kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) targets = None - if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True): + if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True): targets = [("cuda -libs=cudnn", tvm.cuda(0))] check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 115ed48d5888a..1efdb262245f5 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. import numpy as np - +import pytest from tvm import topi import tvm.topi.testing import tvm from tvm import te from tvm import relay -from tvm.relay.testing import check_grad, run_infer_type +from tvm.relay.testing import check_grad, run_infer_type, run_opt_pass from tvm.relay.transform import gradient import tvm.testing @@ -229,11 +229,37 @@ def test_batch_flatten_grad(): verify_batch_flatten_grad((1, 8)) +def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding): + dtype = "float32" + dy = relay.var("dy", shape=dy_shape, dtype=dtype) + x = relay.var("x", shape=x_shape, dtype=dtype) + dw = relay.nn.conv2d_backward_weight( + dy, x, strides=stride, padding=padding, kernel_size=kernel_size + ) + dw_func = relay.Function([dy, x], dw) + dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) + + target = "llvm" + dev = tvm.device(target, 0) + dy_np = np.random.randn(*dy_shape).astype(dtype) + x_np = np.random.randn(*x_shape).astype(dtype) + + dw_np = ( + relay.create_executor(device=dev, target=target) + .evaluate(dw_func_legalized)(dy_np, x_np) + .numpy() + ) + ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding + ) + + np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) + + +def test_conv2d_backward_weight(): + verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1)) + verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0)) + + if __name__ == "__main__": - test_max_pool2d_grad() - test_avg_pool2d_grad() - test_global_avg_pool2d_grad() - test_conv2d_grad() - test_dense_grad() - test_matmul_grad() - test_batch_flatten_grad() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pipeline_executor.py b/tests/python/relay/test_pipeline_executor.py index 4e51f873b3fad..83cf237dbfcc6 100644 --- a/tests/python/relay/test_pipeline_executor.py +++ b/tests/python/relay/test_pipeline_executor.py @@ -126,44 +126,76 @@ def get_manual_conf(mods, target): return mod_config -def test_pipe_config_check(): - # This function is used to trigger runtime error by applying wrong logic connection. +def recreate_parameters(mod): + # Get the binding parameters from a module, then create the same parameters with different data. + # This function is used to test the "parameter" connection. + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, "llvm") - # Get three pipeline modules here. - (mod1, mod2, mod3), dshape = get_mannual_mod() + mod_customized_params = {} + for key, value in lib.params.items(): + new_value = value.numpy() + np.full(value.shape, 10).astype(value.dtype) + mod_customized_params[key] = tvm.nd.array(new_value) + return mod_customized_params - # The input or output name is illegal and expects a runtime error. - pipe_error = pipeline_executor.PipelineConfig() - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][9] - with pytest.raises(RuntimeError): - pipe_error[mod1]["input"]["data_9"] +def test_pipe_runtime_error_check(): + # This function is used to trigger runtime error by applying wrong logic. + if pipeline_executor.pipeline_executor_enabled(): + # Get three pipeline modules here. + (mod1, mod2, mod3), dshape = get_mannual_mod() + + # The input or output name is illegal and expects a runtime error. + pipe_error = pipeline_executor.PipelineConfig() + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][9] + + with pytest.raises(RuntimeError): + pipe_error[mod1]["input"]["data_9"] + + # The module connection will cause a cycle in DAG and expects runtime error. + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][0].connect(pipe_error[mod2]["input"]["data_0"]) + pipe_error[mod2]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) + + # The module connection is illegal and expects runtime error. + + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) - # The module connection will cause a cycle in DAG and expects runtime error. - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][0].connect(pipe_error[mod2]["input"]["data_0"]) - pipe_error[mod2]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod1]["input"]["data_0"]) - # The module connection is illegal and expects runtime error. + with pytest.raises(RuntimeError): + pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod2]["input"]["data_0"]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][0].connect(pipe_error[mod1]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error[mod1]["output"][0].connect(pipe_error["input"]["data_0"]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod1]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error["input"]["data_0"].connect(pipe_error[mod1]["output"][0]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["input"]["data_0"].connect(pipe_error[mod2]["input"]["data_0"]) + with pytest.raises(RuntimeError): + pipe_error["output"]["0"].connect(pipe_error[mod1]["output"][0]) - with pytest.raises(RuntimeError): - pipe_error[mod1]["output"][0].connect(pipe_error["input"]["data_0"]) + # Create pipeline executor to check the executor runtime errors. + pipe_config = pipeline_executor.PipelineConfig() + pipe_config[mod1].target = "llvm" + pipe_config[mod1].dev = tvm.cpu(0) + pipe_config["param_group"]["param_0"].connect(pipe_config[mod1]["param"]) + pipe_config[mod1]["output"][0].connect(pipe_config["output"]["0"]) + # Build and create a pipeline module. + with tvm.transform.PassContext(opt_level=3): + pipeline_mod_factory = pipeline_executor.build(pipe_config) + pipeline_module = pipeline_executor.PipelineModule(pipeline_mod_factory) + customized_parameters = recreate_parameters(mod1) - with pytest.raises(RuntimeError): - pipe_error["input"]["data_0"].connect(pipe_error[mod1]["output"][0]) + # Checking the pipeline executor runtime errors. + with pytest.raises(RuntimeError): + pipeline_module.set_params("param_0", None) - with pytest.raises(RuntimeError): - pipe_error["output"]["0"].connect(pipe_error[mod1]["output"][0]) + with pytest.raises(RuntimeError): + pipeline_module.set_params("param_1", customized_parameters) def test_pipeline(): @@ -180,6 +212,9 @@ def test_pipeline(): pipe_config = pipeline_executor.PipelineConfig() + customized_parameters = recreate_parameters(mod2) + # The global parameters group named "param_0" will be connected to "mod1" as parameters. + pipe_config["param_group"]["param_0"].connect(pipe_config[mod2]["param"]) # The pipeline input named "data_0" will be connected to a input named "data_0" # of mod1. pipe_config["input"]["data_a"].connect(pipe_config[mod1]["input"]["data_0"]) @@ -202,6 +237,7 @@ def test_pipeline(): # The mod3 output[0] will be connected to pipeline output[1]. pipe_config[mod3]["output"][0].connect(pipe_config["output"]["1"]) + print(pipe_config) # Print configueration (print(pipe_config)), the result looks like following. # # Inputs @@ -254,6 +290,10 @@ def test_pipeline(): assert input_map[0] == "1" and input_map[1] == "data_1" input_map = pipeline_module_test.get_input_pipeline_map("data_a") assert input_map[0] == "0" and input_map[1] == "data_0" + module_index = pipeline_module_test.get_params_group_pipeline_map("param_0") + assert module_index == 1 + # Use the parameters group name to set parameters. + pipeline_module_test.set_params("param_0", customized_parameters) if __name__ == "__main__": diff --git a/tests/python/unittest/test_ci.py b/tests/python/unittest/test_ci.py index ac7e6cdd7c29c..0c80617985ee0 100644 --- a/tests/python/unittest/test_ci.py +++ b/tests/python/unittest/test_ci.py @@ -18,6 +18,7 @@ import pathlib import subprocess import sys +import json import tempfile import pytest @@ -25,6 +26,33 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent +def test_cc_reviewers(): + reviewers_script = REPO_ROOT / "tests" / "scripts" / "github_cc_reviewers.py" + + def run(pr_body, expected_reviewers): + proc = subprocess.run( + [str(reviewers_script), "--dry-run"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env={"PR": json.dumps({"number": 1, "body": pr_body})}, + encoding="utf-8", + ) + if proc.returncode != 0: + raise RuntimeError(f"Process failed:\nstdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}") + + assert proc.stdout.strip().endswith(f"Adding reviewers: {expected_reviewers}") + + run(pr_body="abc", expected_reviewers=[]) + run(pr_body="cc @abc", expected_reviewers=["abc"]) + run(pr_body="cc @", expected_reviewers=[]) + run(pr_body="cc @abc @def", expected_reviewers=["abc", "def"]) + run(pr_body="some text cc @abc @def something else", expected_reviewers=["abc", "def"]) + run( + pr_body="some text cc @abc @def something else\n\n another cc @zzz z", + expected_reviewers=["abc", "def", "zzz"], + ) + + def test_skip_ci(): skip_ci_script = REPO_ROOT / "tests" / "scripts" / "git_skip_ci.py" diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py new file mode 100644 index 0000000000000..bebfec6122b35 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import sys +import pytest +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import VerifyGPUCode +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + VerifyGPUCode(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Conv2dCuda0: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda1: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([6400000], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda2: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512000], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda3: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 800000) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant + + +def test_postproc_verify_gpu_0(): + mod = Conv2dCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_1(): + mod = Conv2dCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_2(): + mod = Conv2dCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_3(): + mod = Conv2dCuda3 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py new file mode 100644 index 0000000000000..5a80312203542 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import add_rfactor +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.target import Target +from tvm.te.operation import create_prim_func + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l7, factor_axis=2)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l6, factor_axis=2)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + ], + ] + target = Target("llvm --num-cores=32") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=4, + m=4, + k=512, + ) + ), + target=target, + rule=add_rfactor(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py new file mode 100644 index 0000000000000..e206fcc4502ce --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import auto_inline +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Conv2DBiasBnReLU: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bias_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_mul"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0)) + + +@tvm.script.ir_module +class Conv2DBiasBnReLUInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class MultiLevelTiledConv2D: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class MultiLevelTiledConv2DAfterInline: + @T.prim_func + def main(X: T.Buffer[(1, 512, 56, 56), "float32"], W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1), "float32"], bn_scale: T.Buffer[(512, 1, 1), "float32"], bn_offset: T.Buffer[(512, 1, 1), "float32"], compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None: + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + T.if_then_else(yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + compute[v0, v1, v2, v3] = T.max((compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class SoftmaxBeforeInline: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_exp"): + i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) + T_softmax_exp[i0_2, i1_1] = T.exp(A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32") + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_4, k = T.axis.remap("SR", [i0_3, i1]) + with T.init(): + T_softmax_expsum[i0_4] = T.float32(0) + T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] + for i0_5, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) + T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] + + +@tvm.script.ir_module +class SoftmaxAfterInline: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T_softmax_norm[i0_4, i1_1] = T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_inline_consumer_chain(): + mod = Conv2DBiasBnReLU + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) + + +def test_inline_into_cache(): + mod = MultiLevelTiledConv2D + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MultiLevelTiledConv2DAfterInline) + + +def test_inline_into_multiple_consumers(): + mod = SoftmaxBeforeInline + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline) + + +if __name__ == "__main__": + test_inline_consumer_chain() + test_inline_into_cache() + test_inline_into_multiple_consumers() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py new file mode 100644 index 0000000000000..92c7da922c39c --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.schedule_rule import RandomComputeLocation +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Add: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + # body + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("move"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([A_cached[vi, vj, vk]]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads([A_cached[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_random_compute_location(): + expected = [ + [ + 'b0 = sch.get_block(name="move", func_name="main")', + "l1 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=1)", + ] + ] + mod = Add + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=RandomComputeLocation(), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_random_compute_location() diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index e508fbb0f7477..54037541016d3 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -130,6 +130,41 @@ def access_in_branch_func() -> None: B[i] = A[i - 1] +@T.prim_func +def access_of_padding_pattern() -> None: + X = T.alloc_buffer([28, 28]) + X_pad = T.alloc_buffer([32, 32]) + Y = T.alloc_buffer([28, 28]) + for i, j in T.grid(32, 32): + with T.block("padding"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads( + [ + X[ + T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1, + T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1, + ] + ] + ) + T.writes([X_pad[vi, vj]]) + X_pad[vi, vj] = T.if_then_else( + 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32" + ) + with T.block("padding_reverse"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]]) + T.writes( + [ + Y[ + T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1, + T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1, + ] + ] + ) + if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: + Y[vi - 2, vj - 2] = X_pad[vi, vj] + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -220,6 +255,36 @@ def test_access_in_branch_func(): tvm.ir.assert_structural_equal(ret0[1], ret1[1]) +def test_access_of_padding_pattern(): + s = tvm.tir.schedule.Schedule(access_of_padding_pattern) + alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + def do_compare_buffer_region(region, expect): + assert region.buffer == expect.buffer + analyzer = tvm.arith.Analyzer() + for k, rng in enumerate(region.region): + tvm.ir.assert_structural_equal( + analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min) + ) + tvm.ir.assert_structural_equal( + analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent) + ) + + def do_check_block(block_name): + block = s.get_sref(s.get_block(block_name)).stmt + expect_reads = block.reads + expect_writes = block.writes + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + for i, read in enumerate(ret[0]): + do_compare_buffer_region(read, expect_reads[i]) + for i, write in enumerate(ret[1]): + do_compare_buffer_region(write, expect_writes[i]) + + do_check_block("padding") + do_check_block("padding_reverse") + + if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() @@ -227,3 +292,4 @@ def test_access_in_branch_func(): test_match_buffer() test_access_in_if_then_else_func() test_access_in_branch_func() + test_access_of_padding_pattern() diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 5d2676e41d1c1..4a4cd6c6c2b96 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -37,6 +37,67 @@ def elementwise(a: T.handle, b: T.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 +@T.prim_func +def tiled_conv2d_with_padding( + inputs: T.Buffer[(1, 224, 224, 3), "float32"], + weight: T.Buffer[(7, 7, 3, 64), "float32"], + conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], +) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, + inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], + T.float32(0), + dtype="float32", + ) + for ( + i0_0, + i1_0, + i2_0, + i3_0, + i0_1_1, + i1_1_1, + i2_1_1, + i3_1_1, + i4_0, + i5_0, + i6_0, + i0_2, + i1_2, + i2_2, + i3_2, + i4_1, + i5_1, + i6_1, + i0_3, + i1_3, + i2_3, + i3_3, + ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): + with T.block("conv2d_nhwc"): + n = T.axis.spatial(1, 0) + h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) + w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) + co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1]) + T.reads( + conv2d_nhwc[n, h, w, co], + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], + weight[rh, rw, rc, co], + ) + T.writes(conv2d_nhwc[n, h, w, co]) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float32(0) + conv2d_nhwc[n, h, w, co] = ( + conv2d_nhwc[n, h, w, co] + + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + ) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -116,5 +177,21 @@ def test_sample_perfect_tile_composite(): verify_trace_roundtrip(sch, mod=elementwise) +def test_sample_compute_location(): + n = 100 + sch = tir.Schedule(tiled_conv2d_with_padding, seed=42, debug_mask="all") + pad_input = sch.get_block("PadInput") + decision_dict = dict() + for _ in range(n): + _ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name + decision = sch.trace.decisions[sch.trace.insts[-1]] + decision_dict[decision] = decision_dict[decision] + 1 if decision in decision_dict else 1 + + n_candidates = 8 + expected_rate = 1.0 / n_candidates + for _, cnt in decision_dict.items(): + assert (expected_rate - 0.03) * n <= cnt <= (expected_rate + 0.03) * n + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 57c87e5dedf4a..9b844853f2438 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -24,6 +24,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.CompactBufferAllocation()(mod) mod = tvm.tir.transform.Simplify()(mod) + transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"] tvm.ir.assert_structural_equal(mod["main"], transformed) diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 1a763d083b103..1995695100cbe 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -123,7 +123,7 @@ def test_no_pool_error(): buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0) -@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"]) +@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts", "hill_climb"]) def test_name_based_ordering(algorithm): """ This checks when the size and conlicts are same a stable result is generated""" @@ -142,9 +142,9 @@ def _test(): bi_c = usmp_utils.BufferInfo( name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool] ) - bi_a.set_conflicts([bi_b]) - bi_b.set_conflicts([bi_c]) - bi_c.set_conflicts([bi_a]) + bi_a.set_conflicts([bi_b, bi_c]) + bi_b.set_conflicts([bi_c, bi_a]) + bi_c.set_conflicts([bi_a, bi_b]) buffer_info_arr = [bi_a, bi_b, bi_c] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") @@ -160,7 +160,7 @@ def _test(): @pytest.mark.parametrize( ["algorithm", "workspace_size"], - [("greedy_by_size", 140), ("greedy_by_conflicts", 140)], + [("greedy_by_size", 140), ("greedy_by_conflicts", 140), ("hill_climb", 140)], ) def test_linear(algorithm, workspace_size): """ @@ -222,7 +222,7 @@ def test_linear(algorithm, workspace_size): @pytest.mark.parametrize( ["algorithm", "workspace_size"], - [("greedy_by_size", 190), ("greedy_by_conflicts", 320)], + [("greedy_by_size", 190), ("greedy_by_conflicts", 320), ("hill_climb", 190)], ) def test_fanout(algorithm, workspace_size): """ @@ -364,7 +364,11 @@ def run_model(input: T.handle, output: T.handle) -> None: @pytest.mark.parametrize( ["algorithm", "fast_memory_size", "slow_memory_size"], - [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)], + [ + ("greedy_by_size", 200704, 1418528), + ("greedy_by_conflicts", 200704, 1418528), + ("hill_climb", 200704, 1117462), + ], ) def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size): target = Target("c") @@ -529,7 +533,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @pytest.mark.parametrize( - ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)] + ["algorithm", "workspace_size"], + [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256), ("hill_climb", 7200256)], ) def test_resnet_subgraph(algorithm, workspace_size): target = Target("c") diff --git a/tests/python/unittest/test_tir_usmp_algo_hill_climb.py b/tests/python/unittest/test_tir_usmp_algo_hill_climb.py new file mode 100644 index 0000000000000..a5f1158a90c14 --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_algo_hill_climb.py @@ -0,0 +1,397 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import pytest +import random +import tvm +from tvm.tir.usmp.utils import BufferInfo, PoolInfo + + +def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): + """Helper to check maximum allocated memory size""" + max_workspace_size = 0 + for buffer_info, pool_allocation in buffer_pool_allocations.items(): + if pool_allocation.pool_info == pool_info: + size_candidate = pool_allocation.byte_offset + buffer_info.size_bytes + if size_candidate > max_workspace_size: + max_workspace_size = size_candidate + _diff = max_workspace_size.value - size + return ( + (max_workspace_size.value == size), + "'{}': expected {} got {}, diff {:0.2f}% ({} bytes)".format( + pool_info.pool_name, size, max_workspace_size, 100 * _diff / size, _diff + ), + ) + + +def _verify_conflicts(buffer_info, pool_allocation, buffer_info_map): + """Helper to check expected liveness conflicts""" + for conflict in buffer_info.conflicts: + conflict_pool_allocation = buffer_info_map[conflict] + + if conflict_pool_allocation.pool_info == pool_allocation.pool_info: + assert conflict_pool_allocation.byte_offset != pool_allocation.byte_offset + l2 = ( + max( + conflict_pool_allocation.byte_offset + conflict.size_bytes, + pool_allocation.byte_offset + buffer_info.size_bytes, + ) + - min(conflict_pool_allocation.byte_offset, pool_allocation.byte_offset) + ) + assert ( + conflict.size_bytes + buffer_info.size_bytes <= l2 + ), 'Conflicting: \n"{} @{}"\n"{} @{}"'.format( + conflict, conflict_pool_allocation, buffer_info, pool_allocation + ) + + +def _verify_all_conflicts(buffer_pool_allocations): + """Helper to verify liveness conflicts""" + for buffer_info, pool_allocation in buffer_pool_allocations.items(): + _verify_conflicts(buffer_info, pool_allocation, buffer_pool_allocations) + + +def test_bounded(random_len=150, pools=[PoolInfo("default", {}, 65535), PoolInfo("slow", {})]): + """Tests two pools, one is bounded and one is not limited""" + random.seed(0) + mem_range = [BufferInfo(str(i), random.randrange(1, 65535), pools) for i in range(random_len)] + for mr in mem_range: + pr = random.choice(mem_range) + while pr in (*mr.conflicts, mr): + pr = random.choice(mem_range) + + mr.set_conflicts([*mr.conflicts, pr]) + pr.set_conflicts([*pr.conflicts, mr]) + + fusmp_algo = tvm.get_global_func("tir.usmp.algo.hill_climb") + result_map = fusmp_algo(mem_range, 0) + _verify_all_conflicts(result_map) + + +def __test_data_alloc_max(): + """Test data""" + intervals = [ + (0, 159, 2048), + (0, 13, 7904), + (4, 35, 16), + (12, 17, 32768), + (16, 21, 32768), + ] + return intervals + + +def __test_data_deep_speech(): + """Test data""" + intervals = [ + (0, 159, 2048), + (0, 151, 2048), + (0, 13, 7904), + (2, 49, 16), + (4, 35, 16), + (6, 21, 16), + (12, 17, 32768), + (16, 21, 32768), + (20, 27, 32768), + (26, 31, 32768), + (30, 35, 32768), + (34, 41, 32768), + (40, 45, 32768), + (44, 49, 32768), + (48, 145, 32768), + (54, 59, 2048), + (58, 483, 4096), + (60, 65, 2048), + (64, 461, 4096), + (66, 71, 2048), + (70, 439, 4096), + (72, 77, 2048), + (76, 417, 4096), + (78, 83, 2048), + (82, 395, 4096), + (84, 89, 2048), + (88, 373, 4096), + (90, 95, 2048), + (94, 351, 4096), + (96, 101, 2048), + (100, 329, 4096), + (102, 107, 2048), + (106, 307, 4096), + (108, 113, 2048), + (112, 285, 4096), + (114, 119, 2048), + (118, 263, 4096), + (120, 125, 2048), + (124, 241, 4096), + (126, 131, 2048), + (130, 219, 4096), + (132, 137, 2048), + (136, 197, 4096), + (138, 143, 2048), + (142, 175, 4096), + (144, 149, 2048), + (148, 153, 4096), + (152, 163, 8192), + (154, 171, 2048), + (156, 181, 2048), + (160, 167, 2048), + (162, 165, 2048), + (168, 171, 2048), + (170, 509, 2048), + (174, 185, 8192), + (176, 193, 2048), + (178, 203, 2048), + (182, 189, 2048), + (184, 187, 2048), + (190, 193, 2048), + (192, 511, 2048), + (196, 207, 8192), + (198, 215, 2048), + (200, 225, 2048), + (204, 211, 2048), + (206, 209, 2048), + (212, 215, 2048), + (214, 513, 2048), + (218, 229, 8192), + (220, 237, 2048), + (222, 247, 2048), + (226, 233, 2048), + (228, 231, 2048), + (234, 237, 2048), + (236, 515, 2048), + (240, 251, 8192), + (242, 259, 2048), + (244, 269, 2048), + (248, 255, 2048), + (250, 253, 2048), + (256, 259, 2048), + (258, 517, 2048), + (262, 273, 8192), + (264, 281, 2048), + (266, 291, 2048), + (270, 277, 2048), + (272, 275, 2048), + (278, 281, 2048), + (280, 519, 2048), + (284, 295, 8192), + (286, 303, 2048), + (288, 313, 2048), + (292, 299, 2048), + (294, 297, 2048), + (300, 303, 2048), + (302, 521, 2048), + (306, 317, 8192), + (308, 325, 2048), + (310, 335, 2048), + (314, 321, 2048), + (316, 319, 2048), + (322, 325, 2048), + (324, 523, 2048), + (328, 339, 8192), + (330, 347, 2048), + (332, 357, 2048), + (336, 343, 2048), + (338, 341, 2048), + (344, 347, 2048), + (346, 525, 2048), + (350, 361, 8192), + (352, 369, 2048), + (354, 379, 2048), + (358, 365, 2048), + (360, 363, 2048), + (366, 369, 2048), + (368, 527, 2048), + (372, 383, 8192), + (374, 391, 2048), + (376, 401, 2048), + (380, 387, 2048), + (382, 385, 2048), + (388, 391, 2048), + (390, 529, 2048), + (394, 405, 8192), + (396, 413, 2048), + (398, 423, 2048), + (402, 409, 2048), + (404, 407, 2048), + (410, 413, 2048), + (412, 531, 2048), + (416, 427, 8192), + (418, 435, 2048), + (420, 445, 2048), + (424, 431, 2048), + (426, 429, 2048), + (432, 435, 2048), + (434, 533, 2048), + (438, 449, 8192), + (440, 457, 2048), + (442, 467, 2048), + (446, 453, 2048), + (448, 451, 2048), + (454, 457, 2048), + (456, 535, 2048), + (460, 471, 8192), + (462, 479, 2048), + (464, 489, 2048), + (468, 475, 2048), + (470, 473, 2048), + (476, 479, 2048), + (478, 537, 2048), + (482, 493, 8192), + (484, 501, 2048), + (486, 497, 2048), + (490, 497, 2048), + (492, 495, 2048), + (496, 626, 2048), + (498, 501, 2048), + (500, 626, 2048), + (504, 549, 16), + (508, 543, 32768), + (542, 549, 32768), + (548, 555, 32768), + (554, 563, 464), + (560, 563, 256), + (562, 617, 2048), + (564, 567, 1856), + (566, 573, 1024), + (568, 619, 1024), + (570, 573, 1024), + (572, 577, 1024), + (576, 579, 1024), + (578, 605, 1024), + (580, 593, 1024), + (584, 587, 1024), + (586, 603, 1024), + (594, 597, 1024), + (596, 613, 1024), + (604, 607, 1024), + (606, 617, 1024), + (616, 621, 2048), + (618, 621, 1024), + (620, 626, 464), + ] + return intervals + + +def __test_data_five(): + """Test data""" + return [ + (4, 5, 95), + (1, 4, 52135), + (3, 4, 12136), + (3, 5, 62099), + (4, 5, 50458), + ] + + +def __test_data_simple(): + """Test data""" + return [ + (0, 23, 131072), # 0 + (4, 5, 65568), # 1 + (4, 9, 8192), # 2 + (8, 30, 15360), # 3 + (10, 11, 65568), # 4 + (10, 15, 4096), # 5 + (16, 17, 65552), # 6 + (16, 21, 2048), # 7 + (22, 23, 32784), # 8 + (22, 27, 1024), # 9 + ] + + +def find_maximum_from_intervals(intervals): + """Expected list of intervals of (start, end, size)""" + sorted_list = sorted(intervals, key=lambda _: _[0]) + max_mem = 0 + for t in range(sorted_list[0][0], sorted_list[-1][1] + 1): + max_mem = max( + max_mem, sum([size for (start, end, size) in sorted_list if t >= start and t <= end]) + ) + return max_mem + + +@pytest.mark.parametrize( + "intervals", + [__test_data_alloc_max(), __test_data_simple(), __test_data_deep_speech(), __test_data_five()], +) +def test_intervals(intervals): + """Tests supplied intervals""" + random.seed(0) + result = run_intervals(intervals) + assert result["tir.usmp.algo.hill_climb"] == True, f" {result}" + + +def generate_range(sz, max_segment_sz=65535): + """Helper func to generate list of size sz of ranges of random size max_segment_sz""" + for i in range(0, sz): + start = random.randrange(i, sz) + stop = random.randrange(start + 1, start + 2 + ((sz - start) // 2)) + assert stop - start > 0 + yield (start, stop, random.randrange(1, max_segment_sz)) + + +def test_random_intervals(interval_len=16): + """Tests randomly generated interval of length interval_len""" + random.seed(0) + intervals = list(generate_range(interval_len)) + return run_intervals(intervals) + + +def run_intervals(intervals): + """Helper to run intervals""" + expected_mem = find_maximum_from_intervals(intervals) + pools = [PoolInfo("default", {})] + buffers = [] + # populate + for i, (start, stop, size) in enumerate(intervals): + buf = BufferInfo(str(i), size, pools) + # buf.set_pool_candidates( ["default"] ) + buffers.append(buf) + + # intersect + for i, (i_start, i_stop, _) in enumerate(intervals): + conflicts = set() + for j, (j_start, j_stop, _) in enumerate(intervals): + start = min(i_start, j_start) + stop = max(i_stop, j_stop) + i_dur = i_stop - i_start + 1 + j_dur = j_stop - j_start + 1 + + if i != j and (stop - start + 1 < i_dur + j_dur): + conflicts.add(buffers[j]) + + buffers[i].set_conflicts([c for c in sorted(conflicts, key=lambda c: c.name_hint)]) + + result = {} + for (alg, params) in [ + ("tir.usmp.algo.hill_climb", (expected_mem,)), + ("tir.usmp.algo.greedy_by_size", (expected_mem,)), + ]: + fusmp_algo = tvm.get_global_func(alg) + print("\n", "started", alg) + buffer_info_arr = fusmp_algo(buffers, *params) + print() + + _verify_all_conflicts(buffer_info_arr) + result[alg], msg = _check_max_workspace_size(buffer_info_arr, pools[0], expected_mem) + if not result[alg]: + print(alg, msg) + + return result + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index fc615775c160e..ab40c646391c3 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -140,7 +140,7 @@ def run_model(input: T.handle, output: T.handle) -> None: @tvm.script.ir_module class LinearStructurePlanned: @T.prim_func - def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None: + def run_model(input: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle, output: T.handle) -> None: fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) # body @@ -464,7 +464,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) @T.prim_func - def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None: + def run_model(input: T.handle, global_workspace_0_var: T.handle, output: T.handle) -> None: global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) # body T.attr("default", "device_id", 0) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 58458b38d7f37..0e77b2a494546 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -20,6 +20,7 @@ import pytest from tvm.ir import assert_structural_equal from tvm.script import tir as T +from tvm.script.parser import from_source from tvm.testing import check_error @@ -158,5 +159,27 @@ def elementwise_buffer_no_kwargs_failed( pass +# dynamic shape gemm +@T.prim_func +def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): + N = T.var("int32") + M = T.var("int32") + K = T.var("int32") + A = T.match_buffer(a, (N, K), "float32") + B = T.match_buffer(b, (K, M), "float32") + C = T.match_buffer(c, (N, M), "float32") + for i, j, k in T.grid(N, M, K): + with T.block("gemm"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +def test_dynamic_shape_gemm(): + gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script()) + assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 416113319a926..c9c3cc2cbfcf5 100644 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -147,6 +147,15 @@ def docker(name: str, image: str, scripts: List[str], env: Dict[str, str]): """ check_docker() + if os.getenv("USE_SCCACHE", "0") == "1": + scripts = [ + "sccache --start-server", + ] + scripts + # Set the C/C++ compiler so CMake picks them up in the build + env["CC"] = "/opt/sccache/cc" + env["CXX"] = "/opt/sccache/c++" + env["SCCACHE_CACHE_SIZE"] = os.getenv("SCCACHE_CACHE_SIZE", "50G") + docker_bash = REPO_ROOT / "docker" / "bash.sh" command = [docker_bash, "--name", name] for key, value in env.items(): diff --git a/tests/scripts/git_skip_ci.py b/tests/scripts/git_skip_ci.py index 73fcc6490ab84..c4b88676c34f7 100755 --- a/tests/scripts/git_skip_ci.py +++ b/tests/scripts/git_skip_ci.py @@ -17,56 +17,9 @@ # under the License. import os -import json import argparse -import subprocess -import re -from urllib import request -from typing import Dict, Tuple, Any - -class GitHubRepo: - def __init__(self, user, repo, token): - self.token = token - self.user = user - self.repo = repo - self.base = f"https://api.github.com/repos/{user}/{repo}/" - - def headers(self): - return { - "Authorization": f"Bearer {self.token}", - } - - def get(self, url: str) -> Dict[str, Any]: - url = self.base + url - print("Requesting", url) - req = request.Request(url, headers=self.headers()) - with request.urlopen(req) as response: - response = json.loads(response.read()) - return response - - -def parse_remote(remote: str) -> Tuple[str, str]: - """ - Get a GitHub (user, repo) pair out of a git remote - """ - if remote.startswith("https://"): - # Parse HTTP remote - parts = remote.split("/") - if len(parts) < 2: - raise RuntimeError(f"Unable to parse remote '{remote}'") - return parts[-2], parts[-1].replace(".git", "") - else: - # Parse SSH remote - m = re.search(r":(.*)/(.*)\.git", remote) - if m is None or len(m.groups()) != 2: - raise RuntimeError(f"Unable to parse remote '{remote}'") - return m.groups() - - -def git(command): - proc = subprocess.run(["git"] + command, stdout=subprocess.PIPE, check=True) - return proc.stdout.decode().strip() +from git_utils import git, GitHubRepo, parse_remote if __name__ == "__main__": diff --git a/tests/scripts/git_utils.py b/tests/scripts/git_utils.py new file mode 100644 index 0000000000000..530abe8029a6e --- /dev/null +++ b/tests/scripts/git_utils.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import subprocess +import re +from urllib import request +from typing import Dict, Tuple, Any + + +class GitHubRepo: + def __init__(self, user, repo, token): + self.token = token + self.user = user + self.repo = repo + self.base = f"https://api.github.com/repos/{user}/{repo}/" + + def headers(self): + return { + "Authorization": f"Bearer {self.token}", + } + + def graphql(self, query: str) -> Dict[str, Any]: + return self._post("https://api.github.com/graphql", {"query": query}) + + def _post(self, full_url: str, body: Dict[str, Any]) -> Dict[str, Any]: + print("Requesting POST to", full_url, "with", body) + req = request.Request(full_url, headers=self.headers(), method="POST") + req.add_header("Content-Type", "application/json; charset=utf-8") + data = json.dumps(body) + data = data.encode("utf-8") + req.add_header("Content-Length", len(data)) + + with request.urlopen(req, data) as response: + response = json.loads(response.read()) + return response + + def post(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: + return self._post(self.base + url, data) + + def get(self, url: str) -> Dict[str, Any]: + url = self.base + url + print("Requesting GET to", url) + req = request.Request(url, headers=self.headers()) + with request.urlopen(req) as response: + response = json.loads(response.read()) + return response + + def delete(self, url: str) -> Dict[str, Any]: + url = self.base + url + print("Requesting DELETE to", url) + req = request.Request(url, headers=self.headers(), method="DELETE") + with request.urlopen(req) as response: + response = json.loads(response.read()) + return response + + +def parse_remote(remote: str) -> Tuple[str, str]: + """ + Get a GitHub (user, repo) pair out of a git remote + """ + if remote.startswith("https://"): + # Parse HTTP remote + parts = remote.split("/") + if len(parts) < 2: + raise RuntimeError(f"Unable to parse remote '{remote}'") + return parts[-2], parts[-1].replace(".git", "") + else: + # Parse SSH remote + m = re.search(r":(.*)/(.*)\.git", remote) + if m is None or len(m.groups()) != 2: + raise RuntimeError(f"Unable to parse remote '{remote}'") + return m.groups() + + +def git(command): + command = ["git"] + command + print("Running", command) + proc = subprocess.run(command, stdout=subprocess.PIPE, check=True) + return proc.stdout.decode().strip() diff --git a/tests/scripts/github_cc_reviewers.py b/tests/scripts/github_cc_reviewers.py new file mode 100755 index 0000000000000..8e7198aa7b8fd --- /dev/null +++ b/tests/scripts/github_cc_reviewers.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import json +import argparse +import re +from urllib import error +from typing import Dict, Any, List + + +from git_utils import git, GitHubRepo, parse_remote + + +def find_reviewers(body: str) -> List[str]: + print(f"Parsing body:\n{body}") + matches = re.findall(r"(cc( @[-A-Za-z0-9]+)+)", body, flags=re.MULTILINE) + matches = [full for full, last in matches] + + print("Found matches:", matches) + reviewers = [] + for match in matches: + if match.startswith("cc "): + match = match.replace("cc ", "") + users = [x.strip() for x in match.split("@")] + reviewers += users + + reviewers = set(x for x in reviewers if x != "") + return sorted(list(reviewers)) + + +if __name__ == "__main__": + help = "Add @cc'ed people in a PR body as reviewers" + parser = argparse.ArgumentParser(description=help) + parser.add_argument("--remote", default="origin", help="ssh remote to parse") + parser.add_argument( + "--dry-run", + action="store_true", + default=False, + help="run but don't send any request to GitHub", + ) + args = parser.parse_args() + + remote = git(["config", "--get", f"remote.{args.remote}.url"]) + user, repo = parse_remote(remote) + + pr = json.loads(os.environ["PR"]) + + number = pr["number"] + body = pr["body"] + if body is None: + body = "" + + to_add = find_reviewers(body) + print("Adding reviewers:", to_add) + + if not args.dry_run: + github = GitHubRepo(token=os.environ["GITHUB_TOKEN"], user=user, repo=repo) + + # Add reviewers 1 by 1 since GitHub will error out if any of the + # requested reviewers aren't members / contributors + for reviewer in to_add: + try: + github.post(f"pulls/{number}/requested_reviewers", {"reviewers": [reviewer]}) + except error.HTTPError as e: + print(f"Failed to add reviewer {reviewer}: {e}") diff --git a/tests/scripts/task_config_build_qemu.sh b/tests/scripts/task_config_build_qemu.sh index 816af7a56c6da..134de9d3d3a32 100755 --- a/tests/scripts/task_config_build_qemu.sh +++ b/tests/scripts/task_config_build_qemu.sh @@ -25,6 +25,8 @@ cp ../cmake/config.cmake . echo set\(USE_SORT ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake +echo set\(USE_CMSISNN ON\) >> config.cmake +echo set\(USE_ETHOSU ON\) >> config.cmake echo set\(USE_LLVM llvm-config-10\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake diff --git a/web/src/compact.ts b/web/src/compact.ts index 29569b5d005d3..ac6af35abeff8 100644 --- a/web/src/compact.ts +++ b/web/src/compact.ts @@ -19,9 +19,9 @@ /** NodeJS and Web compact layer */ /** - * Get performance masurement. + * Get performance measurement. */ -export function getPeformance(): Performance { +export function getPerformance(): Performance { if (typeof performance == "undefined") { // eslint-disable-next-line @typescript-eslint/no-var-requires const performanceNode = require("perf_hooks"); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 60a28d53f3617..b0e71d945f8a1 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -653,7 +653,7 @@ class GraphExecutor implements Disposable { */ async benchmarkRuns(dev: DLDevice, number=10, repeat=4): Promise { // Skip first run as it can involve GPU warmup and module loading time. - const perf = compact.getPeformance(); + const perf = compact.getPerformance(); const results = []; this.run(); await dev.sync(); @@ -1049,7 +1049,7 @@ export class Instance implements Disposable { /** Register global packed functions needed by the backend to the env. */ private registerEnvGlobalPackedFuncs(): void { // Register the timer function to enable the time_evaluator. - const perf = compact.getPeformance(); + const perf = compact.getPerformance(); // Helper function to time the finvoke const timeExecution = async (