From 1935341af7b2accae0fc4b1e2d6b94e053a27979 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 27 Jan 2022 03:55:29 +0800 Subject: [PATCH] [MetaSchedule] postproc: rewrite_parallel_vectorize_unroll (#10071) Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin Co-authored-by: Junru Shao Co-authored-by: Xiyou Zhou Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin --- python/tvm/meta_schedule/postproc/__init__.py | 1 + .../rewrite_parallel_vectorize_unroll.py | 33 ++ .../rewrite_parallel_vectorize_unroll.cc | 399 ++++++++++++++++++ ...tproc_rewrite_parallel_vectorize_unroll.py | 87 ++++ 4 files changed, 520 insertions(+) create mode 100644 python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py create mode 100644 src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc create mode 100644 tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index eaab8c7bd484..0c914ac809f9 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -17,6 +17,7 @@ """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc from .disallow_dynamic_loop import DisallowDynamicLoop +from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..abe7288acba9 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that applies parallelization, vectorization and auto unrolling +according to the annotation of each block""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteParallelVectorizeUnroll") +class RewriteParallelVectorizeUnroll(Postproc): + """A postprocessor that applies parallelization, vectorization and auto unrolling + according to the annotation of each block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc new file mode 100644 index 000000000000..69e8dfb858bc --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether the loop has any annotation + * \param sref The sref of loop + * \return Whether the loop has any annotation + */ +inline bool HasAnnOrBinding(const ForNode* loop) { + return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); +} + +/*! \brief The visitor for extracting the stride of a var in a PrimExpr. */ +class StrideExtractor : public ExprVisitor { + public: + /*! + * \brief Extracting the stride of a var in a PrimExpr. + * e.g the stride of `x` in `(x * 2 + 1) * 3 + 1` is 6 + * \param expr The given PrimExpr. + * \param var The target var. + * \return The stride of the var. + */ + static int64_t Extract(const PrimExpr& expr, const Var& var) { + StrideExtractor extractor(var); + extractor.VisitExpr(expr); + return extractor.strides_[expr.get()]; + } + + private: + explicit StrideExtractor(const Var& var) : var_(var) {} + + void VisitExpr_(const MulNode* node) final { + ExprVisitor::VisitExpr_(node); + + if (const auto* a = node->a.as()) { + if (strides_.count(node->b.get())) { + strides_[node] = strides_[node->b.get()] * a->value; + } + } else if (const auto* b = node->b.as()) { + if (strides_.count(node->a.get())) { + strides_[node] = strides_[node->a.get()] * b->value; + } + } + } + + void VisitExpr_(const AddNode* node) final { + ExprVisitor::VisitExpr_(node); + int64_t stride_a, stride_b; + if (strides_.count(node->a.get())) { + stride_a = strides_[node->a.get()]; + } else { + stride_a = INT64_MAX; + } + if (strides_.count(node->b.get())) { + stride_b = strides_[node->b.get()]; + } else { + stride_b = INT64_MAX; + } + if (stride_a != INT64_MAX || stride_b != INT64_MAX) { + strides_[node] = std::min(stride_a, stride_b); + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_.get()) { + strides_[node] = 1; + } + } + + const Var& var_; + std::unordered_map strides_; +}; + +struct ParsedAnnotation { + int max_parallel_extent; + int max_vectorize_extent; + int unroll_explicit; + int unroll_implicit; + int num_parallel_loops; + int num_vectorize_loops; +}; + +bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { + bool found = false; + *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; + for (const auto& ann : block->annotations) { + if (ann.first == attr::meta_schedule_parallel) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_parallel_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_vectorize) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_vectorize_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_explicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_explicit = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_implicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_implicit = imm->value; + } + } + } + return found; +} + +void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { + if (parsed.max_parallel_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_parallel); + } + if (parsed.max_vectorize_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_vectorize); + } + if (parsed.unroll_explicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); + } + if (parsed.unroll_implicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); + } +} + +void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, + const Array& loop_rvs, ParsedAnnotation* parsed) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { + return; + } + int n_loops = loop_rvs.size(); + if (n_loops == 0) { + parsed->max_parallel_extent = -1; + parsed->max_vectorize_extent = -1; + return; + } + // Extract loop_srefs, and calculate the iterator types + Array loop_srefs; + std::vector loop_types; + { + loop_srefs.reserve(n_loops); + loop_types.reserve(n_loops); + for (const LoopRV& loop_rv : loop_rvs) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + loop_types.push_back(GetLoopIterType(loop_srefs.back())); + } + } + // check the maximal number of axes that are vectorizable (contiguous memory access) + BlockRealize realize = GetBlockRealize(sch->state(), block_sref); + Array buffer_access(realize->block->reads); + buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), + realize->block->writes.end()); + std::unordered_map binding_map; + for (size_t i = 0; i < realize->iter_values.size(); i++) { + binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; + } + int max_fusible = INT32_MAX; + // for each block read/write, get the strides of the loop vars and find the fusible + // (vectorizable) axes + for (const BufferRegion& access : buffer_access) { + int fusible = 0; + std::vector strides; + // get strides for each loop var + for (const StmtSRef& loop_sref : loop_srefs) { + int64_t stride = 0, buffer_stride = 1; + const auto* var = loop_sref->StmtAs(); + arith::Analyzer analyzer; + for (int i = access->region.size() - 1; i >= 0; i--) { + PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); + int64_t coef = StrideExtractor::Extract(idx, var->loop_var); + if (coef != 0) { + stride = coef * buffer_stride; + break; + } + buffer_stride *= access->buffer->shape[i].as()->value; + } + strides.push_back(stride); + } + int prev_used_iter = -1; + // check the number of fusible loops + for (int i = strides.size() - 1; i >= 0; i--) { + if (strides[i] == 0) { + // not used in the buffer access, safe to fuse + fusible++; + continue; + } else if (prev_used_iter == -1) { + // the stride of last axis is not 1 means the memory access is not contiguous + if (strides[i] != 1) { + break; + } + fusible++; + prev_used_iter = i; + } else { + // contiguous memory access + const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); + int64_t prev_used_iter_extent = prev_loop->extent.as()->value; + if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + fusible++; + prev_used_iter = i; + } else { + break; + } + } + } + max_fusible = std::min(max_fusible, fusible); + } + // Calculate the parallelize extent + if (parsed->max_parallel_extent != -1) { + int max_extent = parsed->max_parallel_extent; + int& num_fusible = parsed->num_parallel_loops = 0; + int64_t prod_extent = 1; + for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Then we can fuse it in + ++num_fusible; + // Check if we need to break + prod_extent *= *extent; + if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { + break; + } + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Calculate the vectorize extent + if (parsed->max_vectorize_extent != -1) { + int max_extent = parsed->max_vectorize_extent; + int& num_fusible = parsed->num_vectorize_loops = 0; + int64_t prod_extent = 1; + for (int i = n_loops - 1; + i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Cannot vectorize reduce axis + if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { + break; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Check if the extent is still in a good range + prod_extent *= *extent; + if (prod_extent > max_extent) { + break; + } + ++num_fusible; + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { + parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // + n_loops - parsed->num_vectorize_loops); + } +} + +bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + Block block = Downcast(prim_func->body)->block; + if (ParseAnnotation(block, parsed)) { + *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); + RemoveParsedAnn(sch, *root_rv, *parsed); + return true; + } + } + } + return false; +} + +void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { + ICHECK_LE(n, loop_rvs->size()); + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); + sch->Parallel(fused); + for (size_t i = 0; i < n; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteVectorize(const Schedule& sch, size_t n, Array* loop_rvs) { + size_t n_loops = loop_rvs->size(); + ICHECK_LE(n, n_loops); + LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); + sch->Vectorize(fused); + for (size_t i = n_loops - n; i < n_loops; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) { + if (max_step > 0) { + sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); + } +} + +} // namespace tir + +namespace meta_schedule { + +using tir::Schedule; + +class RewriteParallelVectorizeUnrollNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const Schedule& sch) final { + tir::ParsedAnnotation parsed_root; + tir::BlockRV root_rv{nullptr}; + while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { + for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { + Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + tir::ParsedAnnotation parsed = parsed_root; + tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]); + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); +}; + +Postproc Postproc::RewriteParallelVectorizeUnroll() { + ObjectPtr n = + make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") + .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..9988e874b81d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Move_PUV: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1) + T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + + +@T.prim_func +def Move_PUV0(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + for i0_j0_fused in T.parallel(0, 8192): + for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): + for k1_fused in T.vectorized(0, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1_fused) + T.where( + i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024 + and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024 + and k0 * 32 + k1_fused < 1024 + ) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): + postproc = RewriteParallelVectorizeUnroll() + sch = Schedule(Move_PUV) + assert postproc.apply(sch) + print(sch.mod["main"].script()) + mod = tvm.tir.transform.Simplify()(sch.mod) + tvm.ir.assert_structural_equal(mod["main"], Move_PUV0) + + +if __name__ == "__main__": + test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize()