diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 97cf467cca07..e69de29bb2d1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,158 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Github code owners file -# This file is used as a convenient tool to map -# committers' areas of expertise and faciliate the review process. -# -# This may not be the non-comprehensive list and is meant to be -# updated over time. - -# Per ASF policy, committer have global write permission. -# We normally recommend committers to shepherd code in their area of expertise. -* @apache/tvm-committers - -# Order is important; the last matching pattern takes the most precedence. -# The sub modules should be ordered first by depth. -# Making sure we append new sub-module rules after exisiting modules rules. - -############################## -# Top-level Fallbacks -############################## -include/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -src/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -apps/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -python/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics - -# Thirdparty license audit -3rdparty/** @tqchen @jroesch -licenses/** @tqchen @jroesch - -# JVM language -jvm/** @yzhliu - -# Golang -golang/** @srkreddy1238 - -# WASM -web/** @tqchen @jroesch - -# Docker -docker/** @areusch @leandron @jroesch - -# Conda -conda/** @tqchen @junrushao1994 @comaniac - -# CMake -cmake/** @jroesch @tqchen @areusch @junrushao1994 @comaniac - -# rust bindings -rust/** @jroesch @nhynes @nhynes - -# vta -vta/** @tmoreau89 @vegaluisjose - -# docs -docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon -tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon - -# tests -tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon - -############################## -# Specific modules -############################## - -# automation related -src/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy -include/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy -python/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy - -python/tvm/autotvm/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 - -# node system and reflection -src/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -include/tvm/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac - -# ir: Common IR -src/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -include/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -python/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac - -# tir -src/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy -include/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy -python/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy - -# te -src/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were -include/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were -python/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were - -# target -src/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi -include/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi -python/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi - -# arith: Arithmetic module and simplifiers -src/arith/** @tqchen @junrushao1994 @vinx13 -include/tvm/arith/** @tqchen @junrushao1994 @vinx13 -python/tvm/arith/** @tqchen @junrushao1994 @vinx13 - -# parser -src/parser/** @jroesch @slyubomirsky - -# runtime -src/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 -include/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 -python/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 - -# runtime/micro -src/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm -src/runtime/crt/** @areusch @liangfu @tmoreau89 @manupa-arm -include/tvm/runtime/crt/** @areusch @liangfu @tmoreau89 @manupa-arm -include/tvm/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm -python/tvm/micro/** @areusch @liangfu @tmoreau89 @manupa-arm - -# relay -src/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -include/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -python/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 - - -# relay/qnn -src/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang -inlcude/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang -python/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang - -# relay/backend/contrib: BYOC -src/relay/backend/contrib/** @zhiics @trevor-m @comaniac @mbaret @manupa-arm - -# relay/frontends -python/tvm/relay/frontend/** @jwfromm @mbrookhart @srkreddy1238 @siju-samuel @Huyuwei @hlu1 @kazum @PariksheetPinjari909 - -# topi: Operator definitions -src/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -include/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -python/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 - - -# tvm/driver/ -python/tvm/driver/** @leandron @jwfromm @tqchen @jroesch - -# tvm/driver/tvmc -python/tvm/driver/tvmc/** @leandron @jwfromm diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 7cc4efe6b012..9ee3c3da8a83 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -169,6 +169,14 @@ Map ConvertDomMap(const std::unordered_map& * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const Map& dom_map); +/*! + * \brief Same as EvalSet, but takes Map + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values of e. + */ +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -177,6 +185,15 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map); * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); +/*! + * \brief Same as EvalSet, but takes Array + * + * \param exprs The expressions to be evaluated. + * \param dom_map The domain of each variable. + * \return An array of integer sets that can cover all the possible values. + */ +Array EvalSet(const Array& exprs, const Map& dom_map); + /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index d7b23dd79c25..2b809459155e 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -32,7 +32,7 @@ class BuilderInputNode : public runtime::Object { IRModule mod; /*! \brief The target to be built for. */ Target target; - /*! \brief The optional parameters used for build */ + /*! \brief Parameters for Relay build module. */ Optional> params; void VisitAttrs(tvm::AttrVisitor* v) { @@ -55,7 +55,7 @@ class BuilderInput : public runtime::ObjectRef { * \brief Constructor of BuilderInput. * \param mod The IRModule to be built. * \param target The target to be built for. - * \param params The optional parameters used for build + * \param params Parameters for Relay build module. */ TVM_DLL explicit BuilderInput(IRModule mod, Target target, Optional> params = NullOpt); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 1675bcce05ed..7d1fecfaa7f5 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -137,6 +137,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: * - NullOpt on CPU * - [blockIdx.x, vthread.x, threadIdx.x] on GPU + * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. * NullOpt means disable vectorization @@ -146,6 +147,7 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // Optional> tile_binds, // + bool use_tensor_core, // Optional max_innermost_factor, // Optional> vector_load_lens, // Optional> reuse_read, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 428a2e80f4dd..7a7599b0a4f8 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -82,6 +82,8 @@ class TuneContextNode : public runtime::Object { v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_stopped", &is_stopped); + v->Visit("builder_results", &builder_results); + v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); } diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1ab911b756df..86a171774015 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,65 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! \brief A mapping from multi-dimensional indices to another set of multi-dimensional indices */ +class IndexMapNode : public Object { + public: + /*! \brief The source indices */ + Array src_iters; + /*! \brief The target indices */ + Array tgt_iters; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("src_iters", &src_iters); + v->Visit("tgt_iters", &tgt_iters); + } + + /*! + * \brief Take `inputs` as the source indices and return the corresponding target indices. + * \param inputs The source indices. + * \return The target indices. + */ + Array Apply(const Array& inputs) const; + + /*! + * \brief Map a shape to the output space + * \param shape The shape in the source space + * \return The shape in the target space + */ + Array MapShape(const Array& shape) const; + + /*! + * \brief Convert to string representation in Python. + * \return The stringified lambda expression in Python. + */ + String ToPythonString() const; + + static constexpr const char* _type_key = "tir.IndexMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); +}; + +/*! + * \brief Managed reference to IndexMapNode. + * \sa IndexMapNode + */ +class IndexMap : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param src_iters The source indices. + * \param tgt_iters The target indices. + */ + explicit IndexMap(Array src_iters, Array tgt_iters); + /*! + * \brief Create an index map from a packed function + * \param ndim The number of dimensions + * \param func The function to be applied + * \return The created index map + */ + static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func); + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); +}; + /*! * \brief Tensor intrinsics for tensorization */ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index be06b44820cd..89871f0d6352 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -355,6 +355,11 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /******** Schedule: Data movement ********/ + virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -521,6 +526,21 @@ class ScheduleNode : public runtime::Object { */ virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; + /******** Schedule: Layout transformation ********/ + /*! + * \brief Apply a transformation represented by IndexMap to buffer + * \details The indices and the access region to the target buffer is transformed by the given + * index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either + * a function parameter, or allocated in a block (it cannot be a buffer subregion created via + * 'match_buffer'). + * \param block_rv The block that accesses the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param is_write_index Whether the buffer_index is the index of the block's write region. + * \param index_map The transformation to apply. + */ + virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index, + const IndexMap& index_map) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index edb789b0bd7f..7b07146f446c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1224,7 +1224,7 @@ class BlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; -/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ +/*! \brief namespace of possible attributes in AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. /*! \brief Mark launching extent of thread, used by device API. */ @@ -1394,6 +1394,54 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl /*! \brief Mark auto-unroll setting on the block. */ constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; +/*! + * \brief Mark that the block need to add predicate for block var bounds during lowering + */ +constexpr const char* require_block_var_bound_predicate = "require_bound_predicate"; + +/*! + * \brief Mark that the block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; + +/*! \brief Mark that tensor core is enabled in the PrimExpr */ +constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"; + +/*! + * \brief Mark a block as generated by cache_read or cache_write block. + * 0 means cache_read; 1 means cache_write. + * \sa meta_schedule_cache_type_read + * \sa meta_schedule_cache_type_write + */ +constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_read = 0; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_write = 1; + +/*! \brief Mark the scope of the software pipeline */ +constexpr const char* software_pipeline_scope = "software_pipeline_scope"; + +/*! \brief Mark the stage of a statement in the software pipeline */ +constexpr const char* software_pipeline_stage = "software_pipeline_stage"; + +/*! \brief Mark the order of a statement in the software pipeline */ +constexpr const char* software_pipeline_order = "software_pipeline_order"; + +/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify + * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the + * prologue, the body, and the epilogue of the software pipeline. + */ +constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline_stage"; + +/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify + * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the + * prologue, the body, and the epilogue of the software pipeline. + */ +constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 0b4ace20078c..3ddebc5bf0f0 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -406,9 +406,12 @@ inline T Substitute(T input, const std::unordered_map& * \param stmt_or_expr The ir to be visited. * \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the * children of the node + * \param visit_init_block Whether or not to visit the init block + * children of the node */ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr, - const std::function& fvisit); + const std::function& fvisit, + bool visit_init_block = true); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 6d17c396c12f..0904b9092cac 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -383,6 +383,20 @@ TVM_DLL Pass LowerInitBlock(); */ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + /*! * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the * corresponding iter_values in BlockRealize, for opaque blocks by removing all @@ -500,6 +514,24 @@ TVM_DLL Pass ConvertForLoopsToSerial(); */ TVM_DLL Pass UnifiedStaticMemoryPlanner(); +/*! + * \brief Transform annotated loops into pipelined one that ovarlaps producers and consumers. + * \return The IR transform pass. + */ +TVM_DLL Pass InjectSoftwarePipeline(); + +/*! + * \brief Automatically do memory optimizations for auto copy blocks + * \return The pass. + */ +TVM_DLL Pass LowerAutoCopy(); + +/*! + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + * \return The pass. + */ +TVM_DLL Pass RenormalizeSplitPattern(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 1ac58e18db3e..07d74b9b6fb9 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -255,7 +255,7 @@ class IterVarNode : public Object { IterVarType iter_type; /*! * \brief additional tag on the iteration variable, - * set this if this is binded already to a known thread tag. + * set this if this is bound already to a known thread tag. */ String thread_tag; /*! diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 34823ebb1781..2bea0a5da6d9 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -231,6 +231,8 @@ def build( elif isinstance(inputs, PrimFunc): input_mod = lower(inputs, name=name) elif isinstance(inputs, tvm.IRModule): + if name is not None and name != "default_function": + warnings.warn("Specifying name with IRModule input is useless") input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index ef5d3ca130a7..5d658f0fec23 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" -from typing import Dict, List, Optional +from typing import List, Optional, Dict +from tvm.runtime import NDArray from tvm._ffi import register_object from tvm.ir import IRModule from tvm.runtime import NDArray, Object @@ -42,6 +43,7 @@ class BuilderInput(Object): mod: IRModule target: Target + params: Optional[Dict[str, NDArray]] def __init__( self, diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index 954da87e6a63..f53949904472 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -129,7 +129,7 @@ def __init__( super().__init__() if max_workers is None: - max_workers = cpu_count() + max_workers = cpu_count(logical=True) logger.info("LocalBuilder: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 96361e739186..50fbb0e0852b 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -21,4 +21,5 @@ from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock +from .rewrite_tensor_core import RewriteTensorCore from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensor_core.py b/python/tvm/meta_schedule/postproc/rewrite_tensor_core.py new file mode 100644 index 000000000000..f858fed3a6d4 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_tensor_core.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that tensorize Tensor Core related components.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteTensorCore") +class RewriteTensorCore(Postproc): + """A postprocessor that tensorize Tensor Core related components.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteTensorCore, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index a6fb169fa590..6085de809767 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -38,6 +38,8 @@ run_evaluator_common, ) +logger = logging.getLogger(__name__) + logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 2ff49168d0c6..17f8e22033f0 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -53,6 +53,8 @@ class MultiLevelTiling(ScheduleRule): For each level of tiles, which thread axis it is bound to. Recommended: - None on CPU - [blockIdx.x, vthread.x, threadIdx.x] on GPU + use_tensor_core : bool + Whether to apply tensor core wmma intrinsic for the computation max_innermost_factor : Optional[int] The maximum size of the innermost factor. None means no limit vector_load_lens : Optional[List[int]] @@ -68,6 +70,7 @@ def __init__( self, structure: str, tile_binds: Optional[List[str]] = None, + use_tensor_core: bool = False, max_innermost_factor: Optional[int] = None, vector_load_lens: Optional[List[int]] = None, reuse_read: Optional[ReuseType] = None, @@ -77,6 +80,7 @@ def __init__( _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member structure, tile_binds, + use_tensor_core, max_innermost_factor, vector_load_lens, reuse_read.as_dict() if reuse_read is not None else None, diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 2f1ffdd407fa..bf9287a8eb18 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -18,7 +18,6 @@ from enum import Enum from typing import Dict, Tuple -import tvm.relay.testing # pylint: disable=unused-import from tvm import relay from tvm.ir import IRModule from tvm.runtime import NDArray @@ -34,9 +33,74 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name # Specify the type of each model MODEL_TYPES = { + # Image classification models "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnet50": MODEL_TYPE.IMAGE_CLASSIFICATION, + "alexnet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "vgg16": MODEL_TYPE.IMAGE_CLASSIFICATION, + "squeezenet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet121": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet161": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet169": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet201": MODEL_TYPE.IMAGE_CLASSIFICATION, + "inception_v3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "googlenet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "shufflenet_v2_x1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_large": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_small": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnext50_32x4d": MODEL_TYPE.IMAGE_CLASSIFICATION, + "wide_resnet50_2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mnasnet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b1": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b4": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b5": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b6": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b7": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + # Semantic Segmentation models + "fcn_resnet50": MODEL_TYPE.SEGMENTATION, + "fcn_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet50": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + "lraspp_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + # Object detection models + # @Sung: Following networks are not runnable since Torch frontend cannot handle aten::remainder. + # "retinanet_resnet50_fpn", "keypointrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_320_fpn": MODEL_TYPE.OBJECT_DETECTION, + "retinanet_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "maskrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "keypointrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "ssd300_vgg16": MODEL_TYPE.OBJECT_DETECTION, + "ssdlite320_mobilenet_v3_large": MODEL_TYPE.OBJECT_DETECTION, + # Video classification + "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + # Text classification + "bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION, "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_large": MODEL_TYPE.TEXT_CLASSIFICATION, } @@ -73,31 +137,104 @@ def do_trace(model, inp): return model_trace # Load model from torchvision - if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + model = getattr(models, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + model = getattr(models.segmentation, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + model = getattr(models.detection, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + model = getattr(models.video, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: os.environ["TOKENIZERS_PARALLELISM"] = "false" - model = transformers.BertModel( - transformers.BertConfig( + config_dict = { + "bert_tiny": transformers.BertConfig( + num_hidden_layers=6, + hidden_size=512, + intermediate_size=2048, + num_attention_heads=8, + return_dict=False, + ), + "bert_base": transformers.BertConfig( num_hidden_layers=12, hidden_size=768, intermediate_size=3072, num_attention_heads=12, return_dict=False, - ) - ) + ), + "bert_medium": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + "bert_large": transformers.BertConfig( + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + } + configuration = config_dict[model_name] + model = transformers.BertModel(configuration) + A = torch.randint(10000, input_shape) + model.eval() - input_data = torch.randint(10000, input_shape) + scripted_model = torch.jit.trace(model, [A], strict=False) + shape_list = [("input_ids", input_shape)] - scripted_model = torch.jit.trace(model, [input_data], strict=False) - elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: - model = getattr(models, model_name)() - # Setup input - input_data = torch.randn(input_shape).type(torch.float32) - shape_list = [("input0", input_shape)] - # Get trace. Depending on the model type, wrapper may be necessary. - scripted_model = do_trace(model, input_data) + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params else: raise ValueError("Unsupported model in Torch model zoo.") + # Setup input + input_data = torch.randn(input_shape).type(torch.float32) + shape_list = [("input0", input_shape)] + + # Get trace. Depending on the model type, wrapper may be necessary. + if MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + scripted_model = do_trace(wrapped_model, input_data) + + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + + def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + class TraceWrapper(torch.nn.Module): # type: ignore + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + _ = wrapped_model(input_data) + scripted_model = do_trace(wrapped_model, input_data) + else: + scripted_model = do_trace(model, input_data) + # Convert torch model to relay module mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) return mod, params @@ -110,6 +247,8 @@ def get_network( dtype: str = "float32", ) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: """Get the symbol definition and random weight of a network""" + import tvm.relay.testing # pylint: disable=import-outside-toplevel,unused-import + # meta-schedule prefers NHWC layout if layout == "NHWC": image_shape = (224, 224, 3) diff --git a/python/tvm/meta_schedule/testing/run_ansor.sh b/python/tvm/meta_schedule/testing/run_ansor.sh new file mode 100644 index 000000000000..d5ea9df34485 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_ansor.sh @@ -0,0 +1,40 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +NUM_TRIALS=800 +LOG_DIR=$HOME/logs/ansor-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor_cpu.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/python/tvm/meta_schedule/testing/run_meta_schedule.sh b/python/tvm/meta_schedule/testing/run_meta_schedule.sh new file mode 100644 index 000000000000..fa0c7ca42562 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_meta_schedule.sh @@ -0,0 +1,38 @@ +# set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +LOG_DIR=$HOME/logs/ms-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_tune_te_cpu.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials 5000 \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +# run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +# run NRM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index b149f20c52e3..93af4febaf09 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Default schedule rules""" +from typing import List from tvm.meta_schedule.schedule_rule import ( AddRFactor, AutoInline, @@ -28,6 +29,26 @@ from tvm.target import Target +def get(target: Target) -> List[ScheduleRule]: + """Default schedule rules""" + if target.kind.name == "llvm": + return [ + auto_inline(target), + add_rfactor(target), + multi_level_tiling(target), + parallel_vectorize_unroll(target), + random_compute_location(target), + ] + if target.kind.name == "cuda": + return [ + multi_level_tiling(target), + auto_inline_after_tiling(target), + cross_thread_reduction(target), + parallel_vectorize_unroll(target), + ] + raise NotImplementedError(f"{target.kind.name} is not supported") + + def auto_inline(target: Target) -> ScheduleRule: """Default schedule rules for auto inline""" if target.kind.name == "llvm": @@ -53,6 +74,31 @@ def auto_inline(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def auto_inline_after_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline after tiling""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def add_rfactor(target: Target) -> ScheduleRule: """Default schedule rules for with add_rfactor""" if target.kind.name == "llvm": @@ -109,6 +155,29 @@ def random_compute_location(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling with tensor core and reuse""" + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], + use_tensor_core=True, + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def parallel_vectorize_unroll(target: Target) -> ScheduleRule: """Default schedule rules for with parallel-vectorize-unroll""" if target.kind.name == "llvm": @@ -126,3 +195,17 @@ def parallel_vectorize_unroll(target: Target) -> ScheduleRule: unroll_explicit=True, ) raise NotImplementedError(f"{target.kind.name} is not supported") + + +def add_rfactor(target: Target) -> ScheduleRule: + """Default schedule rules for with add_rfactor""" + if target.kind.name == "llvm": + return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def cross_thread_reduction(target: Target) -> ScheduleRule: + """Default schedule rules for with cross-thread reduction""" + if target.kind.name == "cuda": + return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 10e31e7213cb..4abf090ddf95 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -15,11 +15,31 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from typing import List +from typing import List, Union -from tvm.tir import Schedule +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.target import Target +from tvm.tir import PrimFunc, Schedule from tvm.tir.schedule import Trace +from . import schedule_rule as sch_rule + + +def create_context(mod: Union[IRModule, PrimFunc], target: Target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=sch_rule.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + def check_trace(spaces: List[Schedule], expected: List[List[str]]): expected_traces = {"\n".join(t) for t in expected} @@ -31,3 +51,15 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]): actual_traces.add(str_trace) assert str_trace in expected_traces, "\n" + str_trace assert len(expected_traces) == len(actual_traces) + + +def debug_print_spaces(spaces: List[Schedule], trace_as_list: bool) -> None: + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + if trace_as_list: + print(str(trace).strip().splitlines()) + else: + print(trace) diff --git a/python/tvm/meta_schedule/testing/test_ansor_cpu.py b/python/tvm/meta_schedule/testing/test_ansor_cpu.py new file mode 100644 index 000000000000..36e42c2ab636 --- /dev/null +++ b/python/tvm/meta_schedule/testing/test_ansor_cpu.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = rpc_config.count_num_servers(allow_missing=False) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] + workload_func = auto_scheduler.register_workload(workload_func) + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ), + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/test_tune_te_cpu.py b/python/tvm/meta_schedule/testing/test_tune_te_cpu.py new file mode 100644 index 000000000000..b48fc4f9a04c --- /dev/null +++ b/python/tvm/meta_schedule/testing/test_tune_te_cpu.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + alloc_repeat=3, + max_workers=ARGS.rpc_workers, + ) + sch: tir.Schedule = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.ReplayTraceConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + ), + runner=runner, + task_name=ARGS.workload, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/tir_tensor_intrin.py b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py new file mode 100644 index 000000000000..76f1920c2777 --- /dev/null +++ b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A collection of TIR tensor intrinsics""" +# pylint: disable=missing-function-docstring +import tvm +from tvm import tir +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks +# fmt: off + +@T.prim_func +def tensorcore_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def tensorcore_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([ + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 16], + B[vj : vj + 16, vk : vk + 16], + ]) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, v0 + i) + C[0] = C[0] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + T.reads([C[0 : 1], A[v0 : v0 + 4], B[v0 : v0 + 4]]) + T.writes([C[0 : 1]]) + T.evaluate(T.call_extern( # pylint: disable=redundant-keyword-arg + "vec4add", + C.data, C.elem_offset, + A.data, A.elem_offset, + B.data, B.elem_offset, + dtype="int32", + )) + +@T.prim_func +def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast(B[vkk, vjj], + "float32") + + +@T.prim_func +def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, + scope="wmma.accumulator") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([C[vi: vi+16, vj: vj+16], A[vi: vi+16, vk: vk+16], B[vk: vk+16, vj: vj+16]]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.data, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), + B.data, B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16), + C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + dtype="handle")) + + +@T.prim_func +def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_a") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("load"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi+16, vj: vj+16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) + + +@T.prim_func +def wmma_load_b_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("load"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi+16, vj: vj+16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) + + +@T.prim_func +def wmma_fill_desc(c: T.handle) -> None: + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("init"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = T.float32(0) + + +@T.prim_func +def wmma_fill_impl(c: T.handle) -> None: + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads([]) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), T.float32(0), dtype="handle")) + + +@T.prim_func +def wmma_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("store"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0]) + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi + 16, vj: vj + 16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_store_matrix_sync( + A.data, 16, 16, 16, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), C.access_ptr("w"), s1, "row_major", + dtype="handle")) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks + +TENSORCORE_WMMA = tir.TensorIntrin.register( + "test.tensorcore.wmma", + tensorcore_desc, + tensorcore_impl, +) + +NEON_DOT = tir.TensorIntrin.register( + "test.neon.dot", + dot_product_desc, + dot_product_impl, +) + +WMMA_SYNC = tir.TensorIntrin.register( + "wmma_sync", + wmma_sync_desc, + wmma_sync_impl, +) + +WMMA_LOAD_A = tir.TensorIntrin.register( + "wmma_load_a", + wmma_load_a_desc, + wmma_load_a_impl, +) + +WMMA_LOAD_B = tir.TensorIntrin.register( + "wmma_load_b", + wmma_load_b_desc, + wmma_load_b_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_fill", + wmma_fill_desc, + wmma_fill_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_store", + wmma_store_desc, + wmma_store_impl, +) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index faf61f5de3e6..6b2aaec1e5cc 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -170,7 +170,6 @@ def _sch_rules() -> List[ScheduleRule]: M.AutoInline( into_producer=True, into_consumer=True, - # into_cache_only=False, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, @@ -208,7 +207,7 @@ def _mutator_probs() -> Dict[Mutator, float]: ) return { - # M.MutateTileSize(): 0.9, + M.MutateTileSize(): 0.9, M.MutateUnroll(): 0.1, } diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index fc953771bf21..db3261e7a392 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -20,7 +20,7 @@ import synr import tvm.tir -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind @@ -483,8 +483,14 @@ def create_loop_info( """ assert self.context and self.node, "call 'exit_scope' before 'enter_scope'" extent = end if begin == 0 else self.context.analyzer.simplify(end - begin) - self.annotations = annotations - self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, self.annotations)) + self.annotations: Mapping[str, Object] = {} + if annotations is not None: + self.annotations = { + key: String(val) if isinstance(val, str) else val + for key, val in annotations.items() + } + + self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations)) @register diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 5854b9369c16..fa91fcb0200b 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc, TensorIntrin +from .function import PrimFunc, IndexMap, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index bcebab9ddc0a..42bd52930b1a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,7 +16,8 @@ # under the License. """Function data types.""" -from typing import Mapping, Union +from typing import Callable, List, Mapping, Union +import inspect import tvm._ffi import tvm.runtime @@ -210,3 +211,65 @@ def get(name: str): The TensorIntrin with the specified name. """ return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore + + +@tvm._ffi.register_object("tir.IndexMap") +class IndexMap(Object): + """A mapping from multi-dimensional indices to another set of multi-dimensional indices + + Parameters + ---------- + src_iters : list of Var + The source indices + tgt_iters : list of PrimExpr + The target indices + """ + + src_iters: List[Var] + """The source indices""" + + tgt_iters: List[PrimExpr] + """The target indices""" + + def __init__(self, src_iters: List[Var], tgt_iters: List[PrimExpr]): + self._init_handle_by_constructor( + _ffi_api.IndexMap, # type: ignore # pylint: disable=no-member + src_iters, + tgt_iters, + ) + + def apply(self, indices: List[PrimExpr]) -> List[PrimExpr]: + """Apply the index map to a set of indices + + Parameters + ---------- + indices : List[PriExpr] + The indices to be mapped + + Returns + ------- + result : List[PrimExpr] + The mapped indices + """ + return _ffi_api.IndexMapApply(self, indices) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_func(func: Callable) -> "IndexMap": + """Create an index map from a function + + Parameters + ---------- + func : Callable + The function to map from source indices to target indices + """ + + def wrap(args: List[Var]) -> List[PrimExpr]: + result = func(*args) + if isinstance(result, tuple): + return list(result) + if not isinstance(result, list): + result = [result] + return result + + ndim = len(inspect.signature(func).parameters) + return _ffi_api.IndexMapFromFunc(ndim, wrap) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5f0e169c43e3..66ac7b9d772b 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -22,3 +22,5 @@ from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace + +from . import analysis diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py new file mode 100644 index 000000000000..7c0c77a372f3 --- /dev/null +++ b/python/tvm/tir/schedule/analysis.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Analysis used in TensorIR scheduling""" +from typing import List, Optional + +from ..buffer import Buffer +from ..stmt import For +from ..expr import PrimExpr +from ..function import IndexMap + +from . import _ffi_api + + +def suggest_index_map( + buffer: Buffer, + indices: List[PrimExpr], + loops: List[For], + predicate: PrimExpr, +) -> Optional[IndexMap]: + """Provided the access pattern to a buffer, suggest one of the possible layout + transformation to minimize the locality of the access pattern. + + Parameters + ---------- + buffer : Buffer + The buffer to be transformed. + indices : List[PrimExpr] + The access pattern to the buffer. + loops : List[For] + The loops above the buffer. + predicate : PrimExpr + The predicate of the access. + + Returns + ------- + index_map : Optional[IndexMap] + The suggested index map. None if no transformation is suggested. + """ + return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member + buffer, + indices, + loops, + predicate, + ) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 96fa21f30020..51cf67f92542 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc +from tvm.tir import Block, FloatImm, For, IntImm, IndexMap, PrimFunc from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod @@ -1055,6 +1055,30 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + ########## Schedule: Data movement ########## + + def read_at( + self, + loop: LoopRV, + block: BlockRV, + read_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member + self, loop, block, read_buffer_index, storage_scope + ) + + def write_at( + self, + loop: LoopRV, + block: BlockRV, + write_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member + self, loop, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## @type_checked @@ -2111,6 +2135,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: self, block_or_loop, ann_key ) + ########## Schedule: Layout transformation ########## + + def transform_layout( + self, + block: BlockRV, + buffer_index: int, + is_write_index: bool, + index_map: Union[IndexMap, Callable], + ) -> None: + """Apply a transformation represented by IndexMap to buffer + + Parameters + ---------- + block_rv : BlockRV + The block that accesses the target buffer + buffer_index: int + The index of the buffer in block's read or write region + is_write_index : bool + Whether the buffer_index is the index of the block's write region + index_map : Union[IndexMap, Callable] + The transformation to apply + + Examples + -------- + + Before transform_layout, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_transform_layout(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do transform_layout: + + .. code-block:: python + + sch = tir.Schedule(before_storage_align) + sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True, + index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) + print(sch.mod["main"].script()) + + After applying transform_layout, the IR becomes: + + .. code-block:: python + + @T.prim_func + def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((8, 8, 16, 16), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member + self, block, buffer_index, is_write_index, index_map + ) + ########## Schedule: Misc ########## @type_checked diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index e2bcd6cf795b..b543225569b0 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -635,6 +635,18 @@ def PlanAndUpdateBufferAllocationLocation(): return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore +def ApplyBlockBoundPredicate(): + """Narrow the extents of some loops by checking whether some constraints in the block iter + bound predicates can be directly applied on the loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ApplyBlockBoundPredicate() # type: ignore + + def ConvertBlocksToOpaque(): """Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into @@ -760,3 +772,36 @@ def ConvertForLoopsToSerial(): The result pass """ return _ffi_api.ConvertForLoopsToSerial() # type: ignore + + +def InjectSoftwarePipeline(): + """Transform annotated loops into pipelined one that parallelize producers and consumers + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectSoftwarePipeline() # type: ignore + + +def LowerAutoCopy(): + """Automatically do memory optimizations for auto copy blocks + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerAutoCopy() + + +def RenomalizeSplitPattern(): + """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RenormalizeSplitPattern() diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 94316c1e485f..0b9c47833843 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -776,6 +776,17 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom return EvalSet(e, ConvertDomMap(dom_map)); } +Array EvalSet(const Array& exprs, const Map& dom_map) { + Array result; + result.reserve(exprs.size()); + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + for (const PrimExpr& e : exprs) { + result.push_back(m.Eval(e)); + } + return result; +} + IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 6fff2a23ccfe..a4de6592ca13 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min, - const PrimExpr& predicate_induced_max) { + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, + const Optional& predicate_induced_min, + const Optional& predicate_induced_max) { return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, predicate_induced_max); } @@ -494,14 +495,16 @@ class IterMapRewriter : public ExprMutator { * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min, - PrimExpr predicate_induced_max) { + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional predicate_induced_min, + Optional predicate_induced_max) { // normalize to zero base PrimExpr base = expr->base; if (!is_zero(base)) { expr.CopyOnWrite()->base = 0; - if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base; - if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base; + if (predicate_induced_min.defined()) + predicate_induced_min = predicate_induced_min.value() - base; + if (predicate_induced_max.defined()) + predicate_induced_max = predicate_induced_max.value() - base; } if (expr->args.size() < 1) return expr; Optional opt = TryFuseIters(expr); @@ -522,10 +525,10 @@ class IterMapRewriter : public ExprMutator { PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; if (predicate_induced_min.defined()) { - iter_min = max(predicate_induced_min, iter_min); + iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { - iter_max = min(predicate_induced_max, iter_max); + iter_max = min(predicate_induced_max.value(), iter_max); } if (!is_zero(iter_min)) { // structured form's offset should be updated @@ -536,7 +539,6 @@ class IterMapRewriter : public ExprMutator { } mark.CopyOnWrite()->extent = iter_max - iter_min; sum_fuse_map_[flattened_form] = {mark, iter_min}; - // we need to note down the flattened form of constrained iterators // to check the validity of constraints, see also CheckConstraints() constrained_iters_flattened_.push_back(flattened_form); @@ -771,14 +773,15 @@ class IterMapRewriter : public ExprMutator { struct IterConstraint { // The expr of the iter PrimExpr iter; - // The expr of the lower_bound - PrimExpr lower_bound; - // The expr of the upper_bound - PrimExpr upper_bound; + // The expr of the lower_bound, maybe undefined + Optional lower_bound; + // The expr of the upper_bound, maybe undefined + Optional upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size) + IterConstraint(PrimExpr iter, Optional lower_bound, Optional upper_bound, + size_t size) : iter(std::move(iter)), lower_bound(std::move(lower_bound)), upper_bound(std::move(upper_bound)), @@ -788,11 +791,11 @@ struct IterConstraint { /*! * \brief Split the predicate into `(a < b) && (c < d) && ...` * \param pred The predicate to be split. + * \param result The result of predicate split. * \return A list of IterConstraint, empty if the split failed. */ -std::vector MatchBoundConstraints(PrimExpr pred, - const Map& input_iters) { - std::vector result; +bool MatchBoundConstraints(PrimExpr pred, Map& input_iters, + std::vector& result) { arith::PVar lhs, rhs, rest; for (;;) { // try extract comparisions @@ -821,14 +824,14 @@ std::vector MatchBoundConstraints(PrimExpr pred, is_equal = true; is_finish = true; } else { - return std::vector(); + return false; } PrimExpr lhs_expr = lhs.Eval(); PrimExpr rhs_expr = rhs.Eval(); // we only accept predicate of integers if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { - return std::vector(); + return false; } // determine iter and bound, if we can not distinguish them simply, // try divide (lhs - rhs) into itervar aware and itervar free parts @@ -864,24 +867,25 @@ std::vector MatchBoundConstraints(PrimExpr pred, lhs_expr = analyzer.Simplify(lhs_expr); rhs_expr = analyzer.Simplify(rhs_expr); } - PrimExpr lower_bound, upper_bound, iter; + Optional lower_bound = NullOpt, upper_bound = NullOpt; + PrimExpr iter; if (is_greater) { if (bound_at_left) { - // bound > iter + // bound > iter / bound >= iter upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; iter = rhs_expr; } else { - // iter > bound + // iter > bound / iter >= bound lower_bound = is_equal ? rhs_expr : rhs_expr + 1; iter = lhs_expr; } } else { if (bound_at_left) { - // bound < iter + // bound < iter / bound <= iter lower_bound = is_equal ? lhs_expr : lhs_expr + 1; iter = rhs_expr; } else { - // iter < bound + // iter < bound / iter <= bound upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; iter = lhs_expr; } @@ -892,7 +896,7 @@ std::vector MatchBoundConstraints(PrimExpr pred, } pred = rest.Eval(); } - return result; + return true; } bool IterRangeSanityCheck(const Map& iter_ranges) { @@ -912,8 +916,10 @@ Array DetectIterMap(const Array& indices, const Map(); - std::vector constraints = MatchBoundConstraints(predicate, input_iters); - if (!is_one(predicate) && constraints.empty()) { + Map constrained_input_iters = input_iters; + std::vector constraints; + if (!is_one(predicate) && + !MatchBoundConstraints(predicate, constrained_input_iters, constraints)) { diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Fail to collect constraints from iteration predicate: " << predicate); return Array(); @@ -930,10 +936,11 @@ Array DetectIterMap(const Array& indices, const Map(); } if (!rewriter.CheckConstraints()) { @@ -945,7 +952,10 @@ Array DetectIterMap(const Array& indices, const Map results; for (PrimExpr value : indices) { results.push_back(rewriter.Rewrite(value)); - if (rewriter.unresolved_count() != 0) return Array(); + if (rewriter.unresolved_count() != 0) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Affine mapping detection failed"); + return Array(); + } } // Step1: IterIndependenceChecker checks if the iterator are independent. if (!rewriter.CheckMapping(results, require_bijective)) { @@ -1306,7 +1316,8 @@ class IterMapToExprNormalizer : public ExprMutator { } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { return floordiv(source, expr->lower_factor) * expr->scale; } else { - return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale; + return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) * + expr->scale; } } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index ac176b2623a3..99f90b9be90e 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctorb); + if (b.is_const()) { + int64_t c2 = b.base; + ICHECK(c2 != 0) << "MathError: the divisor is 0"; + Entry a = VisitExpr(op->a); + int64_t coeff = ZeroAwareGCD(a.coeff, c2); + return Entry(coeff, a.base % c2); + } + return Everything(); + } + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..84473337a452 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); + TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), + c2.Eval()->value > 0); // canonicalization rule // will try rewrite again after canonicalization. @@ -771,6 +773,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -780,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 5431499d7c9f..ad49e26aa7fc 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -36,6 +36,8 @@ #include #include +#include "../printer/text_printer.h" + namespace tvm { // Register build pipeline related options @@ -189,6 +191,14 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } +Pass Print() { + auto pass_func = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + LOG(INFO) << tir::AsTVMScript(f); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.Print", {}); +} + Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -245,10 +255,14 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ApplyBlockBoundPredicate()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LowerAutoCopy()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); @@ -268,12 +282,14 @@ Array CreatePassList(bool disable_loop_partition) { if (!disable_storage_rewrite) { pass_list.push_back(tir::transform::StorageRewrite()); } + pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); // PHASE 3 + pass_list.push_back(tir::transform::RenormalizeSplitPattern()); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 42cd5d5d5113..6b025c83902b 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -288,6 +288,7 @@ Sequential PassListForPerStoreFeature() { tir::transform::LowerCrossThreadReduction(), tir::transform::LowerInitBlock(), tir::transform::PlanAndUpdateBufferAllocationLocation(), + tir::transform::ApplyBlockBoundPredicate(), tir::transform::ConvertBlocksToOpaque(), tir::transform::UnifyThreadBinding(), tir::transform::CompactBufferAllocation(), @@ -437,6 +438,7 @@ struct Feature { kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator kPosMixed = 7, // The annotated iterator is a mixed space and reduce iterator + kEnd = 8, }; int64_t num = 0; // The number of iterators with the annotation int64_t prod = 0; // The product of the lengths of iterators with the annotation diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 58c86abadfe9..00f6f94eb7d3 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -33,7 +33,20 @@ class UpdateCostModelNode : public MeasureCallbackNode { ICHECK(task->measure_candidates.defined()) // << "Task's measure candidates must be present!"; CostModel cost_model = task_scheduler->cost_model.value(); - cost_model->Update(task, task->measure_candidates.value(), runner_results); + ICHECK_EQ(measure_candidates.size(), builder_results.size()); + ICHECK_EQ(runner_results.size(), builder_results.size()); + int n = builder_results.size(); + Array pruned_candidate; + Array pruned_runner_result; + pruned_candidate.reserve(n); + pruned_runner_result.reserve(n); + for (int i = 0; i < n; i++) { + if (!builder_results[i]->error_msg.defined()) { + pruned_candidate.push_back(measure_candidates[i]); + pruned_runner_result.push_back(runner_results[i]); + } + } + cost_model->Update(task, pruned_candidate, pruned_runner_result); } static constexpr const char* _type_key = "meta_schedule.UpdateCostModel"; diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index ad8ee9854265..ad94ca5f25dd 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -64,6 +64,26 @@ Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, in return Downcast(inst->inputs[0]); } +/*! + * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_tensor_core_enabled) + * \param sch The schedule + * \param inst The instruction to be parsed + * \return Whether ths parsing is successful + */ +bool ParseTensorCoreAnn(const Schedule& sch, const Instruction& inst) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return false; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + if (ann_key != attr::meta_schedule_tensor_core_enabled) { + return false; + } + return true; +} + } // namespace tir namespace meta_schedule { @@ -97,6 +117,8 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { } else if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { thread_extent_y = new_thread_extent.value()->value; + } else if (tir::ParseTensorCoreAnn(sch, inst)) { + thread_extent_x = 32; } else if (Optional block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) { ICHECK_NE(thread_extent_x, -1); if (vector_lane > 1) { @@ -117,6 +139,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { sch->Vectorize(split[3]); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); + sch->StorageAlign(block, 0, -2, 32, 8); } }); } else { @@ -132,6 +155,7 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { Integer(thread_extent_x)}); sch->Bind(split[2], "threadIdx.x"); sch->Bind(split[1], "threadIdx.y"); + sch->StorageAlign(block, 0, -2, 32, 8); } }); } diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index cea1f5b93c9f..386894723d65 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -135,6 +135,16 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); Array loop_rvs = sch->GetLoops(block_rv); tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + // If the block is the isolation block of tensor core, + // we mark the init block for later postprocessor to handle the tensorization step + if (HasAnn(block_sref, tir::attr::meta_schedule_auto_tensorize, "wmma_fill")) { + sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); + sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize); + Array init_inner_block_rv = sch->GetChildBlocks(init_block_rv); + ICHECK_EQ(init_inner_block_rv.size(), 1); + sch->Annotate(init_inner_block_rv[0], tir::attr::meta_schedule_auto_tensorize, + String("wmma_fill")); + } ++rewritten; } if (rewritten == 0) { diff --git a/src/meta_schedule/postproc/rewrite_tensor_core.cc b/src/meta_schedule/postproc/rewrite_tensor_core.cc new file mode 100644 index 000000000000..68442dec3082 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_tensor_core.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; + +using BlockPosition = std::tuple; + +class RewriteTensorCoreNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteTensorCore"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorCoreNode, PostprocNode); +}; + +void CollectTensorized(const tir::Schedule& sch, const String& func_name, + const tir::PrimFuncNode* func, std::vector& tasks) { + tir::PreOrderVisit( + func->body, + [&](const ObjectRef& obj) -> bool { + if (const auto* block = obj.as()) { + tir::StmtSRef block_sref = sch->GetSRef(block); + if (Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + tasks.push_back(std::make_tuple(block_sref->StmtAs()->name_hint, + func_name, intrin_name.value())); + } + } + return true; + }, + /*visit_init_block=*/false); +} + +bool RewriteTensorCoreNode::Apply(const tir::Schedule& sch) { + std::vector tasks; + for (const auto& kv : sch->mod()->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const tir::PrimFuncNode* prim_func = base_func.as()) { + CollectTensorized(sch, g_var->name_hint, prim_func, tasks); + } + } + for (const BlockPosition& task : tasks) { + // Retrieve the block rv according to the task noted down before + BlockRV block_rv = sch->GetBlock(std::get<0>(task), std::get<1>(task)); + String intrin_name = std::get<2>(task); + sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize); + Optional tiled_loop_rv = TilingwithTensorIntrin(sch, block_rv, intrin_name); + if (!tiled_loop_rv.defined()) continue; + sch->Tensorize(tiled_loop_rv.value(), intrin_name); + } + return true; +} + +Postproc Postproc::RewriteTensorCore() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteTensorCoreNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorCore") + .set_body_typed(Postproc::RewriteTensorCore); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index edf13e36bef4..7b93678b3e94 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -20,6 +20,56 @@ #include "../utils.h" +namespace tvm { +namespace tir { + +class ThreadExtentChecker : private StmtVisitor { + public: + static bool Check(const Stmt& stmt) { + try { + ThreadExtentChecker().VisitStmt(stmt); + return true; + } catch (const dmlc::Error& e) { + return false; + } + } + + private: + void VisitStmt_(const ForNode* loop) { + if (IsThreadIdx(GetThreadScope(loop))) { + if (const int64_t* p_ext = GetLoopIntExtent(loop)) { + thread_extent_product *= *p_ext; + StmtVisitor::VisitStmt_(loop); + thread_extent_product /= *p_ext; + return; + } else { + throw dmlc::Error("Dynamic thread extent"); + } + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockNode* block) { + if (Optional low_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { + if (Optional high_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { + int64_t low = low_inclusive.value()->value; + int64_t high = high_inclusive.value()->value; + if (!(low <= thread_extent_product && thread_extent_product <= high)) { + throw dmlc::Error("Thread extent"); + } + } + } + StmtVisitor::VisitStmt_(block); + } + + int64_t thread_extent_product = 1; +}; + +} // namespace tir +} // namespace tvm + namespace tvm { namespace meta_schedule { @@ -66,6 +116,9 @@ class VerifyGPUCodeNode : public PostprocNode { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; if (const auto* prim_func = base_func.as()) { + if (!tir::ThreadExtentChecker::Check(prim_func->body)) { + return false; + } IRModule lowered{nullptr}; try { auto pass_list = Array(); @@ -77,10 +130,14 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ApplyBlockBoundPredicate()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LowerAutoCopy()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index d0bfff40fcbe..a35398591494 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -44,6 +44,67 @@ std::vector GetReadBufferNDims(const StmtSRef& block_sref) { return results; } +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name) { + Optional opt_tensorize_info = GetTensorizeLoopMapping( + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + if (!opt_tensorize_info) return NullOpt; + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + // Construct a mapping from tir loops back to LoopRVs + Map loop2rv; + { + Array loop_rvs = sch->GetLoops(block_rv); + for (const LoopRV& loop_rv : loop_rvs) { + loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); + } + } + // Split the loops + arith::Analyzer analyzer; + std::unordered_set inner_loops; + std::vector reorder_suffix; + reorder_suffix.resize(info->loop_map.size()); + for (const auto& kv : info->loop_map) { + // Extract mapping (block_loop => desc_loop) + const tir::StmtSRef& block_loop_sref = kv.first; + const tir::ForNode* block_loop = block_loop_sref->StmtAs(); + const tir::ForNode* desc_loop = kv.second.get(); + ICHECK(block_loop != nullptr && desc_loop != nullptr); + // Extract the loop extent + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + const auto* int_block_extent = block_extent.as(); + const auto* int_desc_extent = desc_extent.as(); + ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + // Check divisibility + int64_t total = int_block_extent->value; + int64_t inner = int_desc_extent->value; + ICHECK_EQ(total % inner, 0); + int64_t outer = int_block_extent->value / int_desc_extent->value; + // Do the split + Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + ICHECK_EQ(split.size(), 2); + inner_loops.insert(sch->GetSRef(split[1]).operator->()); + // The inner split will be reordered to the loop domain that is tensorized + int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)); + reorder_suffix[desc_loop_index] = split[1]; + } + // Reorder the loops + std::vector reorder_list; + bool meet = false; + Array all_loops = sch->GetLoops(block_rv); + for (const LoopRV& loop : all_loops) { + if (inner_loops.count(sch->GetSRef(loop).operator->())) { + meet = true; + } else if (meet) { + reorder_list.push_back(loop); + } + } + reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); + sch->Reorder(reorder_list); + ICHECK(!reorder_suffix.empty()); + return reorder_suffix[0]; +} + } // namespace tir } // namespace tvm @@ -113,13 +174,31 @@ struct State { Schedule sch; /*! \brief The block to be tiled */ BlockRV block_rv; + /*! \brief The write cache */ + Optional write_cache; + /*! \brief Indicating if the write cache is generated by cache_write */ + bool write_cache_is_added; /*! \brief The loop tiles */ Array> tiles; + /*! \brief Whether Tensor Core is used for the inner computation */ + bool tensor_core_is_used; + /*! \brief The Tensor Core cache read block A for Tensor Core computation */ + Optional tensor_core_load_A; + /*! \brief The Tensor Core cache read block B for Tensor Core computation */ + Optional tensor_core_load_B; + /*! \brief The Tensor Core cache write block for Tensor Core computation */ + Optional tensor_core_store; /*! \brief Default constructor */ explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, - bool write_cache_is_added = false, Array> tiles = {}) - : sch(sch), block_rv(block_rv), tiles(tiles) {} + bool write_cache_is_added = false, Array> tiles = {}, + bool tensor_core_is_used = false) + : sch(sch), + block_rv(block_rv), + write_cache(write_cache), + write_cache_is_added(write_cache_is_added), + tiles(tiles), + tensor_core_is_used(tensor_core_is_used) {} }; /*! @@ -145,12 +224,67 @@ std::vector SubRule(std::vector states, FLambda sub_rule) { */ class MultiLevelTilingNode : public ScheduleRuleNode { public: + // SubRule 0. detect compute intrin + inline std::vector DetectTensorCore(State state) const; // SubRule 1. add write cache inline std::vector AddWriteReuse(State state) const; // SubRule 2. tile the loop nest inline std::vector TileLoopNest(State state) const; // SubRule 3. add read cache inline std::vector AddReadReuse(State state) const; + // SubRule 4. fuse write cache + inline std::vector FuseWriteReuse(State state) const; + + State TensorCoreLoad(State state) const { + // Add the cache read stage for Tensor Core + state.tensor_core_load_A = state.sch->CacheRead(state.block_rv, 1, "wmma.matrix_a"); + state.tensor_core_load_B = state.sch->CacheRead(state.block_rv, 2, "wmma.matrix_b"); + const Array& r_tiles = state.tiles[r_indices_.back()]; + // Insert cache_read block to the proper place + ICHECK(!r_tiles.empty()) << "ValueError: Cannot find any reduction loop in the block"; + state.sch->ComputeAt(state.tensor_core_load_A.value(), r_tiles.back(), true); + state.sch->ComputeAt(state.tensor_core_load_B.value(), r_tiles.back(), true); + // Annotate the block + state.sch->Annotate(state.tensor_core_load_A.value(), tir::attr::meta_schedule_auto_tensorize, + String("wmma_load_a")); + state.sch->Annotate(state.tensor_core_load_B.value(), tir::attr::meta_schedule_auto_tensorize, + String("wmma_load_b")); + return state; + } + + State TensorCoreStore(State state) const { + // Add the cache read stage for Tensor Core + state.tensor_core_store = state.sch->CacheWrite(state.block_rv, 0, "wmma.accumulator"); + // Annotate the block + state.sch->Annotate(state.tensor_core_store.value(), tir::attr::meta_schedule_auto_tensorize, + String("wmma_store")); + return state; + } + + State TensorCoreStoreFusion(State state, int level) const { + const LoopRV& loop = state.tiles[level].back(); + state.sch->ReverseComputeAt(state.tensor_core_store.value(), loop, true); + return state; + } + + BlockRV GetRootBlockRV(const Schedule& sch, BlockRV block_rv) const { + const tir::StmtSRefNode* block = sch->GetSRef(block_rv).get(); + for (; block->parent != nullptr; block = block->parent) + ; + for (const auto& kv : sch->mod()->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + const tir::BlockNode* root = func->body.as()->block.get(); + if (root == block->StmtAs()) { + BlockRV root_rv = sch->GetBlock(root->name_hint, gv->name_hint); + return root_rv; + } + } + } + ICHECK(false) << "Ill schedule data structure"; + throw; + } // Do nothing; Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final { @@ -172,9 +306,11 @@ class MultiLevelTilingNode : public ScheduleRuleNode { sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); std::vector states{State(sch, block_rv)}; - states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return DetectTensorCore(state); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return FuseWriteReuse(state); }); Array results; for (auto&& state : states) { results.push_back(std::move(state.sch)); @@ -191,6 +327,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { String structure; /*! \brief For each level of tiles, which thread axis it is bound to */ Array tile_binds; + /*! \brief Whether to use Tensor Core */ + bool use_tensor_core; /*! \brief The maximum size of the innermost factor */ int max_innermost_factor; /*! \brief The length of vector lane in vectorized cooperative fetching */ @@ -211,6 +349,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("structure", &structure); v->Visit("tile_binds", &tile_binds); + v->Visit("use_tensor_core", &use_tensor_core); v->Visit("max_innermost_factor", &max_innermost_factor); // `vector_load_lens` is not visited // `reuse_read_` is not visited @@ -225,45 +364,66 @@ class MultiLevelTilingNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); }; +inline std::vector MultiLevelTilingNode::DetectTensorCore(State state) const { + std::vector result; + // If Tensor Core is not allowed, we skip this subrule + if (!use_tensor_core) return {state}; + // Do tiling to match Tensor Core wmma sync intrin + BlockRV block_rv = state.block_rv; + Optional tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, "wmma_sync"); + if (!tiled_loop_rv.defined()) return {state}; + // Do blockize + state.block_rv = state.sch->Blockize(tiled_loop_rv.value()); + // Annotate the block + state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String("wmma_sync")); + state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String("wmma_fill")); + state.tensor_core_is_used = true; + // Annotate the root block to notify the following postprocessors + state.sch->Annotate(GetRootBlockRV(state.sch, state.block_rv), + tir::attr::meta_schedule_tensor_core_enabled, String("1")); + result.push_back(state); + return result; +} + inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { const ReuseConfig& config = this->reuse_write_; if (config.req == ReuseType::kNoReuse) { + if (state.tensor_core_is_used) state = TensorCoreStore(state); return {std::move(state)}; } - std::vector results; + // Case 1. If the write cache is already there, we don't need to add another. if (config.req == ReuseType::kMayReuse) { - // Case 1. If the write cache is already there, we don't need to add another. Array consumer_rvs = state.sch->GetConsumers(state.block_rv); if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { - for (int level : config.levels) { - State new_state = state; - new_state.sch = state.sch->Copy(); - new_state.sch->Seed(state.sch->ForkSeed()); - const LoopRV& loop_rv = new_state.tiles[level - 1].back(); - new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true); - results.push_back(std::move(new_state)); - } - results.push_back(state); - return results; - } else { - // Case 2. No write cache is added - State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv); - new_state.sch->Seed(state.sch->ForkSeed()); - results.emplace_back(std::move(new_state)); + state.write_cache = consumer_rvs[0]; + state.write_cache_is_added = false; + if (state.tensor_core_is_used) state = TensorCoreStore(state); + return {std::move(state)}; } } - + std::vector results; + results.reserve(2); + // Case 2. No write cache is added + if (config.req == ReuseType::kMayReuse) { + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv, + /*write_cache=*/NullOpt, + /*write_cache_is_added=*/false); + new_state.sch->Seed(state.sch->ForkSeed()); + if (new_state.tensor_core_is_used) new_state = TensorCoreStore(new_state); + results.emplace_back(std::move(new_state)); + } // Case 3. Add one write cache BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, /*storage_scope=*/config.scope); - for (int level : config.levels) { - State new_state = state; - new_state.sch = state.sch->Copy(); - new_state.sch->Seed(state.sch->ForkSeed()); - const LoopRV& loop_rv = new_state.tiles[level - 1].back(); - new_state.sch->ReverseComputeAt(write_cache, loop_rv, true); - results.push_back(std::move(new_state)); + state.write_cache = write_cache; + { + tir::Annotate(state.sch->state(), state.sch->GetSRef(write_cache), // + tir::attr::meta_schedule_cache_type, // + Integer(tir::attr::meta_schedule_cache_type_write)); } + state.write_cache_is_added = true; + if (state.tensor_core_is_used) state = TensorCoreStore(state); + results.emplace_back(std::move(state)); return results; } @@ -334,6 +494,7 @@ inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { const ReuseConfig& config = this->reuse_read_; if (config.req == ReuseType::kNoReuse) { + if (state.tensor_core_is_used) state = TensorCoreLoad(state); return {std::move(state)}; } ICHECK(config.req != ReuseType::kMayReuse); @@ -353,6 +514,11 @@ inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const } // Do cache_read BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope); + { + tir::Annotate(sch->state(), sch->GetSRef(cache_read_block), // + tir::attr::meta_schedule_cache_type, + Integer(tir::attr::meta_schedule_cache_type_read)); + } // Insert cache_read block to the proper place sch->ComputeAt(cache_read_block, loop_rv, true); // Fuse the iterators of the cache_read @@ -372,6 +538,40 @@ inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const } State new_state = state; new_state.sch = sch; + if (new_state.tensor_core_is_used) new_state = TensorCoreLoad(new_state); + results.push_back(std::move(new_state)); + } + return results; +} + +inline std::vector MultiLevelTilingNode::FuseWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + if (state.tensor_core_is_used) state = TensorCoreStoreFusion(state, r_indices_.front() - 1); + return {std::move(state)}; + } + // If the only-consumer does not exist, or is not elementwise, then do not do fusion + if (!state.write_cache.defined()) { + if (state.tensor_core_is_used) state = TensorCoreStoreFusion(state, r_indices_.front() - 1); + return {std::move(state)}; + } + std::vector results; + // Special case. + // Stages added by `cache_write` must be fused at some level, otherwise it has no benefit. + // On the other hand, If the consumer stage is not added by `cache_write`, + // we may choose not to fuse by setting `must_cache_write = False` + if (!state.write_cache_is_added && config.req != ReuseType::kMustReuse) { + results.push_back(state); + } + BlockRV consumer = state.write_cache.value(); + // Enumerate the level of tile to be fused at + for (int level : config.levels) { + State new_state = state; + new_state.sch = state.sch->Copy(); + new_state.sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + if (new_state.tensor_core_is_used) new_state = TensorCoreStoreFusion(new_state, level - 1); + new_state.sch->ReverseComputeAt(consumer, loop_rv, true); results.push_back(std::move(new_state)); } return results; @@ -380,6 +580,7 @@ inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const // Constructor ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, + bool use_tensor_core, Optional max_innermost_factor, Optional> vector_load_lens, Optional> reuse_read, @@ -387,6 +588,15 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional n = make_object(); n->structure = structure; n->tile_binds = tile_binds.value_or({}); + n->use_tensor_core = use_tensor_core; + if (use_tensor_core) { + // Check whether corresponding wmma intrinsics are registered + tir::TensorIntrin::Get("wmma_sync"); + tir::TensorIntrin::Get("wmma_load_a"); + tir::TensorIntrin::Get("wmma_load_b"); + tir::TensorIntrin::Get("wmma_store"); + tir::TensorIntrin::Get("wmma_fill"); + } n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; n->vector_load_lens = vector_load_lens.defined() ? support::AsVector(vector_load_lens.value()) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index bc616327eb3b..e9a5f268ec2d 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -86,29 +86,73 @@ class PostOrderApplyNode : public SpaceGeneratorNode { // `sch_rules_` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - this->rand_state_ = ForkSeed(&context->rand_state); - CHECK(context->sch_rules.defined()) + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->rand_state_ = ForkSeed(&tune_context->rand_state); + CHECK(tune_context->sch_rules.defined()) << "ValueError: Schedules rules not given in PostOrderApply!"; - this->sch_rules_ = context->sch_rules; + this->sch_rules_ = tune_context->sch_rules; } Array GenerateDesignSpace(const IRModule& mod_) final { using ScheduleAndUnvisitedBlocks = std::pair>; - tir::Schedule sch = tir::Schedule::Traced( // - /*mod=*/mod_, // - /*rand_state=*/ForkSeed(&this->rand_state_), // - /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/0, // tir::kVerifySRefTree | tir::kVerifyCachedFlags /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; - Array result{sch}; + Array result; + Array all_blocks = BlockCollector::Collect(sch), func_blocks, non_func_blocks; + for (const tir::BlockRV& block_rv : all_blocks) { + if (Optional custom_sch_rule_name_opt = + tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { + if (custom_sch_rule_name_opt.value() != "None") { + func_blocks.push_back(block_rv); + } + } else { + non_func_blocks.push_back(block_rv); + } + } + + // only do this once for schedule rules on block annotations + stack.emplace_back(sch, func_blocks); + while (!stack.empty()) { + // get the stack.top() + tir::Schedule sch; + Array blocks; + std::tie(sch, blocks) = stack.back(); + stack.pop_back(); + // if all blocks are visited + if (blocks.empty()) { + result.push_back(sch); + continue; + } + // otherwise, get the last block that is not visited + tir::BlockRV block_rv = blocks.back(); + blocks.pop_back(); + if (sch->HasBlock(block_rv)) { + // pick out the blocks with annotation for customized search space + Optional custom_sch_rule_name_opt = + tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule"); + ICHECK(custom_sch_rule_name_opt.defined() && custom_sch_rule_name_opt.value() != "None"); + String custom_sch_rule_name = custom_sch_rule_name_opt.value(); + const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name); + CHECK(custom_sch_rule_func) << "The given custom schedule function is not defined!"; + Array applied = (*custom_sch_rule_func)(sch, block_rv); + for (const tir::Schedule& sch : applied) { + stack.emplace_back(sch, blocks); + } + } else { + stack.emplace_back(sch, blocks); + } + } + // Enumerate the schedule rules first because you can // always concat multiple schedule rules as one - Array all_blocks = BlockCollector::Collect(sch); for (ScheduleRule sch_rule : sch_rules_) { for (const tir::Schedule& sch : result) { - stack.emplace_back(sch, all_blocks); + stack.emplace_back(sch, non_func_blocks); } result.clear(); diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index ae4a0386d404..46bbd2bceb9a 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -144,6 +144,29 @@ inline NDIntSet NDIntSetEval( return ret; } +/*! + * \brief Output the N-dimensional integer set to a stream. + * \param os The output stream. + * \param nd_int_set The N-dimensional integer set to be output. + * \return The output stream. + */ +inline std::ostream& operator<<(std::ostream& os, const NDIntSet& nd_int_set) { + os << '['; + bool is_first = true; + for (const arith::IntSet& int_set : nd_int_set) { + if (is_first) { + is_first = false; + } else { + os << ", "; + } + PrimExpr min = int_set.min(); + PrimExpr max = int_set.max(); + os << min << ":" << max; + } + os << ']'; + return os; +} + } // namespace support } // namespace tvm diff --git a/src/target/tag.cc b/src/target/tag.cc index a931a288924e..39f8f37aff2b 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -70,6 +70,30 @@ Target TargetTag::AddTag(String name, Map config, bool overri /********** Register Target tags **********/ +TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-64") + .set_config({{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array{"+neon"}}, + {"num-cores", Integer(4)}, + {"host", Map{{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array{"+neon"}}, + {"num-cores", Integer(4)}}}}); + +TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") + .set_config({{"kind", String("cuda")}, + {"arch", String("sm_72")}, + {"shared_memory_per_block", Integer(49152)}, + {"registers_per_block", Integer(65536)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"host", Map{{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("carmel")}, + {"num-cores", Integer(4)}}}}); + #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ TVM_REGISTER_TARGET_TAG(Name).set_config({ \ {"kind", String("cuda")}, \ @@ -318,7 +342,6 @@ TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-415m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480m", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-410m", "sm_21", 49152, 32768); -TVM_REGISTER_CUDA_TAG("nvidia/jetson-agx-xavier", "sm_72", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/jetson-nano", "sm_53", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx2", "sm_62", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx1", "sm_53", 49152, 32768); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index dc1ed1c193e8..d01788e92c4c 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -92,7 +92,7 @@ class GPUCodeVerifier : public StmtExprVisitor { const auto* extent = op->value.as(); ICHECK(extent); - std::string name = var.get()->name_hint; + std::string name = op->node.as()->thread_tag; // record the number of threads in a block if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || name == "vthread") { @@ -151,6 +151,7 @@ class GPUCodeVerifier : public StmtExprVisitor { errors_.push_back(s.str()); } }; + err("threads per block", thread_per_block_, max_threads_per_block_); err("local memory per block", local_memory_per_block_, max_local_memory_per_block_); err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 1c34e34468b5..77c0ffd4b8df 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -21,9 +21,13 @@ * \file src/tir/ir/function.cc * \brief The function data structure. */ +#include #include #include #include +#include + +#include "../../support/nd_int_set.h" namespace tvm { namespace tir { @@ -64,6 +68,135 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); +Array IndexMapNode::Apply(const Array& inputs) const { + CHECK_EQ(inputs.size(), this->src_iters.size()); + arith::Analyzer analyzer; + int n = inputs.size(); + for (int i = 0; i < n; ++i) { + analyzer.Bind(this->src_iters[i], inputs[i]); + } + Array results; + results.reserve(this->tgt_iters.size()); + for (PrimExpr result : this->tgt_iters) { + results.push_back(analyzer.Simplify(std::move(result))); + } + return results; +} + +Array IndexMapNode::MapShape(const Array& shape) const { + using namespace support; + Array indices; + std::unordered_map dom_map; + for (const PrimExpr dim : shape) { + Var var; + indices.push_back(var); + dom_map.emplace(var.get(), arith::IntSet::FromMinExtent(0, dim)); + } + Array mapped_indices = Apply(indices); + NDIntSet nd_int_set = NDIntSetFromPoint(mapped_indices); + nd_int_set = NDIntSetEval(nd_int_set, dom_map); + Array new_shape; + for (const auto& int_set : nd_int_set) { + ICHECK(is_zero(int_set.min())); + new_shape.push_back(int_set.max() + 1); + } + auto fmul = [](PrimExpr a, PrimExpr b, Span span) { return a * b; }; + PrimExpr old_size = foldl(fmul, Integer(1), shape); + PrimExpr new_size = foldl(fmul, Integer(1), new_shape); + + arith::Analyzer analyzer; + CHECK(analyzer.CanProveEqual(old_size, new_size)) + << "ValueError: The size of the new shape after IndexMap " << new_shape + << " doesn't match the size of the original shape " << shape; + return new_shape; +} + +String IndexMapNode::ToPythonString() const { + std::unordered_set used_names; + Map var_remap; + for (const Var& src_iter : src_iters) { + if (used_names.count(src_iter->name_hint)) { + std::string new_name = src_iter->name_hint + std::to_string(used_names.size()); + used_names.insert(new_name); + var_remap.Set(src_iter, Var(new_name)); + } else { + used_names.insert(src_iter->name_hint); + } + } + std::ostringstream oss; + oss << "lambda "; + for (size_t i = 0; i < src_iters.size(); ++i) { + if (i != 0) { + oss << ", "; + } + auto it = var_remap.find(src_iters[i]); + if (it != var_remap.end()) { + oss << (*it).second; + } else { + oss << src_iters[i]; + } + } + oss << ": ("; + for (size_t i = 0; i < tgt_iters.size(); ++i) { + if (i != 0) { + oss << ", "; + } + oss << Substitute(tgt_iters[i], var_remap); + } + if (tgt_iters.size() == 1) { + oss << ","; + } + oss << ")"; + return String(oss.str()); +} + +IndexMap::IndexMap(Array src_iters, Array tgt_iters) { + ObjectPtr n = make_object(); + n->src_iters = std::move(src_iters); + n->tgt_iters = std::move(tgt_iters); + data_ = std::move(n); +} + +IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func) { + Array src_iters; + src_iters.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + src_iters.push_back(Var("i" + std::to_string(i), DataType::Int(32))); + } + return IndexMap(src_iters, func(src_iters)); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + const auto* n = node.as(); + ICHECK(n); + p->stream << "IndexMap: ("; + for (int i = 0, total = n->src_iters.size(); i < total; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << n->src_iters[i]; + } + p->stream << ") => "; + p->stream << "("; + for (int i = 0, total = n->tgt_iters.size(); i < total; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << n->tgt_iters[i]; + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(IndexMapNode); +TVM_REGISTER_GLOBAL("tir.IndexMap") + .set_body_typed([](Array src_iters, Array tgt_iters) { + return IndexMap(src_iters, tgt_iters); + }); +TVM_REGISTER_GLOBAL("tir.IndexMapFromFunc").set_body_typed(IndexMap::FromFunc); +TVM_REGISTER_GLOBAL("tir.IndexMapApply").set_body_method(&IndexMapNode::Apply); + + class TensorIntrinManager { public: Map reg; diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..d3342fafdc06 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -700,10 +700,11 @@ Array Substitute(const Array& region, const Map& vm } void PreOrderVisit(const ObjectRef& stmt_or_expr, - const std::function& fvisit) { + const std::function& fvisit, bool visit_init_block) { class PreOrderVisitor : public StmtExprVisitor { public: - explicit PreOrderVisitor(const std::function& f) : f_(f) {} + explicit PreOrderVisitor(const std::function& f, bool visit_init_block) + : f_(f), visit_init_block_(visit_init_block) {} private: void VisitExpr(const PrimExpr& expr) final { @@ -726,11 +727,35 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, } } + void VisitStmt_(const BlockNode* op) final { + auto fvisit_buffer_region = [this](const BufferRegion& s) { + for (const auto& range : s->region) { + this->VisitExpr(range->min); + this->VisitExpr(range->extent); + } + }; + VisitArray(op->iter_vars, [this](const IterVar& iter_var) { + this->VisitExpr(iter_var->dom->min); + this->VisitExpr(iter_var->dom->extent); + }); + VisitArray(op->reads, fvisit_buffer_region); + VisitArray(op->writes, fvisit_buffer_region); + VisitArray(op->match_buffers, + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + fvisit_buffer_region(match_buffer_region->source); + }); + if (visit_init_block_ && op->init.defined()) { + this->VisitStmt(op->init.value()); + } + this->VisitStmt(op->body); + } + const std::function& f_; + bool visit_init_block_; std::unordered_set visited_; }; - PreOrderVisitor visitor(fvisit); + PreOrderVisitor visitor(fvisit, visit_init_block); if (const auto* stmt = stmt_or_expr.as()) { visitor(GetRef(stmt)); } else if (const auto* expr = stmt_or_expr.as()) { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9c6d1e6e96da..92bd6bd4bf99 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -339,6 +339,47 @@ bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, const StmtSRef& block_sref); +/******** Tensorization ********/ + +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps block loops to desc loops */ + Map loop_map; + /*! \brief Maps loops in desc to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +/*! + * \brief Managed reference to TensorizeInfoNode + * \sa TensorizeInfoNode + */ +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +/*! + * \brief Check if the given block can be tensorized, and in the meantime gather the necessary + * information for tensorization + * \param self The schedule state + * \param block_sref The block to be analyzed + * \param desc_func The target function for tensorization + * \return The necessary information used for tensorization, or NullOpt if the block cannot be + * tensorized + */ +TVM_DLL Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + /******** Producer-consumer relation ********/ /*! @@ -401,6 +442,16 @@ struct ProducerConsumerSplit { */ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write); +/*! + * \brief Find the defining site of the buffer in the given block and its ancestors + * \param block_sref The block sref + * \param buffer The buffer + * \return The defining site of the buffer and whether the buffer is allocated (otherwise the + * buffer is from match_buffer). + */ +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer); + /******** Reduction Block Related ********/ /*! @@ -468,6 +519,80 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, /******** Misc ********/ +/*! + * \brief Given the read/write region, extract the pattern of their index correspondence + * namely, the mapping from read index to the write index. + * \param read_region The read region + * \param write_region The write region + * \return A tuple of booleans, the extracted pattern + * 0) exists: if the pattern is found + * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once + * e.g. A[i, j] = B[i, i, j] + * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. + * e.g. A[i, j] = B[i] + * 3) ordered: if the mapping is ordered + * 4) no_const_read: if there is no constant indexing in the read indices, + * e.g. A[i, j] = B[0, i, j] + * 5) no_shift_read: if there is no constant shift in the read indices, + * e.g. A[i, j] = B[i + 1, j] + */ +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); + +/*! + * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is + * beneficial. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has data reuse opportunity + */ +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the given block has been applied by multi-level tiling. We check this by examine + * the block's annotation. + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has been multi-level tiled. + */ +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); + +/*! + * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. + * \param self The schedule state. + * \param block_sref The block to be checked. + * \param max_parallel_extent The maximum parallel jobs on the target. + * \param max_parallel_extent The maximum cores on the target. + * \return A boolean indicating whether the operation is beneficial. + */ +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic); + +/*! + * \brief Checks if the given AST contains the specific operators + * \param stmt The AST to be checked + * \param ops The list of operators to be checked + * \return A boolean indicating whether the AST contains the specific operators + */ +bool HasOp(const Stmt& stmt, const Array& ops); + +/*! + * \brief Checks if the given AST contains if-then-else, including + * 1) IfThenElse statement + * 2) Select expression + * 3) The operator `tir.if_then_else` + * 4) Block predicates + */ +bool HasIfThenElse(const Stmt& stmt); + +/******** Storage Scope ********/ + /*! * \brief Check whether the input storage scope string is valid. Throw an error if not. * \param self The schedule state @@ -515,6 +640,19 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops); +/*! + * \brief Provided the access pattern to a buffer, suggest one of the possible layout + * transformation to minimize the locality of the access pattern. + * \param buffer The buffer to be transformed + * \param indices The access pattern to the buffer + * \param loops The loops above the buffer + * \param predicate The predicate of the access + * \param analyzer Arithmetic analyzer + */ +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); + /*! * \brief Checks if the given AST contains the specific operators * \param stmt The AST statement to be checked diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c7ed67187793..bdb4295e900b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include + +#include "../ir_comparator.h" #include "../utils.h" namespace tvm { @@ -856,6 +859,179 @@ std::pair, std::vector> CollectComputeLocation(const Schedu return std::make_pair(location_srefs, location_indices); } +/******** Tensorization ********/ + +class AutoTensorizeComparator : public tir::TensorizeComparator { + public: + AutoTensorizeComparator(IRModule lhs_mod) : tir::TensorizeComparator(lhs_mod, false) {} + + bool VisitStmt(const tir::Stmt& n, const tir::Stmt& rhs) override { + if (n.same_as(rhs)) return true; + tir::Stmt lhs = n; + if (lhs->type_index() != rhs->type_index()) { + return false; + } + bool equal = tir::StmtComparator::VisitStmt(lhs, rhs); + ICHECK(equal || !assert_mode_) << "Statements are not matching between:\n" + << n << "\nand\n" + << rhs; + return equal; + } + + bool CompareBuffer(const tir::Buffer& lhs, const tir::Buffer& rhs) override { + if (lhs.same_as(rhs)) return true; + auto it = rhs_buffer_map_.find(rhs); + bool equal; + if (it != rhs_buffer_map_.end()) { + equal = (*it).second.same_as(lhs); + } else { + // Remap both buffer itself and buffer data, skip buffer shape and scope + equal = DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype; + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } + } + return equal; + } +}; + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + // Try to do tiling automatically if possible + // Now the heuristic is that if block's block var binding is constant + loop var, + // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder + // i, j, k according to the loops outside desc_block + // Collect the loops outside block + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract the loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + const auto* desc_body = + Downcast(desc_func->body)->block->body.as(); + ICHECK(desc_body); + tir::PostOrderVisit(desc_body->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + if (!AutoTensorizeComparator(self->mod).VisitStmt(block, GetRef(desc_block))) { + return NullOpt; + } + // Step 3. Extract the loops on top of the block. It is a mirror step of Step 1 + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 4. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + int n_block_vars = block->iter_values.size(); + int n_desc_vars = desc_block->iter_values.size(); + int offset = n_block_vars - n_desc_vars; + if (offset < 0) { + return NullOpt; + } + // We align the block and desc block's bindings from the right side + // block (v0=..., v1=..., v2=...) + // ^ i_block + // desc_block( v1=..., v2=...) + // ^ i_desc + for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) { + // For each block var binding, we find + const PrimExpr& block_bind = block->iter_values[i_block]; + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + // Step 4.1. Find the corresponding loop of the i-th block var of block + const tir::ForNode* block_loop = nullptr; + for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (!UsesVar(r, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { + block_loop = block_loops[i]; + break; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.2. Find the corresponding loop of the i-th block var of desc + const tir::ForNode* desc_loop = nullptr; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(r, + [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { + desc_loop = desc_loops[i]; + break; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.3. Check divisibility of loop extents + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + if (const auto* int_block_extent = block_extent.as()) { + if (const auto* int_desc_extent = desc_extent.as()) { + if (int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + } else { + return NullOpt; + } + } else { + return NullOpt; + } + // Step 4.4. Maps the result of Step 4.1 to Step 4.2 + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + auto it = ret->loop_map.find(block_loop_sref); + if (it == ret->loop_map.end()) { + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + } else if ((*it).second.get() != desc_loop) { + return NullOpt; + } + } + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + /******** Producer-consumer relation ********/ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { @@ -1029,6 +1205,37 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, return access_region[n]->buffer; } +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { + // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or + // match_buffers. + const StmtSRefNode* defining_site_sref = block_sref.get(); + while (defining_site_sref != nullptr) { + const auto* block = defining_site_sref->StmtAs(); + // If this sref is not a block sref, skip it. + if (block == nullptr) { + defining_site_sref = defining_site_sref->parent; + continue; + } + // Try to find the buffer in `alloc_buffers` + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (buffer.same_as(alloc_buffer)) { + return {GetRef(defining_site_sref), true}; + } + } + // Try to find the buffer in `match_buffers` + for (const MatchBufferRegion match_buffer : block->match_buffers) { + if (buffer.same_as(match_buffer)) { + return {GetRef(defining_site_sref), false}; + } + } + defining_site_sref = defining_site_sref->parent; + } + // If we cannot find the defining site block, it means that the buffer must be in the function's + // buffer_map, which isn't an intermediate buffer. + return {NullOpt, false}; +} + /******** Pattern Matcher ********/ /*! diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc new file mode 100644 index 000000000000..144b3a55a467 --- /dev/null +++ b/src/tir/schedule/analysis/layout.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Calculate the strides of the buffer + * \param buffer The buffer + * \return The strides + */ +Array GetStrides(const Buffer& buffer) { + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + return buffer->strides; + } + int ndim = buffer->shape.size(); + if (ndim == 0) { + return {}; + } + Array strides(ndim, PrimExpr{nullptr}); + PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); + for (int i = ndim - 1; i >= 0; --i) { + strides.Set(i, stride); + stride = stride * buffer->shape[i]; + } + return strides; +} + +/*! + * \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern + * to help decision making in layout transformation + */ +class SplitExprCollector { + public: + /*! + * \brief The corresponding IterSplitExpr, simplified for our case + * The pattern is `source // lower_factor % extent * scale` + */ + struct SplitExpr { + /*! \brief The source variable */ + Var source; + /*! \brief The lower factor of the split expression */ + int64_t lower_factor; + /*! \brief The extent of the split expression */ + int64_t extent; + }; + + /*! + * \brief Collect the split expressions in the indexing pattern + * \param index The indexing pattern + * \param input_iters The input iterators' domain + * \param predicate The predicate of the affine map + * \param require_bijective Whether the affine map is required to be bijective + * \param analyzer The analyzer + * \return The collected split expressions + */ + static std::vector Collect(const PrimExpr& index, + const Map& input_iters, // + const PrimExpr& predicate, // + bool require_bijective, // + arith::Analyzer* analyzer) { + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array iter_sum_exprs = arith::DetectIterMap( + {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer, diag_ctx); + if (iter_sum_exprs.empty()) { + return {}; + } + ICHECK_EQ(iter_sum_exprs.size(), 1); + if (iter_sum_exprs[0]->args.size() == 0) { + return {}; + } + SplitExprCollector collector; + collector.Visit(iter_sum_exprs[0]); + if (collector.failed_) { + return {}; + } + return std::move(collector.exprs_); + } + + private: + void Visit(const arith::IterSplitExpr& expr) { + if (const auto* var = expr->source->source.as()) { + const int64_t* lower_factor = as_const_int(expr->lower_factor); + const int64_t* extent = as_const_int(expr->extent); + if (lower_factor == nullptr || extent == nullptr) { + failed_ = true; + return; + } + exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); + } else if (const auto* iter_sum_expr = expr->source->source.as()) { + Visit(GetRef(iter_sum_expr)); + } else { + ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); + } + } + + void Visit(const arith::IterSumExpr& expr) { + for (const arith::IterSplitExpr& arg : expr->args) { + Visit(arg); + } + } + + /*! \brief Whether the analysis failed */ + bool failed_ = false; + /*! \brief The collected split expressions */ + std::vector exprs_; +}; + +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer) { + int ndim = buffer->shape.size(); + int n_loops = loops.size(); + // Step 1. Collect the domains and indices of loop variables + Map input_iters; + std::unordered_map var2id; + var2id.reserve(n_loops); + for (int i = 0; i < n_loops; ++i) { + const For& loop = loops[i]; + input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + var2id.emplace(loop->loop_var.get(), i); + } + // Step 2. Calculate a functor that flattens a multi-dimensional index + auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( + const Array& indices) -> PrimExpr { + PrimExpr flatten_index = make_const(dtype, 0); + for (int i = 0; i < ndim; ++i) { + flatten_index = flatten_index + strides[i] * indices[i]; + } + return flatten_index; + }; + // Step 3. Detect the IterSplitExpr of the indexing pattern + std::vector split_exprs = SplitExprCollector::Collect( + /*index=*/f_flatten_index(indices), input_iters, predicate, + /*require_bijective=*/false, analyzer); + if (split_exprs.empty()) { + return NullOpt; + } + // Step 4. Sort the order of the split expressions + std::vector order(split_exprs.size(), 0); + std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; }); + std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool { + const SplitExprCollector::SplitExpr& a = split_exprs[_a]; + const SplitExprCollector::SplitExpr& b = split_exprs[_b]; + int a_var_id = var2id.at(a.source.get()); + int b_var_id = var2id.at(b.source.get()); + if (a_var_id != b_var_id) { + return a_var_id < b_var_id; + } + return a.lower_factor > b.lower_factor; + }); + // Step 5. Create the indexing mapping + auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), // + split_exprs = std::move(split_exprs), // + order = std::move(order), // + shape = buffer->shape, // + analyzer // + ](Array indices) -> Array { + ICHECK_EQ(indices.size(), shape.size()); + for (int i = 0, n = indices.size(); i < n; ++i) { + analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); + } + PrimExpr index = f_flatten_index({indices.begin(), indices.end()}); + int ndim = split_exprs.size(); + // Step 5.1. Split the flattened index according to `split_exprs` + std::vector split; + split.reserve(ndim); + for (int i = ndim - 1; i >= 0; --i) { + index = analyzer->Simplify(index); + int64_t extent = split_exprs[i].extent; + split.push_back(analyzer->Simplify(floormod(index, extent))); + index = floordiv(index, extent); + } + std::reverse(split.begin(), split.end()); + // Step 5.2. Reorder the indexing pattern according to `order` + Array results; + results.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + results.push_back(split[order[i]]); + } + return results; + }; + return IndexMap::FromFunc(ndim, f_alter_layout); +} + +TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") + .set_body_typed([](Buffer buffer, Array indices, Array loops, + PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 394f0f26db35..92f33d3f4511 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -18,8 +18,6 @@ */ #include "./concrete_schedule.h" -#include - namespace tvm { namespace tir { @@ -214,7 +212,7 @@ Schedule ConcreteScheduleNode::Copy() const { void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { if (seed == -1) { - seed = std::random_device()(); + seed = support::LinearCongruentialEngine::DeviceRandom(); } support::LinearCongruentialEngine(&rand_state_).Seed(seed); } @@ -513,6 +511,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +/******** Schedule: Data movement ********/ + +BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -657,7 +679,20 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_ void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); + if (const auto* str = ann_val.as()) { + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, GetRef(str)); + } else if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->Get(GetRef(expr))); + } else if (ann_val.as()) { + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, ann_val); + } else { + LOG(FATAL) + << "TypeError: Only strings, integers, floats and ExprRVs are supported for now, but gets: " + << ann_val->GetTypeKey(); + throw; + } this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } @@ -672,8 +707,20 @@ void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_k void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Annotate(state_, this->GetSRef(block_rv), ann_key, - this->CheckAndGetAnnotationValue(ann_val)); + if (const auto* str = ann_val.as()) { + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, GetRef(str)); + } else if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, this->Get(GetRef(expr))); + } else if (ann_val.as()) { + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, ann_val); + } else { + LOG(FATAL) + << "TypeError: Only strings, integers, floats and ExprRVs are supported for now, but gets: " + << ann_val->GetTypeKey(); + throw; + } this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } @@ -685,6 +732,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } +/******** Schedule: Layout transformation ********/ +void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, + bool is_write_index, const IndexMap& index_map) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, is_write_index, index_map); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index f0f25ecafa3a..3501e7cb723f 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -109,6 +109,11 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -132,6 +137,9 @@ class ConcreteScheduleNode : public ScheduleNode { void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const BlockRV& block_rv, const String& ann_key) override; + /******** Schedule: Layout transformation ********/ + void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index, + const IndexMap& index_map) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 91a79456e579..14d05a4a340c 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -197,6 +197,8 @@ class PythonAPICall { inline void Input(String arg_name, int arg); /*! \brief Add an integer input */ inline void Input(String arg_name, int64_t arg); + /*! \brief Add a bool input */ + inline void Input(String arg_name, bool arg); /*! \brief Add a double input */ inline void Input(String arg_name, double arg); /*! \brief Add an input random variable */ @@ -462,6 +464,17 @@ void PythonAPICall::Input(String arg_name, int64_t arg) { args_.push_back(std::to_string(arg)); } +void PythonAPICall::Input(String arg_name, bool arg) { + static const char* true_str = "True"; + static const char* false_str = "False"; + arg_names_.emplace_back(std::move(arg_name)); + if (arg) { + args_.push_back(true_str); + } else { + args_.push_back(false_str); + } +} + void PythonAPICall::Input(String arg_name, double arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0cd2d3e6f38a..b445b5a9ded8 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -253,6 +253,15 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); + +/******** Schedule: Data movement ********/ + +TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); + +TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -407,6 +416,23 @@ TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, */ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, const ObjectRef& ann_val); + +/******** Schedule: Layout transformation ********/ +/*! + * \brief Apply a transformation represented by IndexMap to buffer + * \details The indices and the access region to the target buffer is transformed by the given + * index_map. The index_map is also used to infer the new shape of the buffer. Buffer must be + * one of the parameter of the function, or allocated in some blocks (it cannot be a buffer + * subregion created via match_buffer). + * \param self The state of the schedule + * \param block_sref The block sref that accesses the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param is_write_index Whether the buffer_index is the index of the block's write region. + * \param index_map The transformation to apply. + */ +TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + bool is_write_index, const IndexMap& index_map); + /*! * \brief Unannotate a block/loop's annotation with key ann_key * \param self The state of the schedule diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 418e770a5c93..f9cec421cd21 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -64,44 +64,6 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { int axis_; }; -/*! - * \brief Find the defining site of the buffer in the given block and its ancestors - * \param block_sref The block sref - * \param buffer The buffer - * \return The defining site of the buffer and whether the buffer is allocated (otherwise the - * buffer is from match_buffer). - */ -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, - const Buffer& buffer) { - // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or - // match_buffers. - const StmtSRefNode* defining_site_sref = block_sref.get(); - while (defining_site_sref != nullptr) { - const auto* block = defining_site_sref->StmtAs(); - // If this sref is not a block sref, skip it. - if (block == nullptr) { - defining_site_sref = defining_site_sref->parent; - continue; - } - // Try to find the buffer in `allloc_buffers` - for (const Buffer& alloc_buffer : block->alloc_buffers) { - if (buffer.same_as(alloc_buffer)) { - return {GetRef(defining_site_sref), true}; - } - } - // We do not allow the buffer being defined in `match_buffer`. - for (const MatchBufferRegion match_buffer : block->match_buffers) { - if (buffer.same_as(match_buffer)) { - return {GetRef(defining_site_sref), false}; - } - } - defining_site_sref = defining_site_sref->parent; - } - // If we cannot find the defining site block, it means that the buffer must be in the function's - // buffer_map, which isn't an intermediate buffer. - return {NullOpt, false}; -} - class NonAllocatedBufferError : public ScheduleError { public: explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 9a9860b42bc6..1ff71675838b 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -542,8 +542,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr producer_rhs_{nullptr}; }; -void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, - bool check_only = false) { +std::function ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -567,27 +566,24 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 6. Do the real mutation on the AST and the sref tree in the schedule state - if (check_only) { - return; - } - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; } void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { - ComputeInlineImpl(self, producer_block_sref); + ComputeInlineImpl(self, producer_block_sref)(); } bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) { try { - ComputeInlineImpl(self, producer_block_sref, true); + ComputeInlineImpl(self, producer_block_sref); } catch (const tvm::runtime::Error& e) { return false; } return true; } -void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, - bool check_only = false) { +std::function ReverseComputeInlineImpl(ScheduleState self, + const StmtSRef& consumer_block_sref) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); // Step 1. Get the scope block @@ -613,15 +609,12 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 7. Do the real mutation on the AST and the sref tree in the schedule state - if (check_only) { - return; - } - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; } bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { try { - ReverseComputeInlineImpl(self, block_sref, true); + ReverseComputeInlineImpl(self, block_sref); } catch (const tvm::runtime::Error& e) { return false; } @@ -629,7 +622,7 @@ bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sr } void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { - ReverseComputeInlineImpl(self, consumer_block_sref); + ReverseComputeInlineImpl(self, consumer_block_sref)(); } /******** InstructionKind Registration ********/ diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc new file mode 100644 index 000000000000..79bca35f9791 --- /dev/null +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class TransformLayoutRewriter : private StmtExprMutator { + public: + /*! + * \brief Rewrite the access to the buffer after the transformation + * \param scope_stmt The parent statement that contains all accesses to the target buffer + * \param old_buffer The target buffer before transformation + * \param new_buffer The new buffer after transformation + * \param index_map The transformation applied to the buffer + * \return The new AST rooting at the original parent scope and the map from the old block to the + * new block + */ + static std::pair> Rewrite(const Stmt& scope_stmt, + const Buffer& old_buffer, + const Buffer& new_buffer, + const IndexMap& index_map) { + TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map); + Stmt result = rewriter(scope_stmt); + return {result, rewriter.block_sref_reuse_}; + } + + private: + TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, + const IndexMap& index_map) + : old_buffer_(old_buffer), + new_buffer_(new_buffer), + index_map_(index_map), + buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {} + + void RewriteBufferAccess(Buffer* buffer, Array* indices) { + *buffer = new_buffer_; + *indices = index_map_->Apply(*indices); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad buffer_load = Downcast(StmtExprMutator::VisitExpr_(op)); + if (buffer_load->buffer.same_as(old_buffer_)) { + auto* n = buffer_load.CopyOnWrite(); + RewriteBufferAccess(&n->buffer, &n->indices); + } + return std::move(buffer_load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore buffer_store = Downcast(StmtExprMutator::VisitStmt_(op)); + if (buffer_store->buffer.same_as(old_buffer_)) { + auto* n = buffer_store.CopyOnWrite(); + RewriteBufferAccess(&n->buffer, &n->indices); + } + return std::move(buffer_store); + } + + void RewriteAccessRegion(Array* old_access_regions, + const Array& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(old_buffer_)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto* n = block.CopyOnWrite(); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + block_sref_reuse_.Set(GetRef(op), block); + return std::move(block); + } + + const Buffer& old_buffer_; + const Buffer& new_buffer_; + const IndexMap& index_map_; + Map buffer_data_to_buffer_; + Map block_sref_reuse_; +}; + +class BufferIsSubregionError : public ScheduleError { + public: + explicit BufferIsSubregionError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} + + String FastErrorString() const final { + return "ScheduleError: The input buffer is defined in `match_buffer` of a block, it is expected" + " to be a function parameter or allocated by a block"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: The input buffer " << buffer_->name << " is defined in `match_buffer` of " + << "a block, it is expected to be a function parameter or allocated by a block."; + return os.str(); + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + Buffer buffer_; +}; + +void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + bool is_write_index, const IndexMap& index_map) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, + /*is_write=*/!is_write_index); + Optional defining_site_sref; + bool is_alloc; + std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer); + if (defining_site_sref.defined() && !is_alloc) { + throw BufferIsSubregionError(self->mod, old_buffer); + } + + StmtSRef scope_sref = defining_site_sref.defined() + ? defining_site_sref.value() + : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + + // Step 1: Infer the shape of the new buffer + ObjectPtr new_buffer_node = make_object(*(old_buffer.get())); + new_buffer_node->shape = index_map->MapShape(old_buffer->shape); + Buffer new_buffer{new_buffer_node}; + + // Step 2: Rewrite access indices and regions of the buffer + Stmt new_stmt; + Map block_sref_reuse; + std::tie(new_stmt, block_sref_reuse) = TransformLayoutRewriter::Rewrite( + GetRef(scope_block), old_buffer, new_buffer, index_map); + Block new_scope_block = Downcast(new_stmt); + + // Step 3: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. + if (defining_site_sref.defined()) { + auto* n = new_scope_block.CopyOnWrite(); + n->alloc_buffers.MutateByApply([&old_buffer, &new_buffer](const Buffer& buffer) { + if (buffer.same_as(old_buffer)) { + return new_buffer; + } + return buffer; + }); + block_sref_reuse.Set(GetRef(scope_block), new_scope_block); + } else { + GlobalVar g_var; + GetRootPrimFunc(self->mod, scope_block, &g_var); + IRModuleNode* new_mod = self->mod.CopyOnWrite(); + MapNode* new_map = new_mod->functions.CopyOnWrite(); + PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); + PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); + MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite(); + for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { + if ((*it).second.same_as(old_buffer)) { + (*it).second = new_buffer; + } + } + new_map->at(g_var) = std::move(ref_new_func); + } + + // Step 4: Replace the scope block with the new block + self->Replace(scope_sref, new_scope_block, block_sref_reuse); +} + +/******** InstructionKind Registration ********/ + +struct TransformLayoutTraits : public UnpackedInstTraits { + static constexpr const char* kName = "TransformLayout"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + Bool is_write_index, IndexMap index_map) { + return sch->TransformLayout(block_rv, buffer_index, is_write_index, index_map); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + Bool is_write_index, IndexMap index_map) { + PythonAPICall py("transform_layout"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("is_write_index", is_write_index); + py.Input("index_map", index_map->ToPythonString()); + return py.Str(); + } + + public: + static ObjectRef AttrsAsJSON(const Array& attrs) { + Array attrs_record; + attrs_record.reserve(kNumAttrs); + attrs_record.push_back(attrs[0]); + attrs_record.push_back(attrs[1]); + attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); + return std::move(attrs_record); + } + + static Array AttrsFromJSON(const ObjectRef& attrs_record_) { + Array attrs_record = Downcast>(attrs_record_); + Array attrs; + attrs.push_back(attrs_record[0]); + attrs.push_back(attrs_record[1]); + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); + return attrs; + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7b9ac488b8b9..fa2a4469b8c9 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -413,7 +413,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, for (int i = 0; i < n; i++) { const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); - substitute_value = substitute_value * factor + var; + if (!is_one(factor)) substitute_value = substitute_value * factor + var; analyzer.Bind(var, Range::FromMinExtent(0, factor)); new_loop_vars.emplace_back(std::move(var)); } @@ -505,11 +505,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); Array substitute_value; substitute_value.resize(loops.size()); - PrimExpr tot = fused_var; - for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { - substitute_value.Set(i, floormod(tot, loops[i]->extent)); - tot = floordiv(tot, loops[i]->extent); - } + PrimExpr lower = 1; + for (int i = static_cast(loops.size()) - 1; i > 0; i--) { + substitute_value.Set(i, is_one(loops[i]->extent) + ? 0 + : floordiv(floormod(fused_var, lower * loops[i]->extent), lower)); + lower = lower * loops[i]->extent; + } + substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; Map opaque_block_reuse; auto f_substitute = [&](const Var& v) -> Optional { @@ -534,6 +537,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } + /*! * \brief Collect an array of loop srefs into a set * \param self The schedule state diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc new file mode 100644 index 000000000000..2702f08343a0 --- /dev/null +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + return true; + } + } + return false; +} + +void RelaxBufferRegions(const Array& buffer_regions, + const Buffer& buffer, // + const Map& var_dom, // + const Map& bindings, // + std::vector* relaxed_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + Array relaxed_region = + arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); + relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); + } + } +} + +class ScopeReplacer : public StmtMutator { + public: + static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = make_object(*scope_block); + new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); + new_scope_block->alloc_buffers.push_back(dst); + return Block(new_scope_block); + } + + private: + explicit ScopeReplacer(const ForNode* old_loop, const ForNode* new_loop) + : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} + + Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } + Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == old_loop_) { + found_ = true; + return GetRef(new_loop_); + } + return StmtMutator::VisitStmt_(loop); + } + + const ForNode* old_loop_; + const ForNode* new_loop_; + bool found_; +}; + +class ReadWriteAtBufferReplacer : public StmtExprMutator { + public: + explicit ReadWriteAtBufferReplacer(const Buffer& src, const Buffer& dst, + Map* block_sref_reuse) + : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + return BufferStore(new_store); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + return BufferLoad(new_load); + } + return load; + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block old_block = GetRef(_block); + Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = make_object(*block.get()); + new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); + new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); + block_sref_reuse_->Set(old_block, Block(new_block)); + return Block(new_block); + } + + const Buffer& src_; + const Buffer& dst_; + Map* block_sref_reuse_; +}; + +struct ReadWriteAtImpl { + template + static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope, + Map annotations) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer src = + GetNthAccessBuffer(self, GetRef(block), buffer_index, /*is_write=*/!is_read); + Buffer dst = WithScope(src, storage_scope); + ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); + std::pair new_loop_block = + impl.MakeLoopAndBlock(src->name + "_" + storage_scope); + StmtSRef result_block_sref = + impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); + impl.UpdateBlockInfo(result_block_sref); + return result_block_sref; + } + + private: + static Map GetLoopDomain(const StmtSRefNode* loop_sref) { + Map result; + for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; + loop_sref = loop_sref->parent) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return result; + } + + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, + /*require_stage_pipeline=*/true); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_root_sref); + Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); + return self_->stmt2ref.at(new_block); + } + + void UpdateBlockInfo(const StmtSRef& new_block_sref) { + BlockInfo& block_info = self_->block_info[new_block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } + + template + std::pair MakeLoopAndBlock(const String& new_block_name_hint) { + Array subtrees = AsArray(loop_->body); + int n_subtrees = subtrees.size(); + runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); + std::vector relaxed_regions; + std::vector r_pos; + std::vector w_pos; + relaxed_regions.reserve(n_subtrees); + r_pos.reserve(n_subtrees); + w_pos.reserve(n_subtrees); + // Step 1. Iterate over all subtrees + for (int i = 0; i < n_subtrees; ++i) { + bool r_visited = false; + bool w_visited = false; + auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, + &scope](const ObjectRef& obj) -> bool { + const BlockRealizeNode* realize = obj.as(); + if (realize == nullptr) { + return true; + } + const BlockNode* block = realize->block.get(); + bool has_r = HasBuffer(block->reads, src_); + bool has_w = HasBuffer(block->writes, src_); + r_visited = r_visited || has_r; + w_visited = w_visited || has_w; + if (is_read ? has_r : has_w) { + RelaxBufferRegions( + /*buffer_regions=*/is_read ? block->reads : block->writes, + /*buffer=*/src_, + /*var_dom=*/ + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*high_exclusive=*/loop_sref_, + /*extra_relax_scope=*/scope)), + /*bindings=*/GetBindings(GetRef(realize)), + /*relaxed_regions=*/&relaxed_regions); + } + return false; + }; + PreOrderVisit(subtrees[i], f_visit); + if (r_visited) { + r_pos.push_back(i); + } + if (w_visited) { + w_pos.push_back(i); + } + } + // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement + int insert_pos = -1, st = -1, ed = -1; + if (is_read) { + ICHECK(!r_pos.empty()); + // No write after the first read + ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + // Can be inserted at [0, r_pos.front()], i.e. before the first read + insert_pos = r_pos.front(); + // Buffer reads in [insert_pos, +oo) is rewritten + st = insert_pos; + ed = n_subtrees; + } else { + ICHECK(!w_pos.empty()); + // No read after the last write + ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + // Can be inserted into (w_pos.back(), +oo), i.e. after the last write + insert_pos = w_pos.back() + 1; + st = 0; + ed = insert_pos; + } + // Step 3. Calculate `domain`, the domain of buffer access + NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); + int ndim = relaxed.size(); + Array domain; + domain.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IntSet& int_set = relaxed[i]; + PrimExpr min = analyzer_->Simplify(int_set.min()); + PrimExpr extent = analyzer_->Simplify(int_set.max() + 1 - min); + domain.push_back(Range::FromMinExtent(min, extent)); + } + // Step 4. Insert the auto copy block and replace buffers + ReadWriteAtBufferReplacer replacer(src_, dst_, &block_sref_reuse_); + for (int i = st; i < ed; ++i) { + Stmt stmt = subtrees[i]; + subtrees.Set(i, Stmt(nullptr)); + subtrees.Set(i, replacer(std::move(stmt))); + } + BlockRealize realize = + is_read + ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + subtrees.insert(subtrees.begin() + insert_pos, realize); + ObjectPtr new_loop = make_object(*loop_); + new_loop->body = SeqStmt(std::move(subtrees)); + return {For(new_loop), realize}; + } + + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, + const Map& loop_domain, Array domain) const { + int n = domain.size(); + std::vector loop_vars; + loop_vars.reserve(n); + for (int i = 0; i < n; ++i) { + loop_vars.push_back(Var("ax" + std::to_string(i))); + } + Map bindings; + Array iter_vars; + Array iter_values; + Array indices; + iter_vars.reserve(n); + iter_values.reserve(n); + indices.reserve(n); + for (int i = 0; i < n; ++i) { + auto f_substitute = [&loop_domain, &bindings, &iter_vars, + &iter_values](const Var& var) -> Optional { + auto it = bindings.find(var); + if (it != bindings.end()) { + return (*it).second; + } + Range range = loop_domain.at(var); + ObjectPtr v = make_object(*var.get()); + v->name_hint = "v" + std::to_string(iter_vars.size()); + bindings.Set(var, Var(v)); + iter_values.push_back(var); + iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); + return Var(v); + }; + ObjectPtr dom = make_object(*domain[i].get()); + dom->min = Substitute(std::move(dom->min), f_substitute); + dom->extent = Substitute(std::move(dom->extent), f_substitute); + domain.Set(i, Range(dom)); + } + for (int i = 0; i < n; ++i) { + indices.push_back(domain[i]->min + loop_vars[i]); + } + Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); + for (int i = n - 1; i >= 0; --i) { + stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + } + return BlockRealize( + /*values=*/iter_values, + /*predicate=*/const_true(), + Block(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); + } + + explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, + const Buffer& dst, Map annotations) + : self_(self), + loop_sref_(loop_sref), + loop_(nullptr), + src_(src), + dst_(dst), + annotations_(annotations), + block_sref_reuse_(), + analyzer_(std::make_unique()) { + loop_ = TVM_SREF_TO_FOR(loop_, loop_sref); + } + + ScheduleState self_; + const StmtSRef& loop_sref_; + const ForNode* loop_; + const Buffer& src_; + const Buffer& dst_; + Map annotations_; + Map block_sref_reuse_; + std::unique_ptr analyzer_; +}; + +StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, + {{"auto_copy", Integer(1)}}); +} + +StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, + storage_scope, {{"auto_copy", Integer(1)}}); +} + +/******** Instruction Registration ********/ + +struct ReadAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReadAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("read_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct WriteAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "WriteAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer write_buffer_index, String storage_scope) { + return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("write_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 4b9b78e3b299..03ffb4fe159e 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -485,7 +485,7 @@ class LoopPropertyError : public ScheduleError { throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop); } continue; - } else if (reduction_touched) { + } else if (reduction_touched) { if (!meet_reduction_loop) { CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); meet_reduction_loop = true; @@ -559,10 +559,13 @@ class BaseBlockCreator { } void CreateBlock() { - CreateAdditionalIter(); for (int i = 0; i < n_block_iters_; ++i) { CreateNormalIters(i); } + if (!additional_iter_.defined()) { + ICHECK(arith::Analyzer().CanProveEqual(rf_loop_->extent, Integer(1))); + CreateAdditionalIter(); + } CreateReductionUpdate(); CreateReadWriteRegions(); @@ -600,6 +603,8 @@ class BaseBlockCreator { BlockRealize new_block_realize_; /*! \brief The indices used to access the intermediate rfactor buffer */ Array rf_buf_access_indices_; + /*! \brief The additional block iter of the new created block for the rfactor loop. */ + IterVar additional_iter_; protected: /*! \brief The old block-realize */ @@ -671,15 +676,6 @@ class RFactorBlockCreator : public BaseBlockCreator { combiner_rhs_(std::move(combiner_rhs)) {} private: - void CreateAdditionalIter() final { - // Create a new data parallel block iter for the rfactor loop. - additional_iter_ = - IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar); - loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var; - iter_vars_.push_back(additional_iter_); - iter_values_.push_back(rf_loop_->loop_var); - } - void CreateNormalIters(int idx) final { IterVar old_iter = old_block_realize_->block->iter_vars[idx]; PrimExpr old_binding = old_block_realize_->iter_values[idx]; @@ -705,20 +701,31 @@ class RFactorBlockCreator : public BaseBlockCreator { } const For& loop = it->second; if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) { - // We haven't created the new block iter for `var`. So here we create it, append it - // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. - IterVar new_iter_var = - IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce); + // - We haven't created the new block iter for `var`. So here we create it, append it + // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. + // - If the loop is the rfactor loop, envoke `CreateAdditionalIter()`. + if (loop.same_as(rf_loop_)) { + CreateAdditionalIter(); + continue; + } + IterVar new_iter_var = IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, kCommReduce); loop_var2block_binding_[var.get()] = new_iter_var->var; iter_vars_.push_back(new_iter_var); iter_values_.push_back(var); } } // Substitute the original binding with new block iters. Store the result expression - // in `rf_var_map` for future substitution. + // in `var_map_` for future substitution. var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); } + void CreateAdditionalIter() final { + additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kDataPar); + iter_vars_.insert(iter_vars_.end(), additional_iter_); + iter_values_.insert(iter_values_.end(), rf_loop_->loop_var); + loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_; + } + void CreateReductionUpdate() final { rf_buf_access_indices_ = old_reduction_update_->indices; rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, @@ -753,10 +760,6 @@ class RFactorBlockCreator : public BaseBlockCreator { return new_regions; } - public: - /*! \brief The generated additional block iter in rfactor block for the rfactor loop */ - IterVar additional_iter_; - private: /*! * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction @@ -797,14 +800,11 @@ class WriteBackBlockCreator : public BaseBlockCreator { private: void CreateAdditionalIter() final { - // Create a new reduction block iter for the rfactor loop. - IterVar wb_new_block_iter = - IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); - iter_vars_.push_back(wb_new_block_iter); - iter_values_.push_back(rf_loop_->loop_var); - var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var); + additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); + iter_vars_.insert(iter_vars_.end(), additional_iter_); + iter_values_.insert(iter_values_.end(), rf_loop_->loop_var); + var_map_.Set(rf_additional_iter_->var, additional_iter_->var); } - void CreateNormalIters(int idx) final { IterVar old_block_iter = old_block_realize_->block->iter_vars[idx]; if (old_block_iter->iter_type == IterVarType::kDataPar) { @@ -812,6 +812,16 @@ class WriteBackBlockCreator : public BaseBlockCreator { kDataPar); iter_values_.push_back(old_block_realize_->iter_values[idx]); var_map_.Set(old_block_iter->var, iter_vars_.back()); + return; + } + + ICHECK(old_block_iter->iter_type == IterVarType::kCommReduce); + // If the old block iter touches the reduction loop and we have not created a new reduction + // block iter for the rfactor loop, create one now. + if (!additional_iter_.defined() && + UsesVar(old_block_realize_->iter_values[idx], + [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) { + CreateAdditionalIter(); } } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index b466843f9459..80a2680aa3ad 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -165,6 +165,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); @@ -226,6 +230,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") throw; }); +/******** (FFI) Layout transformation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") + .set_body_method(&ScheduleNode::TransformLayout); /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 1e2e57eb6eca..878a56934765 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -264,6 +264,31 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReadAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("WriteAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -273,7 +298,7 @@ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_r static const InstructionKind& kind = InstructionKind::Get("ComputeAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops)}, + /*attrs=*/{Bool(preserve_unit_loops)}, /*outputs=*/{})); } @@ -284,7 +309,7 @@ void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops)}, + /*attrs=*/{Bool(preserve_unit_loops)}, /*outputs=*/{})); } @@ -427,6 +452,19 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k /*outputs=*/{})); } +/******** Schedule: Layout transformation ********/ + +void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, + bool is_write_index, const IndexMap& index_map) { + ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, is_write_index, index_map); + static const InstructionKind& kind = InstructionKind::Get("TransformLayout"); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Bool(is_write_index), index_map}, + /*outputs=*/{})); +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 5d3fdbf570de..0a00eb3793b4 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -73,6 +73,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -95,6 +100,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; void Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const BlockRV& block_rv, const String& ann_key) override; + /******** Schedule: Layout transformation ********/ + void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index, + const IndexMap& index_map) override; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index ffb6b2d52628..fb3829c59a01 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,5 +136,98 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } +/******** Utilities for tensorization ********/ + +class IRSubstituteInScope : public StmtExprMutator { + public: + explicit IRSubstituteInScope(std::function fmap) + : fmap_(std::move(fmap)) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = fmap_(op); + if (it.defined()) { + return it; + } else { + return GetRef(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + arith::Analyzer analyzer; + auto fmutate = [&](const PrimExpr& e) { return this->VisitExpr(e); }; + Array v = op->iter_values; + v.MutateByApply(fmutate); + PrimExpr pred = this->VisitExpr(op->predicate); + if (v.same_as(op->iter_values) && pred.same_as(op->predicate)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + n->predicate = std::move(analyzer.Simplify(pred)); + return Stmt(n); + } + } + + private: + const std::function fmap_; +}; + +Stmt SubstituteInScope(const Stmt& stmt, + const std::function& value_func) { + return IRSubstituteInScope(value_func)(stmt); +} + +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(stmt); +} + +PrimExpr SubstituteInScope(const PrimExpr& expr, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(expr); +} + +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return GetRef(it->second); + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(stmt); +} + +PrimExpr SubstituteInScope(const PrimExpr& expr, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return GetRef(it->second); + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(expr); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 2de8ef6e0c93..9d4ddfc49fc2 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -267,7 +267,7 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/**************** Loop extents ****************/ +/**************** PrimExpr parsing and extents ****************/ /*! * \brief Get the extents of a loop @@ -431,6 +431,40 @@ inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::Bl } } +/******** Tensorization ******/ +/*! + * \brief Rewrite the block's outer loops to match the tensor intrin + * \param sch The schedule + * \param block_rv The block_rv we want to rewrite + * \param intrin_name The name of the tensor intrin we want to match + */ +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name); + +/*! + * \brief Substitute the var in current block scope specified in key->var to be value. + * \param stmt The source stmt to be substituted + * \param value_func The function of new values mapping. + * \return The converted stmt. + */ +Stmt SubstituteInScope(const Stmt& stmt, const std::function& value_func); + +/*! + * \brief Substitute the var in current block scope specified in var map + * \param stmt The source stmt to be substituted + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map); + +/*! + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map); + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/apply_block_bound_predicate.cc b/src/tir/transforms/apply_block_bound_predicate.cc new file mode 100644 index 000000000000..2e93f4b13063 --- /dev/null +++ b/src/tir/transforms/apply_block_bound_predicate.cc @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file apply_block_bound_predicate.cc + * \brief Apply the block iter bound predicate to loops. + */ + +#include +#include +#include + +#include "../../arith/pattern_match.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +class BoundPredicateParserSimplifier : public ExprMutator { + public: + explicit BoundPredicateParserSimplifier(Map binding_map, + Map* bound_intset) + : binding_map_(std::move(binding_map)), bound_intset_(bound_intset) {} + + private: + PrimExpr VisitExpr(const PrimExpr& expr) final { + if (expr->IsInstance() || expr->IsInstance() || expr->IsInstance()) { + return ExprMutator::VisitExpr(expr); + } + ICHECK(false) << "InternalError: PrimExpr \"" << expr + << "\" is not supposed to appear as a bound predicate"; + throw; + } + + PrimExpr VisitExpr_(const LTNode* lt) final { + const VarNode* var = lt->a.as(); + if (!var) { + ICHECK(false) << "InternalError: LHS of logical expression here is required to be variables"; + } + Optional binding = binding_map_.Get(GetRef(var)); + if (!binding.defined()) { + ICHECK(false) << "InternalError: The LHS variable is supposed to be a block iterator"; + } + const VarNode* loop_var = binding.value().as(); + if (!loop_var) { + return GetRef(lt); + } + + arith::IntSet intset = + bound_intset_->Get(GetRef(loop_var)).value_or(arith::IntSet::Everything()); + intset = arith::Intersect( + {intset, arith::IntSet::FromRange(Range(min_value(lt->b.dtype()), lt->b))}); + bound_intset_->Set(GetRef(loop_var), intset); + return const_true(); + } + + PrimExpr VisitExpr_(const GENode* ge) final { + const VarNode* var = ge->a.as(); + if (!var) { + ICHECK(false) << "InternalError: LHS of logical expression here is required to be variables"; + } + Optional binding = binding_map_.Get(GetRef(var)); + if (!binding.defined()) { + ICHECK(false) << "InternalError: The LHS variable is supposed to be a block iterator"; + } + const VarNode* loop_var = binding.value().as(); + if (!loop_var) { + return GetRef(ge); + } + + arith::IntSet intset = + bound_intset_->Get(GetRef(loop_var)).value_or(arith::IntSet::Everything()); + intset = arith::Intersect( + {intset, arith::IntSet::FromRange(Range(ge->b, max_value(ge->b.dtype())))}); + bound_intset_->Set(GetRef(loop_var), intset); + return const_true(); + } + + Map binding_map_; + Map* bound_intset_; +}; + +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + */ +class LoopExtentMutator : public StmtMutator { + private: + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // Step 1. Mutate recursively. + BlockRealize new_realize = Downcast(StmtMutator::VisitStmt_(realize)); + // Step 2. If the block has no "require_block_var_bound_predicate" annotation, skip this block. + Block block = new_realize->block; + const Optional& bound_predicate = + block->annotations.Get(tir::attr::require_block_var_bound_predicate); + if (!bound_predicate.defined()) { + return new_realize; + } + // Step 3. Make a mapping from block iters to bindings. + Map binding_map; + ICHECK_EQ(block->iter_vars.size(), new_realize->iter_values.size()); + int n_iter = static_cast(block->iter_vars.size()); + for (int i = 0; i < n_iter; ++i) { + binding_map.Set(block->iter_vars[i]->var, new_realize->iter_values[i]); + } + // Step 4. Parse the bound predicate, removing constraints on the block vars whose binding are + // single vars. + PrimExpr new_predicate = BoundPredicateParserSimplifier( + binding_map, &bound_intset_)(Downcast(bound_predicate.value())); + // Step 5. Update the block annotation and update the new block-realize. + ObjectPtr p_new_block = CopyOnWrite(block.get()); + if (ana_.CanProveEqual(new_predicate, const_true())) { + p_new_block->annotations.erase(tir::attr::require_block_var_bound_predicate); + } else { + p_new_block->annotations.Set(tir::attr::require_block_var_bound_predicate, new_predicate); + } + ObjectPtr p_new_realize = CopyOnWrite(new_realize.get()); + p_new_realize->block = Block(p_new_block); + + return BlockRealize(p_new_realize); + } + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Mutate recursively. + For new_loop = Downcast(StmtMutator::VisitStmt_(loop)); + // Step 2. Check whether this loop has a bound intset. If not, return the new loop. + Optional intset = bound_intset_.Get(new_loop->loop_var); + if (!intset.defined()) { + return new_loop; + } + // Step 3. Update the new loop's `min` and `extent` according to the extent. + PrimExpr new_min = max(new_loop->min, intset.value().min()); + PrimExpr new_extent = min(new_loop->min + new_loop->extent, intset.value().max() + 1) - new_min; + // Step 4. Update the new loop. + ObjectPtr p_new_loop = CopyOnWrite(new_loop.get()); + p_new_loop->min = ana_.Simplify(new_min); + p_new_loop->extent = ana_.Simplify(new_extent); + + return For(p_new_loop); + } + + /*! \brief The bounds of loop vars, provided by the block iter bound predicate */ + Map bound_intset_; + /*! \brief The analyzer */ + arith::Analyzer ana_; +}; + +PrimFunc ApplyBlockBoundPredicate(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = LoopExtentMutator()(f->body); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass ApplyBlockBoundPredicate() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ApplyBlockBoundPredicate(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ApplyBlockBoundPredicate", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ApplyBlockBoundPredicate") + .set_body_typed(ApplyBlockBoundPredicate); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 20ddd7f84a35..c12671bf98fc 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -53,12 +53,29 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, for (size_t i = 0; i < nd_int_set.size(); ++i) { const arith::IntSet& int_set = nd_int_set[i]; Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])); - result.push_back( - Range::FromMinExtent(analyzer->Simplify(range->min), analyzer->Simplify(range->extent))); + result.push_back(Range::FromMinExtent( + analyzer->Simplify(range->min), analyzer->Simplify(min(original_shape[i], range->extent)))); } return result; } +NDIntSet NDIntSetEval(Region region, PrimExpr predicate, + std::unordered_map& dom_map, + arith::Analyzer* analyzer) { + std::unordered_map var_dom; + for (const auto& it : dom_map) { + var_dom[GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); + } + Optional> eval_res = + arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer); + if (eval_res.defined()) { + NDIntSet res(0); + for (const auto& it : eval_res.value()) res.push_back(it); + return res; + } + return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map); +} + /*! * \brief Collect the access region of each buffer. * \note The param buffer regions will not be collected. @@ -149,7 +166,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } return; } - return StmtExprVisitor::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BlockNode* op) final { @@ -198,6 +215,13 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } } + void VisitStmt_(const BlockRealizeNode* op) final { + PrimExpr cur_predicate = predicate_in_scope; + predicate_in_scope = op->predicate; + StmtExprVisitor::VisitStmt_(op); + predicate_in_scope = cur_predicate; + } + /**************** Helper functions ****************/ void VisitBufferAccess(const BufferRegion& buffer_region) { @@ -206,7 +230,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor { if (it != buffer_var_in_scope_.end()) { const Buffer& buffer = it->second.first; size_t n_ancestor_loops = it->second.second; - NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region); // Step 1. Stop ancestor loop vars out of the allocation block from // being relaxed unless NeedRelaxThread() is true. std::vector non_relaxed(n_ancestor_loops); @@ -222,7 +245,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { dom_map_.erase(dom_it); } // Step 2. Relax the access region - nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_); + NDIntSet nd_int_set = + NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_); // Step 3. Restore the non-relaxed ancestor loops domain for (size_t i = 0; i < n_ancestor_loops; ++i) { const VarNode* v = ancestor_loops_[i]->loop_var.get(); @@ -279,6 +303,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_var_in_scope_; + /*! \brief The block predicate of current scope */ + PrimExpr predicate_in_scope{true}; /*! \brief The map from loop vars to their iter range. */ std::unordered_map dom_map_; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index e9d99cda7e13..e3b32cb6c460 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -65,6 +65,12 @@ class BufferFlattener : public StmtExprMutator { if (!is_one(predicate)) { body = IfThenElse(predicate, std::move(body)); } + // If the block has bound predicates, transform it to if-then-else + const Optional& bound_predicate = + new_block->annotations.Get(tir::attr::require_block_var_bound_predicate); + if (bound_predicate.defined()) { + body = IfThenElse(Downcast(bound_predicate.value()), std::move(body)); + } // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { const Buffer& buffer = new_block->alloc_buffers[i - 1]; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc new file mode 100644 index 000000000000..a5cecc1d4707 --- /dev/null +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inject_software_pipeline.cc + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers + */ +#include +#include +#include + +#include "../../support/utils.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" + +namespace tvm { +namespace tir { + +namespace software_pipeline { + +Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) { + Block block = Block({}, {}, {}, "", body); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + auto* n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + return block; +} + +struct PipelineStageOrder { + int stage; + int order; + PipelineStageOrder(int stage, int order) : stage(stage), order(order) {} +}; + +using PipelineInfo = std::unordered_map; + +struct BufferAccessInfo { + int def; // the defining stage of the buffer + int use; // the last using stage of the buffer + BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use){}; +}; + +/*! + * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices + * of accessing to remapped buffer to select the version corresponding to the pipeline stage. + */ +class PipelineBodyRewriter : public StmtExprMutator { + public: + /*! + * \brief Constructor of PipelineBodyRewriter. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param buffer_remap The map from original buffer to the buffer with updated shape for + * multi-versioning in the sofeware pipeline. + * \param pipeline_loop The original loop to be software pipelined. + * \param access_all_versions Whether all versions the the buffers in the software pipeline are + * accessed. This will be used to update block access region. In the prologue and epilogue + * of a two-stage software pipeline, only one version of these buffers are accessed. + */ + PipelineBodyRewriter(const Map& buffer_data_to_buffer, + const Map& buffer_remap, For pipeline_loop, + bool access_all_versions) + : buffer_data_to_buffer_(buffer_data_to_buffer), + buffer_remap_(buffer_remap), + pipeline_loop_(pipeline_loop), + access_all_versions_(access_all_versions) {} + + private: + BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { + auto it = buffer_remap_.find(buffer_region->buffer); + if (it != buffer_remap_.end()) { + Region new_region = buffer_region->region; + const Buffer& new_buffer = (*it).second; + // For pipeline buffers, always relax the access region of the first dimension to full extent + Range accessed_version = + access_all_versions_ + ? Range::FromMinExtent(0, new_buffer->shape[0]) + : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), + new_buffer->shape[0]), + Integer(1)); + new_region.insert(new_region.begin(), accessed_version); + return BufferRegion(new_buffer, new_region); + } + return buffer_region; + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + BlockNode* n = block.CopyOnWrite(); + n->reads.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + n->writes.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(alloc_buffer->data); + } + return block; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return std::move(store); + } + const Buffer& new_buffer = (*it).second; + auto* n = store.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return std::move(load); + } + const Buffer& new_buffer = (*it).second; + auto* n = load.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(load); + } + + PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, + const PrimExpr& old_index) { + PrimExpr new_buffer_offset = old_index; + + const int fragment_size = 256; + PrimExpr offset = + floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), old_buffer->shape), + fragment_size); + new_buffer_offset += + floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; + return new_buffer_offset; + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // Intrinsic calls should be handled explicitly here as they are opaque accesses to + // buffer. + static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); + static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); + static const auto& mma_sync = builtin::tvm_mma_sync(); + static const auto& access_ptr = builtin::tvm_access_ptr(); + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); + return Call(call->dtype, call->op, new_args, call->span); + } + } else if (call->op.same_as(mma_sync)) { + Array new_args = call->args; + for (int i = 0; i < 4; i++) { + const Var& buffer_var = Downcast(call->args[i * 2]); + const PrimExpr& index = call->args[i * 2 + 1]; + const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); + new_args.Set(i * 2 + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } else if (call->op.same_as(access_ptr)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[1])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[2]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + new_args.Set(2, new_index); + return Call(call->dtype, call->op, new_args, call->span); + } + } + return std::move(call); + } + + Map buffer_data_to_buffer_; + Map buffer_remap_; + For pipeline_loop_; + bool access_all_versions_; +}; + +class PipelineRewriter : public StmtExprMutator { + public: + static Stmt Rewrite( + Map buffer_data_to_buffer, + const std::unordered_set& double_buffers, + const Array pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info) { + PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, + pipeline_info); + return rewriter.BuildPipeline(); + } + + private: + PipelineRewriter(Map buffer_data_to_buffer, + const std::unordered_set& double_buffers, + const Array& pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info) + + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + double_buffers_(double_buffers), + pipeline_allocs_(pipeline_allocs), + pipeline_loop_(pipeline_loop), + pipeline_info_(pipeline_info) {} + + Stmt BuildPipeline() { + // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions + // need to maintain for each buffer. + RemapPipelineBuffers(pipeline_allocs_); + + ordered_stmts_.resize(pipeline_info_.size()); + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int order = pair.second.order; + ordered_stmts_.Set(order, block); + } + + // Step 2: Emit the pipeline prologue, body and epilogue. + Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); + Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false); + Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true); + + SeqStmt stmt = SeqStmt({prologue, body, epilogue}); + + // Step 3: Add annotations of nested software pipeline (if appliable) + stmt = AnnotateNestedPipeline(stmt); + + // Step 4: Make a new block that contains new buffer allocations after pipeline rewriting. + Array alloc_buffers; + for (const auto& alloc : pipeline_allocs_) { + auto it = buffer_remap_.find(alloc); + if (it != buffer_remap_.end()) { + alloc_buffers.push_back((*it).second); + } else { + alloc_buffers.push_back(alloc); + } + buffer_data_to_buffer_.erase(alloc->data); + } + Block block = MakeBlock(stmt, buffer_data_to_buffer_); + auto* n = block.CopyOnWrite(); + n->alloc_buffers = std::move(alloc_buffers); + return BlockRealize({}, Bool(true), block); + } + + private: + /*! + * \brief Annotate the result of software pipeline rewriting with user-provided annotations. + * + * When there are nested software pipelines, after rewriting the inner software pipeline, + * it is required to add annotations to the result of the inner software pipeline to specify + * the rewriting behavior of the outer software pipeline. + * This method expects the annotations `attr::nested_software_pipeline_order`, and + * `attr::nested_software_pipeline_stage` are present on the inner software pipeline loop. + * + * \param pipeline_seq The sequence of statements after pipeline rewriting, which consists of + * three BlockRealize that represents the prologue, the body, and the epilogue of the software + * pipeline. + * \return The sequence of the statements that consists of the annotated software pipeline. + */ + SeqStmt AnnotateNestedPipeline(const SeqStmt& pipeline_seq) { + auto it = pipeline_loop_->annotations.find(attr::nested_software_pipeline_stage); + if (it == pipeline_loop_->annotations.end()) { + return pipeline_seq; + } + Array nested_stage = Downcast>((*it).second); + CHECK(pipeline_loop_->annotations.count(attr::nested_software_pipeline_order)) + << "ValueError: Annotation for the order of the nested software pipeline is missing."; + Array nested_order = Downcast>( + pipeline_loop_->annotations.at(attr::nested_software_pipeline_order)); + CHECK_EQ(nested_stage.size(), 3) << "ValueError: Annotation for the stage of the nested " + "software pipeline should be a 3-tuple"; + CHECK_EQ(nested_order.size(), 3) << "ValueError: Annotation for the order of the nested " + "software pipeline should be a 3-tuple"; + Array new_seq; + new_seq.reserve(pipeline_seq->seq.size()); + for (size_t i = 0; i < pipeline_seq->seq.size(); i++) { + BlockRealize block_realize = Downcast(pipeline_seq->seq[i]); + auto* block = block_realize.CopyOnWrite()->block.CopyOnWrite(); + block->annotations.Set(attr::software_pipeline_stage, nested_stage[i]); + block->annotations.Set(attr::software_pipeline_order, nested_order[i]); + new_seq.push_back(std::move(block_realize)); + } + return SeqStmt(std::move(new_seq)); + } + + /*! + * \brief Analyze accesses to the buffers in the software pipeline. + * + * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which + * can be used to compute the number of versions needed to maintain after rewriting. + */ + std::unordered_map + GetBufferAccessInfo() { + std::unordered_map infos; + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int stage = pair.second.stage; + max_stage_ = std::max(max_stage_, stage); + + for (const BufferRegion& write : block->writes) { + if (!infos.count(write->buffer)) { + infos.emplace(write->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(write->buffer); + if (info.def == -1) { + info.def = stage; + } + } + + for (const BufferRegion& read : block->reads) { + if (!infos.count(read->buffer)) { + infos.emplace(read->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(read->buffer); + info.use = std::max(info.use, stage); + } + } + return infos; + } + + /*! + * \brief Check whether two regions have intersections. + * \param region1 The first region. + * \param region2 The second region. + * \return Whether region1 and region2 have intersections. + */ + bool MayConflict(Region region1, Region region2) { + ICHECK(region1.size() == region2.size()); + for (size_t i = 0; i < region1.size(); i++) { + Range dim1 = region1[i]; + Range dim2 = region2[i]; + auto int_set1 = arith::IntSet::FromRange(dim1); + auto int_set2 = arith::IntSet::FromRange(dim2); + if (arith::Intersect({int_set1, int_set2}).IsNothing()) { + return false; + } + } + return true; + } + + /*! + * \brief Compute the number of versions need to maintain for buffer accessed in the software + * pipeline. + * + * This method applies liveness analysis to the target buffer to compute the number of versions + * need to maintain during the software pipeline. + * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the + * result of the analysis. Additional double buffering in the software pipeline can be useful + * to eliminate synchonizations in GPU devices. + * + * \param buffer The target buffer + * \param buffer_info The access information of the target buffer. + * \return The number of versions required for the target buffer. + */ + int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) { + if (buffer_info.def == -1) { + // Keep the original number of versions as buffers defined outside the software pipeline + // should not be mutated. + return 1; + } + + // `use - def + 1` is a upper bound of the needed versions + // We optimize a few case where the number of versions can be smaller than the upper bound + int num_versions = buffer_info.use - buffer_info.def + 1; + if (num_versions == 2) { + // A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when + // these exists a reader block_i and a writer block_j such that + // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions + // of block_i and block_j overlap. + bool need_multi_version = false; + for (const auto& pair1 : pipeline_info_) { + const Block& writer_block = pair1.first; + const auto& writer_info = pair1.second; + + auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), + [&](const BufferRegion& buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it1 == writer_block->writes.end()) { + continue; + } + + for (const auto& pair2 : pipeline_info_) { + const Block& reader_block = pair2.first; + const auto& reader_info = pair2.second; + auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), + [&](const BufferRegion& buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it2 == reader_block->reads.end()) { + continue; + } + if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage && + MayConflict((*it1)->region, (*it2)->region)) { + need_multi_version = true; + break; + } + } + } + if (!need_multi_version) { + num_versions = 1; + } + } + if (num_versions == 1 && double_buffers_.count(buffer)) { + num_versions = 2; + } + return num_versions; + } + + /*! + * \brief Rewrite buffer allocations to create new buffers with new shapes according to + * the software pipeline. + * \param pipeline_allocs The buffer allocations inside the software pipeline scope. + */ + void RemapPipelineBuffers(Array pipeline_allocs) { + std::unordered_map infos = + GetBufferAccessInfo(); + for (const auto& pair : infos) { + const Buffer& buffer = pair.first; + const BufferAccessInfo& buffer_info = pair.second; + int num_versions = ComputeBufferVersions(buffer, buffer_info); + if (num_versions > 1) { + Buffer new_buffer = RewriteAllocBuffer(buffer, num_versions); + CHECK(std::find(pipeline_allocs.begin(), pipeline_allocs.end(), buffer) != + pipeline_allocs.end()); + buffer_remap_.Set(pair.first, new_buffer); + } + } + } + + /*! + * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined + * accesses. + * \param buffer The buffer to be resized. + * \param num_versions The number of versions to keep. + * \return The resized buffer. + */ + Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { + ObjectPtr new_buffer = make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), num_versions); + if (new_buffer->strides.size()) { + ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); + } + return Buffer(new_buffer); + } + + Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { + Array stmts; + PrimExpr new_loop_var; + bool is_unit_loop = analyzer_.CanProveEqual(start + 1, end); + if (is_unit_loop) { + new_loop_var = start; + } else { + new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); + analyzer_.Bind(Downcast(new_loop_var), Range(start, end), true); + } + + for (const Block block : ordered_stmts_) { + int stage = pipeline_info_.at(block).stage; + PrimExpr skewed_loop_var = new_loop_var - stage; + PrimExpr inbound = (skewed_loop_var >= pipeline_loop_->min) && + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); + inbound = analyzer_.Simplify(inbound); + if (analyzer_.CanProve(!inbound)) { + continue; + } + Block new_block = Downcast(PipelineBodyRewriter( + buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block)); + Map subst_map; + if (is_unit_loop) { + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); + } else { + // normalize loop range + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min)); + } + new_block = Downcast(Substitute(new_block, subst_map)); + stmts.push_back(BlockRealize({}, inbound, new_block)); + } + + Stmt stmt; + if (is_unit_loop) { + stmt = stmts.size() == 1 ? stmts[0] : SeqStmt(stmts); + } else { + stmt = For(Downcast(new_loop_var), pipeline_loop_->min, end - start, + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, SeqStmt(stmts)); + } + if (stmt->IsInstance()) { + return stmt; + } + return BlockRealize({}, Bool(true), MakeBlock(stmt, buffer_data_to_buffer_)); + } + + arith::Analyzer analyzer_; + Map buffer_data_to_buffer_; + const std::unordered_set& double_buffers_; + Array pipeline_allocs_; + For pipeline_loop_; + PipelineInfo pipeline_info_; + int max_stage_ = -1; + Map buffer_remap_; + Array ordered_stmts_; +}; + +class PipelineInjector : private StmtExprMutator { + public: + static Stmt Inject(const PrimFunc& func) { + PipelineInjector injector; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + injector.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + return injector(func->body); + } + + private: + PipelineInjector() = default; + + PipelineStageOrder CheckAndRemovePipelineAnnotation(Map* annotations) const { + CHECK(annotations->count(attr::software_pipeline_stage)) + << "ValueError: Stage of the statement in the software pipeline is not defined."; + CHECK(annotations->count(attr::software_pipeline_order)) + << "ValueError: Order of the statement in the software pipeline is not defined."; + Integer stage = Downcast(annotations->at(attr::software_pipeline_stage)); + Integer order = Downcast(annotations->at(attr::software_pipeline_order)); + annotations->erase(attr::software_pipeline_stage); + annotations->erase(attr::software_pipeline_order); + return {static_cast(stage->value), static_cast(order->value)}; + } + + /*! + * \brief Check the pipeline satisfies the following conditions: + * 1) No conflicting order: The order of each statement should be unique. + * 2) No reordering with the same stage: Statements in the same stage are not allowed to be + * reordered. + */ + void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { + std::unordered_set used_orders; + std::unordered_map stage_max_order; + for (const Block& block : original_order) { + const auto& stmt_info = pipeline_info.at(block); + int stage = stmt_info.stage; + int order = stmt_info.order; + CHECK(!used_orders.count(order)) + << "ValueError: Two statements in the software pipeline cannot have the same order"; + used_orders.insert(order); + CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order) + << "ValueError: Statements in the same stage of the software pipeline must have " + "increasing order."; + stage_max_order[stage] = order; + } + } + + Stmt VisitStmt_(const ForNode* op) final { + // Step 1: Recursively rewrite the children first. + For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); + bool is_pipeline = HasPipelineAnnotation(op); + if (!is_pipeline) { + return std::move(for_node); + } + // Step 2: Find the body of the pipeline. It can be direct child of the for-loop. If the + // for-loop as BlockRealize as its child, the pipeline body will be the child of the block. + Stmt pipeline_body; + Array pipeline_allocs; + if (const auto* realize = for_node->body.as()) { + const auto& block = realize->block; + for (const auto& buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body = block->body; + pipeline_allocs = block->alloc_buffers; + } else { + pipeline_body = for_node->body; + } + + const SeqStmtNode* pipeline_body_seq = pipeline_body.as(); + CHECK(pipeline_body_seq) + << "ValueError: The body of the software pipeline should be SeqStmt, got " + << pipeline_body->GetTypeKey(); + const SeqStmtNode* original_seq = + op->body->IsInstance() + ? op->body.as()->block->body.as() + : op->body.as(); + ICHECK(original_seq); + + // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop should + // be converted into a block. + PipelineInfo pipeline_info; + Array original_order; + + auto f_add_child = [&](const Stmt& child) { + const auto* block_realize = child.as(); + Block block = (block_realize && is_one(block_realize->predicate)) + ? block_realize->block + : MakeBlock(child, buffer_data_to_buffer_); + original_order.push_back(block); + }; + for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { + const auto* nested_block_realize = pipeline_body_seq->seq[i].as(); + if (nested_block_realize && is_one(nested_block_realize->predicate) && + nested_block_realize->block->body->IsInstance()) { + const Block& nested_pipeline_block = nested_block_realize->block; + ICHECK( + nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered + for (const auto& buffer : nested_pipeline_block->alloc_buffers) { + pipeline_allocs.push_back(buffer); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + const auto* nested_seq = nested_pipeline_block->body.as(); + for (size_t j = 0; j < nested_seq->seq.size(); j++) { + f_add_child(nested_seq->seq[j]); + } + } else { + f_add_child(pipeline_body_seq->seq[i]); + } + } + + auto pipeline_stages = + Downcast>(op->annotations.at(attr::software_pipeline_stage)); + auto pipeline_orders = + Downcast>(op->annotations.at(attr::software_pipeline_order)); + CHECK_EQ(pipeline_stages.size(), original_order.size()); + CHECK_EQ(pipeline_orders.size(), original_order.size()); + for (size_t i = 0; i < pipeline_stages.size(); i++) { + PipelineStageOrder stage_order(pipeline_stages[i]->value, pipeline_orders[i]->value); + pipeline_info.emplace(original_order[i], stage_order); + } + // ValidatePipelineBody(pipeline_info, original_order); + + // Step 4: Rewrite the pipeline body. + Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, + pipeline_allocs, GetRef(op), pipeline_info); + + if (const auto* realize = op->body.as()) { + const auto& block = realize->block; + for (const auto& buffer : block->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + } + return pipeline; + } + + /*! + * \brief Add buffer allocations to a block and update the write region of the block. + * \param n The block pointer to which the buffer allocations are added. + * \param alloc_buffers The buffer allocations to be added. + */ + void AddAllocBuffers(BlockNode* n, const Array alloc_buffers) { + for (const Buffer& alloc_buffer : alloc_buffers) { + n->alloc_buffers.push_back(alloc_buffer); + Region region; + region.reserve(alloc_buffer->shape.size()); + for (const PrimExpr& dim : alloc_buffer->shape) { + region.push_back(Range::FromMinExtent(0, dim)); + } + n->writes.push_back(BufferRegion(alloc_buffer, region)); + } + } + + /*! + * \brief Flatten nested SeqStmt while passing through BlockRealize / Block. + * \param block The block which has SeqStmt body to rewrite. + * \return The new block that contains flattened SeqStmt as its body. + */ + Block FlattenNestedBlocks(Block block) { + const SeqStmtNode* seq = block->body.as(); + auto* n = block.CopyOnWrite(); + Array new_seq; + new_seq.reserve(seq->seq.size()); + bool changed = false; + for (size_t i = 0; i < seq->seq.size(); i++) { + const auto* nested_block_realize = seq->seq[i].as(); + if (!nested_block_realize || !is_one(nested_block_realize->predicate) || + !nested_block_realize->block->body->IsInstance()) { + new_seq.push_back(seq->seq[i]); + continue; + } + AddAllocBuffers(n, nested_block_realize->block->alloc_buffers); + const auto* nested_seq = nested_block_realize->block->body.as(); + new_seq.reserve(new_seq.size() + nested_seq->seq.size()); + for (const auto& nested_seq_body : nested_seq->seq) { + new_seq.push_back(nested_seq_body); + } + changed = true; + } + if (changed) { + n->body = SeqStmt(new_seq); + } + return block; + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const auto& buffer : op->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + auto it = op->annotations.find(attr::double_buffer_scope); + if (it != op->annotations.end()) { + int buffer_index = Downcast((*it).second); + CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size()) + << "ValueError: Index of the buffer exceeds the size of the write regions of the block. (" + << buffer_index << " vs. " << op->writes.size() << ")"; + double_buffers.insert(op->writes[buffer_index]->buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + // if (block->body->IsInstance()) { + // // Rewriting for software pipelining will produce nested SeqStmt. These statements need to + // be + // // flattened for rewriting outer software pipeline (if nested software pipelines are + // present). block = FlattenNestedBlocks(block); + // } + + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + return block; + } + + bool HasPipelineAnnotation(const ForNode* op) const { + auto it1 = op->annotations.find(attr::software_pipeline_stage); + auto it2 = op->annotations.find(attr::software_pipeline_order); + bool has_stage = it1 != op->annotations.end(); + bool has_order = it2 != op->annotations.end(); + if (has_stage && has_order) { + return true; + } + if (has_stage) { + LOG(FATAL) << "ValueError: Order of the software pipeline is not defined."; + } + if (has_order) { + LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined."; + } + return false; + } + + Map buffer_data_to_buffer_; + std::unordered_set double_buffers; +}; + +} // namespace software_pipeline + +namespace transform { + +/*! + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. + * \return The IR transform pass. + */ +Pass InjectSoftwarePipeline() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* fptr = f.CopyOnWrite(); + fptr->body = software_pipeline::PipelineInjector::Inject(f); + fptr->body = ConvertSSA(std::move(fptr->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 4df38ff543b5..d2dd95a581ed 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -149,22 +149,22 @@ Array RemoveBufferFromBufferRegions(const Array& buf /*! * \brief Substitute a given source buffer with a given target buffer in statements or expressions */ -class BufferReplacer : private StmtExprMutator { +class BufferMutator : private StmtExprMutator { public: static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { - return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + return BufferMutator(src_buffer, tgt_buffer)(std::move(stmt)); } private: - explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) + explicit BufferMutator(Buffer src_buffer, Buffer tgt_buffer) : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) : GetRef(load); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { if (store->buffer.same_as(src_buffer_)) { PrimExpr value = StmtExprMutator::VisitExpr(store->value); return BufferStore(tgt_buffer_, value, {0}); @@ -287,7 +287,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionalwrites = {it_buffer_region.value()}; new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = - BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + BufferMutator::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); new_block->init = NullOpt; ObjectPtr n = make_object(*realize); n->block = Block(new_block); diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc new file mode 100644 index 000000000000..7925f4e090c4 --- /dev/null +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../runtime/thread_storage_scope.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Fuse consecutive loops + * \param body the outer-most loop + * \return the fused loop + */ +Stmt FuseNestLoops(Stmt body) { + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Map subst_map; + PrimExpr tot = fused_var; + for (int i = n - 1; i >= 0; i--) { + subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + auto f_substitute = [&](const Var& v) -> Optional { + return subst_map.Get(v).value_or(v); + }; + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + return For(fused_var, 0, fused_extent, ForKind::kSerial, + Substitute(std::move(body), f_substitute)); +} + +/*! + * \brief a combination of split, bind, vectorize, + * a helper function to perform coalesced load/store + * \param stmt the stmt to do transformation + * \param constraints The constraints, including thread extents, vector bytes, and data bits. + * \return The stmt after transformation + */ +Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { + const ForNode* loop = TVM_TYPE_AS(loop, stmt, ForNode); + int loop_extent = Downcast(loop->extent)->value; + int vector_bytes = constraints.vector_bytes; + int data_bits = constraints.data_bits; + int vector_len = std::max(1, vector_bytes * 8 / data_bits); + int tot_threads = 1; + // generate thread binding loops + std::vector factors{-1}; + std::vector thread_axis; + if (Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.z"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.y"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.x"); + } + // generate vectorized loop + factors.push_back(vector_len); + // generate outer loop + ICHECK_EQ(loop_extent % (tot_threads * vector_len), 0); + factors[0] = loop_extent / (tot_threads * vector_len); + // create new loop vars + int n = factors.size(); + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); + } + // substitute fused loop var with new loop vars + PrimExpr substitute_value = 0; + for (int i = 0; i < n; i++) { + substitute_value *= factors[i]; + substitute_value += new_loop_vars[i]; + } + // Construct the new loop nest + Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }); + body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); + for (int i = n - 2; i >= 1; i--) { + body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + } + return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); +} + +Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_fuse = FuseNestLoops(stmt); + Stmt after_split = SplitBindVectorize(std::move(after_fuse), constraints); + return after_split; +} + +/*! + * \brief Get the index mapping of a specific stmt. + * The stmt is like: + * for i0: + * ... + * for in: + * A[f(i0, ..., in])] = B[i0, ..., in], + * where f is the index mapping we want to get. + * \param constraints The constraints, including the write region that is required to calculate + * the index mapping + * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) + */ +Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + body = loop->body; + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + BufferRegion write_region = constraints.write_region; + const Array& write_index = buf_store->indices; + ICHECK(write_region->region.size() == write_index.size() && + write_region->buffer.same_as(buf_store->buffer)); + Array result; + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(write_region->region.size()); i++) { + PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); + if (!is_zero(pattern)) { + result.push_back(pattern); + } + } + return result; +} + +Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body = stmt; + Map var_range; + Array loop_vars; + // Step 1. Get index mapping + Array mapping_pattern = GetMapping(stmt, constraints); + while (const ForNode* loop = body.as()) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loop_vars.push_back(loop->loop_var); + body = loop->body; + } + // Step 2. Get Inverse mapping + arith::Analyzer analyzer; + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array iter_map = + arith::DetectIterMap(mapping_pattern, var_range, Bool(true), true, &analyzer, diag_ctx); + CHECK_EQ(iter_map.size(), loop_vars.size()); + Map inverse_mapping = arith::InverseAffineIterMap(iter_map, loop_vars); + // Step 3. Generate new body + BufferRegion read_region = constraints.read_region; + BufferRegion write_region = constraints.write_region; + Array write_index; + Array read_index; + Array new_loop_vars; + Map substitute_map; + // Step 3.1 construct target buffer indices + for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { + if (is_one(write_region->region[i]->extent)) { + write_index.push_back(write_region->region[i]->min); + } else { + Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + new_loop_vars.push_back(var); + substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + write_index.push_back(write_region->region[i]->min + var); + } + } + // Step 3.2 construct source buffer indices + for (int i = 0, j = 0; i < static_cast(read_region->region.size()); i++) { + if (is_one(read_region->region[i]->extent)) { + read_index.push_back(read_region->region[i]->min); + } else { + read_index.push_back( + read_region->region[i]->min + + Substitute(inverse_mapping[Downcast(loop_vars[j++])], substitute_map)); + } + } + BufferLoad new_buf_load = BufferLoad(read_region->buffer, read_index); + BufferStore new_buf_store = BufferStore(write_region->buffer, new_buf_load, write_index); + Stmt ret = new_buf_store; + // Step 3.3 construct loop body + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; i--) { + PrimExpr extent = write_region->region[i]->extent; + ret = For(new_loop_vars[i], 0, extent, ForKind::kSerial, std::move(ret)); + } + return ret; +} +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc new file mode 100644 index 000000000000..4ffffc9fdeab --- /dev/null +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, + Stmt* ith_loop = nullptr) { + Stmt ret = inner_body; + for (int i = static_cast(loops.size() - 1); i >= 0; i--) { + ObjectPtr new_loop = make_object(*loops[i]); + new_loop->body = ret; + ret = For(new_loop); + if (ith == i) { + *ith_loop = ret; + } + } + return ret; +} + +/*! + * \brief lift all the thread binding loops + * \param stmt the top loop + * \return a pair. The first is the transformed stmt. + * The second is the lowest thread binding loop. + */ +std::pair> LiftThreadBindingLoops(Stmt stmt) { + std::vector normal_loops; + std::vector thread_binding_loops; + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + if (loop->kind == ForKind::kThreadBinding) { + thread_binding_loops.push_back(loop); + } else { + normal_loops.push_back(loop); + } + body = loop->body; + } + body = CopyLoopChain(normal_loops, std::move(body)); + For compute_location{nullptr}; + body = CopyLoopChain(thread_binding_loops, std::move(body), + static_cast(thread_binding_loops.size()) - 1, &compute_location); + return std::make_pair(body, compute_location); +} + +/*! + * \brief Analyze the access pattern for buffer rank promotion. + * Rank promotion is a transformation that reshapes the buffer + * but doesn't change its underlying data layout. + * After the reshape, we expect that all dimensions of the access indices + * will be in the form of floormod(floordiv(x, a), b). + * Rank promotion removes strided access, thus enabling further buffer compacting + */ +class IndexPatternFinder : public ExprVisitor { + public: + IndexPatternFinder(const Map& var_range, Array* resulting_index) + : var_range_(var_range), resulting_index_(resulting_index) {} + + /*! + * \brief Calculate the new buffer shape after rank promotion. + * For each dimension of original shape, it will be split into multiple parts. + * The inner array represents the multiple parts of one original dimension, + * and the outer array represents the original dimensions + * For example, original shape [4, 8] may be split into [[2, 2], [2, 4]] + * \param indices The access indices of the buffer + * \param var_range The iter range of the vars in the indices + * \param rewrite_indices The access indices after rank promotion + * \return The new buffer shape after rank promotion. + */ + static Array> GetRankPromotedShape(Array indices, + const Map& var_range, + Array* rewrite_indices) { + Map var_dom = AsIntSet(var_range); + Array> new_shape; + for (const PrimExpr& expr : indices) { + IndexPatternFinder extractor(var_range, rewrite_indices); + arith::IntSet intset = arith::EvalSet(expr, var_dom); + extractor.mod_ = intset.max() + 1; + extractor.div_ = 1; + extractor.offset_ = 0; + extractor(expr); + Array access_shape = extractor.access_shape_; + for (int i = static_cast(access_shape.size()) - 1; i >= 1; i--) { + if (!is_zero(floormod(extractor.offset_, access_shape[i]))) { + return {}; + } else { + extractor.offset_ = floordiv(extractor.offset_, access_shape[i]); + } + } + access_shape.Set(0, extractor.offset_ + access_shape[0]); + new_shape.push_back(access_shape); + } + return new_shape; + } + + private: + void VisitExpr_(const VarNode* op) final { + arith::Analyzer analyzer; + PrimExpr extent = var_range_[GetRef(op)]->extent; + PrimExpr access_iter_range = min(mod_, (max(1, floordiv(extent, div_)))); + if (!analyzer.CanProveEqual(1, access_iter_range)) { + access_shape_.push_back(access_iter_range); + resulting_index_->push_back(floormod(floordiv(GetRef(op), div_), mod_)); + } + } + + void VisitExpr_(const FloorDivNode* op) final { + PrimExpr old_div = div_; + div_ *= op->b; + ExprVisitor::VisitExpr_(op); + div_ = old_div; + } + + void VisitExpr_(const FloorModNode* op) final { + PrimExpr old_mod = mod_; + mod_ = max(1, min(floordiv(op->b, div_), mod_)); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + } + + void VisitExpr_(const MulNode* op) final { + PrimExpr old_mod = mod_; + PrimExpr old_div = div_; + div_ = max(1, floordiv(div_, op->b)); + mod_ = max(1, floordiv(mod_, floordiv(op->b, floordiv(old_div, div_)))); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + div_ = old_div; + } + + void VisitExpr_(const AddNode* op) final { + if (is_const_int(op->b)) { + offset_ += floormod(floordiv(op->b, div_), mod_); + } + ExprVisitor::VisitExpr_(op); + } + + PrimExpr div_; + PrimExpr mod_; + PrimExpr offset_; + Map var_range_; + Array access_shape_; + Array* resulting_index_; +}; + +/*! + * \brief Utilities to perform rank promotion + */ +class RankPromoter : public StmtExprMutator { + public: + /*! + * \brief Flatten the buffer shape like performing inverse rank promotion. + * For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1] + * \param new_shape The buffer shape in the special form as returned by GetRankPromotedShape + * \return The buffer shape after flatten + */ + static Array FlattenNewShape(const Array>& new_shape) { + Array ret; + ret.reserve(new_shape.size()); + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr prod = 1; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + prod *= new_shape[i][j]; + } + ret.push_back(prod); + } + return ret; + } + /** + * \brief Rewrite the index given the shape after rank promotion + * \param indices The original indices + * \param new_shape The buffer shape after rank promotion + * \return The new indices + */ + static Array RewriteIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + ICHECK_EQ(indices.size(), new_shape.size()); + for (int i = 0; i < static_cast(indices.size()); i++) { + PrimExpr index = indices[i]; + // The indices transformed from one original dimension + Array index_dim(new_shape[i].size(), 0); + for (int j = static_cast(new_shape[i].size()) - 1; j >= 0; j--) { + index_dim.Set(j, floormod(index, new_shape[i][j])); + index = floordiv(index, new_shape[i][j]); + } + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + new_indices.push_back(index_dim[j]); + } + } + return new_indices; + } + /*! + * \brief Rewrite the index after buffer flattening + * \param indices The original indices + * \param new_shape The shape before buffer flattening + * \return The indices after buffer flattening + */ + static Array RewriteBackIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + int offset = 0; + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr index = 0; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + index *= new_shape[i][j]; + index += indices[offset + j]; + } + new_indices.push_back(index); + offset += new_shape[i].size(); + } + return new_indices; + } + RankPromoter(const Buffer& src, const Buffer& dst, const Array>& new_shape, + const Array>& relaxed_new_shape, const Array& relaxed_region) + : src_(src), + dst_(dst), + new_shape_(new_shape), + relaxed_new_shape_(relaxed_new_shape), + relaxed_region_(relaxed_region) {} + + static Stmt RewriteBody(Stmt stmt, const Buffer& src, const Buffer& dst, + const Array>& new_shape, + const Array>& relaxed_new_shape, + const Array& relaxed_region) { + RankPromoter promoter(src, dst, new_shape, relaxed_new_shape, relaxed_region); + return promoter(stmt); + } + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + new_store->indices = ConvertIndices(new_store->indices); + return BufferStore(new_store); + } + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + new_load->indices = ConvertIndices(new_load->indices); + return BufferLoad(new_load); + } + return std::move(load); + } + + /*! + * \brief Rewrite the indices after performing buffer rank promotion + + * buffer compacting + buffer flattening. + * \param indices The original indices + * \return The indices after these transformations + */ + Array ConvertIndices(const Array& indices) { + Array rewrite_indices = RewriteIndex(indices, new_shape_); + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(rewrite_indices.size()); i++) { + rewrite_indices.Set(i, analyzer.Simplify(rewrite_indices[i] - relaxed_region_[i]->min)); + } + return RewriteBackIndex(rewrite_indices, relaxed_new_shape_); + } + + const Buffer& src_; + const Buffer& dst_; + Array> new_shape_; + Array> relaxed_new_shape_; + Array relaxed_region_; +}; + +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer) { + Stmt body = stmt; + std::vector loops; + bool need_relax = !compute_location.defined(); + Map relax_var_range; + Map all_var_range; + PrimExpr vector_bytes = -1; + // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into + // several contiguous-changing dimensions + // Step 1.1 collect loop var range for rank promotion + while (const ForNode* loop = body.as()) { + if (need_relax) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else { + loops.push_back(loop); + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + if (loop == compute_location.get()) { + need_relax = true; + } + if (loop->kind == ForKind::kVectorized) { + vector_bytes = loop->extent; + } + body = loop->body; + } + for (const For& loop : outer_loops) { + if (loop->kind == ForKind::kThreadBinding) { + const String& thread_tag = loop->thread_binding.value()->thread_tag; + if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), + runtime::ThreadScope::Create(thread_tag))) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer; + Array indices = is_write_cache ? buf_store->indices : buf_load->indices; + // Step 1.2 get the new shape and new access indices after rank promotion + Array rewrite_indices; + Array> new_shape = + IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices); + // Step 2. relax the access region after rank promotion + arith::Analyzer analyzer; + analyzer.Bind(all_var_range); + Array relaxed_region; + relaxed_region.reserve(rewrite_indices.size()); + { + Map relax_var_intset = AsIntSet(relax_var_range); + for (const PrimExpr& index : rewrite_indices) { + arith::IntSet int_set = arith::EvalSet(index, relax_var_intset); + relaxed_region.push_back(Range::FromMinExtent( + int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1))); + } + } + // Step 3. generate the data copy bodies + // preparation work + Array new_loop_vars; + Array orig_buf_indices, new_buf_indices; + Array> relaxed_new_shape; + for (int i = 0; i < static_cast(relaxed_region.size()); i++) { + Var new_loop_var = Var("ax" + std::to_string(i)); + new_loop_vars.push_back(new_loop_var); + orig_buf_indices.push_back(relaxed_region[i]->min + new_loop_var); + new_buf_indices.push_back(new_loop_var); + } + relaxed_new_shape.reserve(new_shape.size()); + for (int i = 0, ct = 0; i < static_cast(new_shape.size()); i++) { + Array layer; + for (int j = 0; j < static_cast(new_shape[i].size()); j++, ct++) { + layer.push_back(relaxed_region[ct]->extent); + } + relaxed_new_shape.push_back(layer); + } + // Step 3.1 create a buffer for the cache + Buffer new_buffer = WithScope(orig_buffer, storage_scope); + new_buffer.CopyOnWrite()->shape = RankPromoter::FlattenNewShape(relaxed_new_shape); + *alloc_buffer = new_buffer; + Array real_orig_buf_indices = + RankPromoter::RewriteBackIndex(orig_buf_indices, new_shape); + Array real_new_buf_indices = + RankPromoter::RewriteBackIndex(new_buf_indices, relaxed_new_shape); + // Step 3.2 generate a body that writes to the cache + Stmt generate_body = is_write_cache + ? BufferStore(orig_buffer, BufferLoad(new_buffer, real_new_buf_indices), + real_orig_buf_indices) + : BufferStore(new_buffer, BufferLoad(orig_buffer, real_orig_buf_indices), + real_new_buf_indices); + for (int i = static_cast(relaxed_region.size()) - 1; i >= 0; i--) { + if (i == static_cast(relaxed_region.size()) - 1 && !is_const_int(vector_bytes, -1)) { + ICHECK(analyzer.CanProve(vector_bytes == relaxed_region[i]->extent)); + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kVectorized, generate_body); + } else { + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kSerial, generate_body); + } + } + // Step 3.3 rewrite the original body to load from cache + Stmt rewrite_body; + if (compute_location.defined()) { + rewrite_body = compute_location.value()->body; + } else { + rewrite_body = stmt; + } + rewrite_body = RankPromoter::RewriteBody(rewrite_body, orig_buffer, new_buffer, new_shape, + relaxed_new_shape, relaxed_region); + SeqStmt insert_location; + if (is_write_cache) { + generate_body = insert_location = SeqStmt({rewrite_body, generate_body}); + } else { + generate_body = insert_location = SeqStmt({generate_body, rewrite_body}); + } + generate_body = CopyLoopChain(loops, generate_body); + return std::make_pair(generate_body, insert_location); +} + +Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body; + Optional compute_location; + std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt)); + Buffer cache_buffer; + Stmt after_caching = InsertCacheStage(body, false, "local", compute_location, + constraints.outer_loops, &cache_buffer) + .first; + output->alloc_buffer.push_back(cache_buffer); + return after_caching; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc new file mode 100644 index 000000000000..a0103aab380b --- /dev/null +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -0,0 +1,763 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +// rewrite rules +static InverseMapping inverse_mapping; +static CoalescedAccess coalesced_access; +static CreateLocalStage create_local_stage; +static SharedToWmma shared_to_wmma; +static WmmaToGlobal wmma_to_global; +static WmmaToShared wmma_to_shared; + +/*! + * \brief A class to perform auto padding. + * + * One simple way to perform auto padding is to fix each padding size for each dimension at the + * same time, calculate the precise access index and the bank conflict, + * and choose the one with minimal conflict. However, this algorithm has exponential complexity. + * Suppose we have d dimensions and the padding size is 0-31, we need to calculate bank + * conflict for 32^{d-1} times. + * We propose a fast incremental algorithm that works for affine inputs, and it only calculate + * bank conflict for 32*{d-1} times. To be specific, we first decide the optimal padding size for + * dimension d-2, then for dimension d-3, ..., finally for dimension 0. It involves 2 steps. + * + * First, we analyze how a typical warp accesses the shared memory banks. + * A typical warp means setting all irrelevant loop vars to 0, and only keeps the threads in a warp. + * For each dimension, the access index is represented by + * x_1 * scale_1 + ... + x_n * scale_n (x_i is loop var) + * Note: The affine property guarantees that {x_i} must be independent, + * otherwise the algorithm is wrong. + * We will use this information to keep a list for each dimension called "iteration space" that + * records the resulting index as x_i takes each possible value. + * + * For example, the index is [outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is threadIdx.y. + * tx is in [0, 16), and ty is in [0, 2). + * We will first get a warp access [ty, tx*4] because outer and vec are irrelevant loop vars. + * It's obvious that ty, tx*4 are both in the form of x_1 * scale_1 + ... + x_n * scale_n. + * In this case, we will keep lists {{0, 1}, {0, 4, ..., 60}} + * + * Next, we choose a padding size that has minimal conflict from the last dimension to first one. + * To calculate the conflict, we calculate the Cartesian product of the iteration space of all + * dimensions not higher than this. Each single point of product space represents access index + * of a particular thread, by which we can calculate the accessed memory bank. The conflict is + * the highest access frequency among the banks. + * + */ +class AutoPadder { + public: + /** + * \brief Do padding to the given buffers in shard memory + * \param buffers the given buffers + * \return the list of new padded buffers + */ + Array PadSharedMemory(const Array& buffers) { + Array result; + + for (const Buffer& buffer : buffers) { + runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + auto iter_spaces = iter_spaces_[buffer.get()]; + if (iter_spaces.empty()) { + result.push_back(buffer); + continue; + } + // The access index represented by points in the cartesian product of lower dimension + // iteration spaces + std::vector> low_dim_iter_space(iter_spaces.size(), std::vector()); + + int n = buffer->shape.size(); + int data_bits = buffer->dtype.bits(); + // Step 1. initialize `low_dim_iter_space` with the iteration space of the last dim + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { + auto last_dim_iter_space = iter_spaces[i][n - 1]; + low_dim_iter_space[i] = last_dim_iter_space; + } + PrimExpr stride = 1; + Array reverse_strides; + int pad_min = padding_min_.Get(buffer).value_or(Integer(1)); + // Step 2. For each dimension, select a padding that has minimal bank conflict + for (int k = n - 2; k >= 0; k--) { // dims + int max_pad_size = std::min( + int(max_pad_factor_ * (stride * buffer->shape[k + 1]).as()->value), + 32 * 32 / data_bits); + int min_conflict = INT32_MAX; + int min_conflict_pad = -1; + for (int pad = 0; pad <= max_pad_size; pad += pad_min) { // select padding + int padded_stride = ((stride * buffer->shape[k + 1]).as()->value + pad) % + (32 * 32 / data_bits); + int conflict = 0; + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + int bank[32]{0}; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + int comb = (v1 * padded_stride + v2) * data_bits / 32 % 32; + bank[comb]++; + } + } + for (int j = 0; j < 32; j++) { + conflict = std::max(conflict, bank[j]); + } + } + if (conflict < min_conflict) { + min_conflict = conflict; + min_conflict_pad = pad; + } + } + // update low_dim_iter_space with + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + if (!iter_space.empty()) { + int padded_stride = + ((stride * buffer->shape[k + 1]).as()->value + min_conflict_pad) % + (32 * 32 / data_bits); + std::vector span; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + span.push_back(((v1 * padded_stride + v2) * data_bits) % (32 * 32 / data_bits)); + } + } + low_dim_iter_space[i] = span; + } else { + ICHECK(min_conflict_pad == 0); + } + } + stride = stride * buffer->shape[k + 1] + min_conflict_pad; + reverse_strides.push_back(stride); + } + // Step 3. create the new padded buffer + ObjectPtr b = make_object(*buffer.get()); + Array strides; + for (int i = static_cast(reverse_strides.size()) - 1; i >= 0; i--) { + strides.push_back(reverse_strides[i]); + } + strides.push_back(1); + b->strides = strides; + Buffer new_buffer(b); + result.push_back(new_buffer); + padded_buffer_map_.Set(buffer, new_buffer); + } else { + result.push_back(buffer); + } + } + return result; + } + + /** + * \brief Replace all occurrence of the old buffer with the new buffer in the stmt + * \param stmt the stmt to do replacement + * \return the stmt after replacement + */ + Stmt RewriteBufferAccess(const Stmt& stmt) { + class Rewriter : public StmtExprMutator { + public: + Rewriter(const Map& buffer_map) : buffer_map_(buffer_map) {} + + private: + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(store); + } + + Stmt VisitStmt_(const BlockNode* op) final { + // To reduce the number of blocks in block sref reuse map, we check whether the block is + // really mutated (i.e., the old buffer appears in the block). If so, we return the block + // after mutation. Otherwise we just return the original block. + bool changed = false; + // Step 1. Mutate the read region. + Array reads; + for (const BufferRegion& read : op->reads) { + if (buffer_map_.count(read->buffer)) { + changed = true; + reads.push_back(BufferRegion(buffer_map_[read->buffer], read->region)); + } else { + reads.push_back(read); + } + } + // Step 2. Mutate the write region. + Array writes; + for (const BufferRegion& write : op->writes) { + if (buffer_map_.count(write->buffer)) { + changed = true; + writes.push_back(BufferRegion(buffer_map_[write->buffer], write->region)); + } else { + writes.push_back(write); + } + } + // Step 4. Mutate `match_buffers`. If an old buffer appears as a source of + // MatchBufferRegion, the storage scope of the target buffer also needs to be set. + Array match_buffers; + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + if (buffer_map_.count(match_buffer->source->buffer)) { + changed = true; + Buffer new_buffer = buffer_map_[match_buffer->source->buffer]; + match_buffers.push_back(MatchBufferRegion( + match_buffer->buffer, BufferRegion(new_buffer, match_buffer->source->region))); + } else { + match_buffers.push_back(match_buffer); + } + } + // Step 5. Recursively mutate the block. + Stmt res = StmtMutator::VisitStmt_(op); + if (res.get() != op) { + changed = true; + } + + if (changed) { + ObjectPtr block = CopyOnWrite(res.as()); + block->reads = std::move(reads); + block->writes = std::move(writes); + block->match_buffers = std::move(match_buffers); + return Stmt(block); + } else { + return GetRef(op); + } + } + const Map& buffer_map_; + }; + Rewriter rewriter(padded_buffer_map_); + return rewriter(stmt); + } + + /** + * \brief an equivalent of scale * loop_var with loop_var: {min=0, extent=extent} + */ + struct Pattern { + int extent; + int scale; + }; + + /** + * \brief Collect pattern from indices + */ + class PatternCollector : public StmtExprVisitor { + void VisitExpr_(const VarNode* op) final { + if (!success_) { + return; + } + int extent = var_range_[GetRef(op)]->extent.as()->value; + if (extent > 1) { + stack_.push({{extent, 1}}); + } else { + stack_.push({}); + } + } + + void VisitExpr_(const AddNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector merged_patterns; + std::vector r = stack_.top(); + stack_.pop(); + std::vector l = stack_.top(); + stack_.pop(); + for (const Pattern& pattern : l) { + merged_patterns.push_back(pattern); + } + for (const Pattern& pattern : r) { + merged_patterns.push_back(pattern); + } + if (merged_patterns.empty()) { + stack_.push({}); + return; + } + std::vector ret; + ret.push_back(merged_patterns[0]); + for (int i = 0; i < static_cast(merged_patterns.size()); i++) { + Pattern prev_pattern = ret.back(); + if (merged_patterns[i].extent * merged_patterns[i].scale == prev_pattern.scale) { + ret.pop_back(); + ret.push_back( + {prev_pattern.extent * merged_patterns[i].extent, merged_patterns[i].scale}); + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorDivNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int lower_factor = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale >= lower_factor) { + if (pattern.scale % lower_factor == 0) { + ret.push_back({pattern.extent, pattern.scale / lower_factor}); + } else { + success_ = false; + } + } else if (pattern.scale * pattern.extent > lower_factor) { + if ((pattern.scale * pattern.extent) % lower_factor == 0) { + ret.push_back({pattern.extent * pattern.scale / lower_factor, 1}); + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorModNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int extent = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale < extent) { + if (extent % pattern.scale == 0) { + if (extent / pattern.scale < pattern.extent) { + ret.push_back({extent / pattern.scale, pattern.scale}); + } else { + ret.push_back({pattern.extent, pattern.scale}); + } + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const MulNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int scale = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + ret.push_back({pattern.extent, pattern.scale * scale}); + } + stack_.push(ret); + } + + public: + PatternCollector(const Map& var_range) : var_range_(var_range) {} + + /*! + * \brief Collect the iteration space for given indices. The iteration space is the possible + * values that an index can take (do not remove duplicate). + * For example, the input is [ty, tx*4], where tx is in [0, 16), and ty is in [0, 2). + * The output would be {{0, 1}, {0, 4, ..., 60}} + * \param indices The indices to analyze + * \param var_range The range of loop variables + * \param data_bits The size of dtype in bits + * \return The iteration space. The first array represents dimensions, and the second array + * represents the iteration space of one dimension + */ + static std::vector> CollectIterationSpace(const Array& indices, + const Map& var_range, + int data_bits) { + PatternCollector collector(var_range); + std::vector> ret; + for (int i = 0; i < static_cast(indices.size()); i++) { + collector(indices[i]); + if (collector.success_ && collector.stack_.size() == 1) { + auto patterns = collector.stack_.top(); + int extent_prod = 1; + for (const Pattern& p : patterns) { + extent_prod *= p.extent; + } + std::vector iter_space; + for (int thread_id = 0; thread_id < extent_prod; thread_id++) { + int index = 0; + int n = thread_id; + for (int j = static_cast(patterns.size()) - 1; j >= 0; j--) { + int val = n % patterns[j].extent; + index += val * patterns[j].scale; + n /= patterns[j].extent; + } + iter_space.push_back(index); + } + + ret.push_back(iter_space); + collector.stack_.pop(); + } else { + ret.push_back({}); + } + } + return ret; + } + + std::stack> stack_; + const Map& var_range_; + bool success_ = true; + }; + + /*! A utility class for calling CollectIterationSpace to each buffer access*/ + class IterSpaceAnalyzer : public StmtExprVisitor { + public: + IterSpaceAnalyzer(const Map& substitute_map, AutoPadder* self, int data_bits, + const Map warp_thread_extent) + : substitute_map_(substitute_map), + self(self), + data_bits_(data_bits), + warp_thread_extent_(warp_thread_extent) {} + + private: + bool CheckVarContiguous(PrimExpr e, Var var) { + PrimExpr e1 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(0); + } else { + return v; + } + }); + PrimExpr e2 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(1); + } else { + return v; + } + }); + arith::Analyzer analyzer; + return analyzer.CanProve(e2 - e1 == 1); + } + + void VisitStmt_(const ForNode* op) final { + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.Set(op->loop_var, op->min); + } else { + Integer extent = + warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1); + var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, extent)); + } + if (op->kind == ForKind::kVectorized) { + vector_var = op->loop_var; + vector_length_ = op->extent.as()->value; + } + StmtExprVisitor::VisitStmt_(op); + if (op->kind == ForKind::kVectorized) { + vector_length_ = -1; + } + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.erase(op->loop_var); + } + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer store + * For example, the access is A[outer*2+ty, tx*4+vec] = xxx, where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer store + */ + void VisitStmt_(const BufferStoreNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && CheckVarContiguous(substitued_indices.back(), vector_var)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitStmt_(op); + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer load + * For example, the access is xxx = A[outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer load + */ + void VisitExpr_(const BufferLoadNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && CheckVarContiguous(substitued_indices.back(), vector_var)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + /*! + * \brief Take a typical warp and collect the iteration space for load_matrix_sync and + * store_matrix_sync + * For example, the access region is A[y*16+16, x*16+16], where y and x are not bound to + * threadIdx. The iteration space would be {{0, 1, ..., 15}, {0, 1, ..., 15}}. + * \param op the call node + */ + void VisitStmt_(const BlockNode* op) final { + if (const auto* eval = op->body.as()) { + if (const auto* call = eval->value.as()) { + if (call->op == builtin::tvm_load_matrix_sync() || + call->op == builtin::tvm_store_matrix_sync()) { + for (const MatchBufferRegion& r : op->match_buffers) { + Buffer src_buffer = r->source->buffer; + runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Region region = r->source->region; + Array indices; + for (int i = 0; i < static_cast(region.size()); i++) { + Var var("region" + std::to_string(i)); + indices.push_back(region[i]->min + var); + var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent)); + } + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = PatternCollector::CollectIterationSpace( + substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[src_buffer.get()].push_back(iter_space); + } + } + } + } + } + } + } + + Map substitute_map_; + AutoPadder* self; + int data_bits_; + Map warp_thread_extent_; + Map var_range_; + int vector_length_ = -1; + Var vector_var; + }; + + /*! + * \brief Analyze the shared memory access + * \param stmt The data copy + * \param outer_loops The outer loops of the stmt + * \param data_bits The length of dtype in bits + * \param thread_extent The extents of all thread binding loops + */ + void AnalyzeSharedMemoryAccess(const Stmt& stmt, const Array& outer_loops, int data_bits, + const Map& thread_extent) { + Map warp_thread_extent; + Integer prod = 1; + Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; + arith::Analyzer analyzer; + for (int i = 0; i < 3; i++) { + Integer extent = thread_extent.Get(thread_tags[i]).value_or(1); + if (analyzer.CanProve(prod * extent >= 32)) { + warp_thread_extent.Set(thread_tags[i], Downcast(floordiv(32, prod))); + prod *= floordiv(32, prod); + break; + } else { + warp_thread_extent.Set(thread_tags[i], Downcast(extent)); + prod *= extent; + } + } + Map substitute_map; + for (const For& loop : outer_loops) { + substitute_map.Set(loop->loop_var, loop->min); + } + IterSpaceAnalyzer iter_space_analyzer(substitute_map, this, data_bits, warp_thread_extent); + iter_space_analyzer(stmt); + } + + private: + /*! \brief A map from the old buffers to the new padded buffers */ + Map padded_buffer_map_; + /*! \brief A map from each buffer to the iteration spaces of the accesses*/ + std::unordered_map>>> iter_spaces_; + /*! \brief A map from each buffer to their minimal padding size */ + Map padding_min_; + /*! \brief max padding size in relative to the original shape*/ + const double max_pad_factor_ = 0.25; + + friend class AutoCopyMutator; +}; + +class AutoCopyMutator : public StmtExprMutator { + public: + explicit AutoCopyMutator(Map thread_extent) : thread_extent_(thread_extent) {} + /** + * \brief Replace old buffers with padded buffers in the stmt + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ + Stmt RewritePaddingBody(const Stmt& stmt) { return padder.RewriteBufferAccess(stmt); } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + // only rewrite the block annotated with "auto_copy" + if (GetAnn(op, "auto_copy").value_or(0)->value == 0) { + BlockNode* n = block.CopyOnWrite(); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + ICHECK_EQ(block->reads.size(), 1); + ICHECK_EQ(block->writes.size(), 1); + int data_bits = block->reads[0]->buffer->dtype.bits(); + ConstraintSet constraints(this->thread_extent_, // + this->outer_loops_, // + block->reads[0], // + block->writes[0], // + data_bits, // + block->annotations); + BlockNode* n = block.CopyOnWrite(); + OutputSet outputs; + for (RewriteRule* rule : rules) { + n->body = rule->Apply(std::move(n->body), constraints, &outputs); + } + for (const Buffer& buffer : outputs.alloc_buffer) { + n->alloc_buffers.push_back(buffer); + } + for (const auto& p : outputs.padding_min) { + Integer m = padder.padding_min_.Get(p.first).value_or(1); + padder.padding_min_.Set(p.first, Downcast(max(p.second, m))); + } + padder.AnalyzeSharedMemoryAccess(block->body, outer_loops_, data_bits, thread_extent_); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + + Stmt VisitStmt_(const ForNode* op) final { + outer_loops_.push_back(GetRef(op)); + Stmt stmt = StmtMutator::VisitStmt_(op); + outer_loops_.pop_back(); + return stmt; + } + + /*! \brief Thread extents collected. */ + Map thread_extent_; + /*! \brief The outer loops during recursive visit */ + Array outer_loops_; + /*! \brief Calculating optimal padding size */ + AutoPadder padder; + + /*! \brief All rewrite rules. */ + const std::array rules = { + &inverse_mapping, // + &coalesced_access, // + &create_local_stage, // + &shared_to_wmma, // + &wmma_to_global, // + &wmma_to_shared, + }; +}; + +/*! + * \brief Collect the extent for all thread binding loops. + */ +class ThreadExtentCollector : public StmtVisitor { + public: + static Map CollectThreadExtent(const Stmt& stmt) { + ThreadExtentCollector collector; + collector(stmt); + return collector.thread_extent_; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (Optional warp_execution = GetAnn(op, "warp_execution")) { + if (warp_execution.value()->value != 0) { + thread_extent_.Set("threadIdx.x", Integer(32)); + } + } + StmtVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode* op) final { + if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) { + thread_extent_.Set(op->thread_binding.value()->thread_tag, Downcast(op->extent)); + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief the map from thread tag to its extent */ + Map thread_extent_; +}; + +namespace transform { + +Pass LowerAutoCopy() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + AutoCopyMutator mutator(ThreadExtentCollector::CollectThreadExtent(n->body)); + n->body = mutator(std::move(n->body)); + n->body = mutator.RewritePaddingBody(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h new file mode 100644 index 000000000000..1cb0ea496a03 --- /dev/null +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../schedule/utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The set containing all possible constraints of a data copy */ +struct ConstraintSet { + /*! \brief The extents of the thread binding loops */ + Map thread_extent; + /*! \brief The outer loops surrounding the data copy */ + Array outer_loops; + /*! \brief The read region of the data copy */ + BufferRegion read_region; + /*! \brief The write region of the data copy */ + BufferRegion write_region; + /*! \brief The dtype size in bits */ + int data_bits; + /*! \brief Whether to insert a local stage in the data copy */ + int add_local_stage = 0; + /*! \brief The vectorization length in bytes */ + int vector_bytes = 1; + + explicit ConstraintSet(Map thread_extent, // + Array outer_loops, // + BufferRegion read_region, // + BufferRegion write_region, // + int data_bits, // + const Map& ann) + : thread_extent(thread_extent), + outer_loops(outer_loops), + read_region(read_region), + write_region(write_region), + data_bits(data_bits) { + if (Optional add_local_stage = ann.Get("local_stage")) { + this->add_local_stage = Downcast(add_local_stage.value())->value; + } + if (Optional vector_bytes = ann.Get("vector_bytes")) { + this->vector_bytes = Downcast(vector_bytes.value())->value; + } + } +}; + +/*! \brief The set containing all possible outputs of a rewrite rule */ +struct OutputSet { + /*! \brief New buffers allocated after rewrite */ + Array alloc_buffer; + /*! \brief The minimal padding size of a buffer in base 2 logarithm */ + Map padding_min; +}; + +/*! + * \brief Rules to rewrite a data copy. + */ +class RewriteRule { + protected: + /* RewriteRule() = default; */ + /*! + * \brief Rewrite the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \param output Some additional information that the rewrite rule produces. (including the new + * buffer to be allocated, etc.) + * \return the stmt after rewrite + */ + virtual Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const = 0; + /*! + * \brief Whether the rewrite rule can be applied to the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \return A boolean flag indicating whether the rule can be applied + */ + virtual bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const { return true; } + + public: + inline Stmt Apply(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { + if (CanApply(stmt, constraints)) { + return Rewrite(stmt, constraints, output); + } else { + return stmt; + } + } +}; + +inline bool IsCopyBetweenScope(const Buffer& src_buffer, const Buffer& tgt_buffer, + runtime::StorageRank src_rank, runtime::StorageRank tgt_rank) { + runtime::StorageScope src_scope = runtime::StorageScope::Create(src_buffer.scope()); + runtime::StorageScope tgt_scope = runtime::StorageScope::Create(tgt_buffer.scope()); + return src_scope.rank == src_rank && tgt_scope.rank == tgt_rank; +} + +/*! + * \brief Coalesce and vectorize memory access. + */ +class CoalescedAccess : public RewriteRule { + public: + CoalescedAccess() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Transform from A[f(i,j)] = B[i,j] to A[i,j] = B[f^{-1}(i,j)] + */ +class InverseMapping : public RewriteRule { + public: + InverseMapping() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Create a local stage when loading from global memory to shared memory. + */ +class CreateLocalStage : public RewriteRule { + public: + CreateLocalStage() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) && + is_one(constraints.add_local_stage); + } +}; + +/*! + * \brief Add a cache stage in shared memory. Perform tensor core rewrite for wmma->shared, and + * perform coalescing and vectorizing for shared->global. + */ +class WmmaToGlobal : public RewriteRule { + public: + WmmaToGlobal() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Rewrite shared->wmma data copy with load_matrix_sync + */ +class SharedToWmma : public RewriteRule { + public: + SharedToWmma() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixA) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixB); + } +}; + +/*! + * \brief Rewrite wmma->shared data copy with store_matrix_sync + */ +class WmmaToShared : public RewriteRule { + public: + WmmaToShared() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kShared); + } +}; + +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc new file mode 100644 index 000000000000..6e880146d618 --- /dev/null +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. + * \param stmt The stmt + * \return A pair. The first is the stmt after transformation. + * The second is the compute location where we may add write cache. + */ +std::pair> TileWmmaBlock(Stmt stmt) { + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + PrimExpr extent_last1 = loops[n - 1]->extent; + PrimExpr extent_last2 = loops[n - 2]->extent; + { + arith::Analyzer analyzer; + if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) || + !analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) { + return std::make_pair(stmt, NullOpt); + } + } + Var new_loop_vars[4] = { + /*0:*/ loops[n - 2]->loop_var.copy_with_suffix("_0"), + /*1:*/ loops[n - 1]->loop_var.copy_with_suffix("_0"), + /*2:*/ loops[n - 2]->loop_var.copy_with_suffix("_1"), + /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), + }; + body = Substitute(std::move(body), + Map{ + {loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]}, + {loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]}, + }); + { + PrimExpr factor[4] = { + /*0:*/ floordiv(extent_last2, 16), // + /*1:*/ floordiv(extent_last1, 16), // + /*3:*/ 16, // + /*4:*/ 16, // + }; + body = For(new_loop_vars[3], 0, factor[3], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[2], 0, factor[2], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[1], 0, factor[1], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[0], 0, factor[0], ForKind::kSerial, std::move(body)); + } + For compute_location = Downcast(body); + for (int i = n - 3; i >= 0; i--) { + body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), + loops[i]->thread_binding, loops[i]->annotations); + } + return {body, compute_location}; +} + +Array RelaxIndices(const Array& indices, const Array& shape, + const Map& var_dom) { + Array int_set = arith::EvalSet(indices, var_dom); + int ndim = int_set.size(); + Array region; + region.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i]))); + }; + return region; +} + +/*! + * \brief Rewrite the data copy that stores to wmma fragment with wmma::load_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaLoad(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(16); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + Buffer new_src_buffer( + /*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer( + /*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, + /*predicate=*/Bool(true), + Block( + /*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_load", + /*body=*/ + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_load_matrix_sync(), + { + /*0:*/ new_tgt_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_tgt_buffer->elem_offset, 256) + + floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*dtype=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + /*args=*/ + { + /*0:*/ TypeAnnotation(new_src_buffer->dtype), + /*1:*/ new_src_buffer->data, + /*2:*/ new_src_buffer->elem_offset, + /*3:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16, + /*4:*/ 1, + }), + /*6:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2], + /*7:*/ StringImm("row_major"), + })), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + /*0:*/ MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + /*1:*/ MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +/*! + * \brief Rewrite the data copy that loads from wmma fragment with wmma::store_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaStore(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(32); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, // + /*predicate=*/Bool(true), + Block(/*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_store", + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_store_matrix_sync(), + {/*0:*/ new_src_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_src_buffer->elem_offset, 256) + + floordiv(floormod(new_src_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + { + /*0:*/ TypeAnnotation(new_tgt_buffer->dtype), + /*1:*/ new_tgt_buffer->data, + /*2:*/ new_tgt_buffer->elem_offset, + /*3:*/ new_tgt_buffer->strides[0] * 16, + /*4:*/ 2, + }), + /*6:*/ new_tgt_buffer->strides[0], + /*7:*/ StringImm("row_major")})), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +Stmt SharedToWmma::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.read_region->buffer, 8); + return RewriteWmmaLoad(after_tiling); +} + +Stmt WmmaToShared::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.write_region->buffer, 8); + return RewriteWmmaStore(after_tiling); +} + +class WmmaToGlobalRewriter : public StmtExprMutator { + public: + WmmaToGlobalRewriter(const SeqStmtNode* tgt_stmt, const ConstraintSet& constraints) + : tgt_stmt_(tgt_stmt), constraints_(constraints) {} + + private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + if (op == tgt_stmt_) { + ICHECK_EQ(op->seq.size(), 2); + Stmt wmma_to_shared = RewriteWmmaStore(op->seq[0]); + Stmt shared_to_global = CoalescedAccess().Rewrite(op->seq[1], constraints_, nullptr); + return SeqStmt({wmma_to_shared, shared_to_global}); + } else { + return StmtMutator::VisitStmt_(op); + } + } + + const SeqStmtNode* tgt_stmt_; + const ConstraintSet& constraints_; +}; + +Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body{nullptr}; + Optional compute_location{nullptr}; + std::tie(body, compute_location) = TileWmmaBlock(stmt); + SeqStmt seq{nullptr}; + Buffer cache_buffer; + // Step 1. add a shared memory cache + std::tie(body, seq) = InsertCacheStage(std::move(body), true, "shared.dyn", compute_location, + constraints.outer_loops, &cache_buffer); + output->alloc_buffer.push_back(cache_buffer); + output->padding_min.Set(cache_buffer, 8); + // Step 2. do coalesced rewrite and tensor core rewrite respectively for 2 parts + WmmaToGlobalRewriter rewriter(seq.get(), constraints); + return rewriter(body); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc new file mode 100644 index 000000000000..dd19d7923e77 --- /dev/null +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file renormalize_split_pattern.cc + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + */ +#include +#include +#include +#include +#include +#include + +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" + +namespace tvm { +namespace tir { + +using namespace arith; + +// macro for doing simple rewrite +#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ + } + +// macro rewrite + recursive_rewrite only if CondExor is true after match. +#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ + } + +class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { + public: + explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + + PrimExpr VisitExpr_(const FloorDivNode* op) final { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + PrimExpr ret = floordiv(a, b); + // Pattern var to match any expression + PVar x, y, z; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + + // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) + TVM_TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), + floormod(floordiv(x, c2), floordiv(c3, c2)), + c3.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floordiv(floormod(x, broadcast(c3, lanes)), broadcast(c2, lanes)), + floormod(floordiv(x, broadcast(c2, lanes)), broadcast(floordiv(c3, c2), lanes)), + c3.Eval()->value % c2.Eval()->value == 0); + + // floordiv(x*c1*c3 + y, c2*c3) = floordiv(x*c1 + floordiv(y, c3), c2) + if ((floordiv(x * c1 + y, c2)).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval(), c3), c2_div)); + } + } + } + if ((floordiv(x * broadcast(c1, lanes) + y, broadcast(c2, lanes))).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv( + x.Eval() * Broadcast(c1_div, lanes.Eval()) + + floordiv(y.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), + Broadcast(c2_div, lanes.Eval()))); + } + } + } + + // floordiv(x*c1*c3 + y + z, c2*c3) = floordiv(x*c1 + floordiv(y + z, c3), c2) + if ((floordiv(x * c1 + y + z, c2)).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval() + z.Eval(), c3), c2_div)); + } + } + } + if ((floordiv(x * broadcast(c1, lanes) + y + z, broadcast(c2, lanes))).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv( + x.Eval() * Broadcast(c1_div, lanes.Eval()) + + floordiv(y.Eval() + z.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), + Broadcast(c2_div, lanes.Eval()))); + } + } + } + + return ret; + } + + PrimExpr VisitExpr_(const LENode* op) { return this->VisitExpr(Not(op->b < op->a)); } + + PrimExpr VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } + + PrimExpr VisitExpr_(const GENode* op) { return this->VisitExpr(Not(op->a < op->b)); } + + PrimExpr VisitExpr_(const LTNode* op) { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + PrimExpr ret = tir::LT(a, b); + // Pattern var to match any expression + PVar x; + // Pattern var match IntImm + PVar c1, c2; + TVM_TRY_RECURSIVE_REWRITE_IF(xvalue> 0); + return ret; + } + + PrimExpr VisitExpr_(const NotNode* op) { + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + // Pattern var to match any expression + PVar x, y; + TVM_TRY_REWRITE(!(!x), x); + TVM_TRY_REWRITE(!(x <= y), y < x); + TVM_TRY_REWRITE(!(x >= y), x < y); + TVM_TRY_REWRITE(!(x < y), y <= x); + TVM_TRY_REWRITE(!(x > y), x <= y); + return ret; + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + With ctx1(analyzer_, op->loop_var >= op->min); + With ctx2(analyzer_, op->loop_var < op->min + op->extent); + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + // Recursive rewrite x + // we limit maximum depth of recursive rewrite allowed to + // avoid infinite loop + PrimExpr RecursiveRewrite(const PrimExpr& x) { + if (recur_depth_ >= kMaxRecurDepth) return x; + ++recur_depth_; + PrimExpr res = this->VisitExpr(x); + --recur_depth_; + return res; + } + + private: + // counter to record recursive rewrite depth. + int recur_depth_{0}; + // maximum number of recursion allowed during a single pass. + static const constexpr int kMaxRecurDepth = 5; +}; + +namespace transform { + +Pass RenormalizeSplitPattern() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") + .set_body_typed(RenormalizeSplitPattern); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/meta_schedule/run_ansor_cpu.sh b/tests/python/meta_schedule/run_ansor_cpu.sh new file mode 100644 index 000000000000..a080ded8fdd9 --- /dev/null +++ b/tests/python/meta_schedule/run_ansor_cpu.sh @@ -0,0 +1,41 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +NUM_TRIALS=800 +LOG_DIR=$HOME/logs/ansor-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run SFM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/tests/python/meta_schedule/run_ansor_cuda.sh b/tests/python/meta_schedule/run_ansor_cuda.sh new file mode 100644 index 000000000000..6eda12fe119c --- /dev/null +++ b/tests/python/meta_schedule/run_ansor_cuda.sh @@ -0,0 +1,39 @@ +# set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="jetson-agx-xavier" +TARGET="nvidia/jetson-agx-xavier" +LOG_DIR=$HOME/logs/ansor-cuda/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +run C1D +run C2D +run CAP +run DEP +run DIL +run GMM +run GRP +run T2D +run C2d-BN-RELU +run TBG + +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/run_meta_schedule_cpu.sh b/tests/python/meta_schedule/run_meta_schedule_cpu.sh new file mode 100644 index 000000000000..87bc17f9e8b6 --- /dev/null +++ b/tests/python/meta_schedule/run_meta_schedule_cpu.sh @@ -0,0 +1,40 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +LOG_DIR=$HOME/logs/ms-cpu/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_meta_schedule.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials $NUM_TRIALS \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run SFM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/tests/python/meta_schedule/run_meta_schedule_cuda.sh b/tests/python/meta_schedule/run_meta_schedule_cuda.sh new file mode 100644 index 000000000000..28132a05045a --- /dev/null +++ b/tests/python/meta_schedule/run_meta_schedule_cuda.sh @@ -0,0 +1,41 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="jetson-agx-xavier" +TARGET="nvidia/jetson-agx-xavier" +LOG_DIR=$HOME/logs/ms-cuda/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + work_dir=$LOG_DIR/$name/ + mkdir -p $work_dir + echo "Running workload $name" + python tests/python/meta_schedule/test_meta_schedule.py \ + --workload "$name" \ + --target "$TARGET" \ + --work-dir "$work_dir" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials $NUM_TRIALS \ + 2>&1 | tee "$work_dir/$name.log" +} + +run C1D +run C2D +run CAP +run DEP +run DIL +run GMM +run GRP +run T2D +run C2d-BN-RELU +run TBG + +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/test_ansor.py b/tests/python/meta_schedule/test_ansor.py new file mode 100644 index 000000000000..1e548c49afa3 --- /dev/null +++ b/tests/python/meta_schedule/test_ansor.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = rpc_config.count_num_servers(allow_missing=False) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] + workload_func = auto_scheduler.register_workload(workload_func) + + if ARGS.target.device_name == "cpu": + hardware_params = auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ) + else: + hardware_params = auto_scheduler.HardwareParams( + num_cores=-1, + vector_unit_bytes=16, + cache_line_bytes=64, + max_shared_memory_per_block=int(ARGS.target.attrs["shared_memory_per_block"]), + max_local_memory_per_block=int(ARGS.target.attrs["registers_per_block"]), + max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + max_vthread_extent=8, + warp_size=32, + ) + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=hardware_params, + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_debug_ansor.py b/tests/python/meta_schedule/test_debug_ansor.py new file mode 100644 index 000000000000..be562963a1a0 --- /dev/null +++ b/tests/python/meta_schedule/test_debug_ansor.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import Tuple + +import tvm +from tvm import te, topi + + +TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + import os + if not os.path.exists("/tmp/perf"): + os.mkdir("/tmp/perf") + with open("/tmp/perf/te.cu", "w") as f: + f.write(code) + return code + + +def func( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, padded, output) + + +def main(): + inputs, weight, PadInput, conv1d_nlc = func(1, 256, 64, 128, 3, 2, 1) + s = te.create_schedule(conv1d_nlc.op) + # fmt: off + PadInput_i0, PadInput_i1, PadInput_i2 = tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis) + conv1d_nlc_n, conv1d_nlc_l, conv1d_nlc_co, conv1d_nlc_rl, conv1d_nlc_rc = tuple(conv1d_nlc.op.axis) + tuple(conv1d_nlc.op.reduce_axis) + conv1d_nlc_local, = s.cache_write([conv1d_nlc], "local") + conv1d_nlc_local_n_c, conv1d_nlc_local_l_c, conv1d_nlc_local_co_c, conv1d_nlc_local_rl, conv1d_nlc_local_rc = tuple(conv1d_nlc_local.op.axis) + tuple(conv1d_nlc_local.op.reduce_axis) + conv1d_nlc_local_n_c_o_i, conv1d_nlc_local_n_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c, factor=1) + conv1d_nlc_local_n_c_o_o_i, conv1d_nlc_local_n_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_i, factor=1) + conv1d_nlc_local_n_c_o_o_o_i, conv1d_nlc_local_n_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_o_i, factor=1) + conv1d_nlc_local_n_c_o_o_o_o, conv1d_nlc_local_n_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_o_o_i, factor=1) + conv1d_nlc_local_l_c_o_i, conv1d_nlc_local_l_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c, factor=1) + conv1d_nlc_local_l_c_o_o_i, conv1d_nlc_local_l_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_i, factor=4) + conv1d_nlc_local_l_c_o_o_o_i, conv1d_nlc_local_l_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_o_i, factor=8) + conv1d_nlc_local_l_c_o_o_o_o, conv1d_nlc_local_l_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_o_o_i, factor=1) + conv1d_nlc_local_co_c_o_i, conv1d_nlc_local_co_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c, factor=2) + conv1d_nlc_local_co_c_o_o_i, conv1d_nlc_local_co_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_i, factor=1) + conv1d_nlc_local_co_c_o_o_o_i, conv1d_nlc_local_co_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_o_i, factor=16) + conv1d_nlc_local_co_c_o_o_o_o, conv1d_nlc_local_co_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_o_o_i, factor=1) + conv1d_nlc_local_rl_o_i, conv1d_nlc_local_rl_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rl, factor=3) + conv1d_nlc_local_rl_o_o, conv1d_nlc_local_rl_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rl_o_i, factor=1) + conv1d_nlc_local_rc_o_i, conv1d_nlc_local_rc_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rc, factor=2) + conv1d_nlc_local_rc_o_o, conv1d_nlc_local_rc_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rc_o_i, factor=8) + s[conv1d_nlc_local].reorder(conv1d_nlc_local_n_c_o_o_o_o, conv1d_nlc_local_l_c_o_o_o_o, conv1d_nlc_local_co_c_o_o_o_o, conv1d_nlc_local_n_c_o_o_o_i, conv1d_nlc_local_l_c_o_o_o_i, conv1d_nlc_local_co_c_o_o_o_i, conv1d_nlc_local_n_c_o_o_i, conv1d_nlc_local_l_c_o_o_i, conv1d_nlc_local_co_c_o_o_i, conv1d_nlc_local_rl_o_o, conv1d_nlc_local_rc_o_o, conv1d_nlc_local_rl_o_i, conv1d_nlc_local_rc_o_i, conv1d_nlc_local_n_c_o_i, conv1d_nlc_local_l_c_o_i, conv1d_nlc_local_co_c_o_i, conv1d_nlc_local_rl_i, conv1d_nlc_local_rc_i, conv1d_nlc_local_n_c_i, + conv1d_nlc_local_l_c_i, conv1d_nlc_local_co_c_i) + conv1d_nlc_n_o_i, conv1d_nlc_n_i = s[conv1d_nlc].split(conv1d_nlc_n, factor=1) + conv1d_nlc_n_o_o_i, conv1d_nlc_n_o_i = s[conv1d_nlc].split(conv1d_nlc_n_o_i, factor=1) + conv1d_nlc_n_o_o_o, conv1d_nlc_n_o_o_i = s[conv1d_nlc].split(conv1d_nlc_n_o_o_i, factor=1) + conv1d_nlc_l_o_i, conv1d_nlc_l_i = s[conv1d_nlc].split(conv1d_nlc_l, factor=4) + conv1d_nlc_l_o_o_i, conv1d_nlc_l_o_i = s[conv1d_nlc].split(conv1d_nlc_l_o_i, factor=8) + conv1d_nlc_l_o_o_o, conv1d_nlc_l_o_o_i = s[conv1d_nlc].split(conv1d_nlc_l_o_o_i, factor=1) + conv1d_nlc_co_o_i, conv1d_nlc_co_i = s[conv1d_nlc].split(conv1d_nlc_co, factor=2) + conv1d_nlc_co_o_o_i, conv1d_nlc_co_o_i = s[conv1d_nlc].split(conv1d_nlc_co_o_i, factor=16) + conv1d_nlc_co_o_o_o, conv1d_nlc_co_o_o_i = s[conv1d_nlc].split(conv1d_nlc_co_o_o_i, factor=1) + s[conv1d_nlc].reorder(conv1d_nlc_n_o_o_o, conv1d_nlc_l_o_o_o, conv1d_nlc_co_o_o_o, conv1d_nlc_n_o_o_i, conv1d_nlc_l_o_o_i, conv1d_nlc_co_o_o_i, conv1d_nlc_n_o_i, conv1d_nlc_l_o_i, conv1d_nlc_co_o_i, conv1d_nlc_n_i, conv1d_nlc_l_i, conv1d_nlc_co_i) + s[conv1d_nlc_local].compute_at(s[conv1d_nlc], conv1d_nlc_co_o_i) + weight_shared = s.cache_read(weight, "shared", [conv1d_nlc_local]) + weight_shared_ax0, weight_shared_ax1, weight_shared_ax2 = tuple(weight_shared.op.axis) + s[weight_shared].compute_at(s[conv1d_nlc_local], conv1d_nlc_local_rc_o_o) + PadInput_shared = s.cache_read(PadInput, "shared", [conv1d_nlc_local]) + PadInput_shared_ax0, PadInput_shared_ax1, PadInput_shared_ax2 = tuple(PadInput_shared.op.axis) + s[PadInput_shared].compute_at(s[conv1d_nlc_local], conv1d_nlc_local_rc_o_o) + s[PadInput].compute_inline() + conv1d_nlc_n_o_o_o_l_o_o_o_fused_co_o_o_o_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_o_o, conv1d_nlc_l_o_o_o, conv1d_nlc_co_o_o_o) + s[conv1d_nlc].bind(conv1d_nlc_n_o_o_o_l_o_o_o_fused_co_o_o_o_fused, te.thread_axis("blockIdx.x")) + conv1d_nlc_n_o_o_i_l_o_o_i_fused_co_o_o_i_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_o_i, conv1d_nlc_l_o_o_i, conv1d_nlc_co_o_o_i) + s[conv1d_nlc].bind(conv1d_nlc_n_o_o_i_l_o_o_i_fused_co_o_o_i_fused, te.thread_axis("vthread")) + conv1d_nlc_n_o_i_l_o_i_fused_co_o_i_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_i, conv1d_nlc_l_o_i, conv1d_nlc_co_o_i) + s[conv1d_nlc].bind(conv1d_nlc_n_o_i_l_o_i_fused_co_o_i_fused, te.thread_axis("threadIdx.x")) + weight_shared_ax0_ax1_fused_ax2_fused = s[weight_shared].fuse(weight_shared_ax0, weight_shared_ax1, weight_shared_ax2) + weight_shared_ax0_ax1_fused_ax2_fused_o, weight_shared_ax0_ax1_fused_ax2_fused_i = s[weight_shared].split(weight_shared_ax0_ax1_fused_ax2_fused, factor=1) + s[weight_shared].vectorize(weight_shared_ax0_ax1_fused_ax2_fused_i) + weight_shared_ax0_ax1_fused_ax2_fused_o_o, weight_shared_ax0_ax1_fused_ax2_fused_o_i = s[weight_shared].split(weight_shared_ax0_ax1_fused_ax2_fused_o, factor=128) + s[weight_shared].bind(weight_shared_ax0_ax1_fused_ax2_fused_o_i, te.thread_axis("threadIdx.x")) + PadInput_shared_ax0_ax1_fused_ax2_fused = s[PadInput_shared].fuse(PadInput_shared_ax0, PadInput_shared_ax1, PadInput_shared_ax2) + PadInput_shared_ax0_ax1_fused_ax2_fused_o, PadInput_shared_ax0_ax1_fused_ax2_fused_i = s[PadInput_shared].split(PadInput_shared_ax0_ax1_fused_ax2_fused, factor=1) + s[PadInput_shared].vectorize(PadInput_shared_ax0_ax1_fused_ax2_fused_i) + PadInput_shared_ax0_ax1_fused_ax2_fused_o_o, PadInput_shared_ax0_ax1_fused_ax2_fused_o_i = s[PadInput_shared].split(PadInput_shared_ax0_ax1_fused_ax2_fused_o, factor=128) + s[PadInput_shared].bind(PadInput_shared_ax0_ax1_fused_ax2_fused_o_i, te.thread_axis("threadIdx.x")) + # s[conv1d_nlc_local].pragma(conv1d_nlc_local_n_c_o_o_o_o, "auto_unroll_max_step", 1024) + # s[conv1d_nlc_local].pragma(conv1d_nlc_local_n_c_o_o_o_o, "unroll_explicit", True) + # fmt: off + print(tvm.lower(s, [inputs, weight, conv1d_nlc]).script()) + tvm.build(s, [inputs, weight, conv1d_nlc], target=TARGET) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_debug_meta_schedule.py b/tests/python/meta_schedule/test_debug_meta_schedule.py new file mode 100644 index 000000000000..b93a01dae737 --- /dev/null +++ b/tests/python/meta_schedule/test_debug_meta_schedule.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +from typing import List + +import tvm +from tvm import meta_schedule as ms +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import Postproc +from tvm.meta_schedule.testing import create_te_workload +from tvm.meta_schedule.tune import DefaultCUDA, DefaultLLVM +from tvm.meta_schedule.utils import remove_build_dir +from tvm.target import Target +from tvm.tir import Schedule + + +RPC_HOST = "192.168.6.66" +RPC_PORT = 4445 +RPC_KEY = "jetson-agx-xavier" +TARGET = Target("nvidia/jetson-agx-xavier") +WORKLOAD = "C1D" +POSTPROCS: List[Postproc] = DefaultCUDA._postproc() # pylint: disable=protected-access + +TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") + + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + import os + + if not os.path.exists("/tmp/perf"): + os.mkdir("/tmp/perf") + with open("/tmp/perf/tir.cu", "w") as f: + f.write(code) + return code + + +def schedule_fn(sch: Schedule): + # pylint: disable=invalid-name,line-too-long,unused-variable + # fmt: off + b0 = sch.get_block(name="PadInput", func_name="main") + b1 = sch.get_block(name="conv1d_nlc", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + l4, l5, l6, l7, l8 = sch.get_loops(block=b1) + v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l14, l15, l16, l17, l18 = sch.split(loop=l4, factors=[v9, v10, v11, v12, v13]) + v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1]) + l24, l25, l26, l27, l28 = sch.split(loop=l5, factors=[v19, v20, v21, v22, v23]) + v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[4, 1, 16, 1, 2]) + l34, l35, l36, l37, l38 = sch.split(loop=l6, factors=[v29, v30, v31, v32, v33]) + v39, v40, v41 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[1, 1, 3]) + l42, l43, l44 = sch.split(loop=l7, factors=[v39, v40, v41]) + v45, v46, v47 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[4, 8, 2]) + l48, l49, l50 = sch.split(loop=l8, factors=[v45, v46, v47]) + sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l48, l43, l49, l17, l27, l37, l44, l50, l18, l28, l38) + l51 = sch.fuse(l14, l24, l34) + sch.bind(loop=l51, thread_axis="blockIdx.x") + l52 = sch.fuse(l15, l25, l35) + sch.bind(loop=l52, thread_axis="vthread.x") + l53 = sch.fuse(l16, l26, l36) + sch.bind(loop=l53, thread_axis="threadIdx.x") + + b54 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b54, loop=l48, preserve_unit_loops=True) + l55, l56, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b54) + l63 = sch.fuse(l60, l61, l62) + v64, v65 = sch.sample_perfect_tile(loop=l63, n=2, max_innermost_factor=4, decision=[1040, 1]) + sch.annotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + + b66 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b66, loop=l48, preserve_unit_loops=True) + l67, l68, l69, l70, l71, l72, l73, l74 = sch.get_loops(block=b66) + l75 = sch.fuse(l72, l73, l74) + v76, v77 = sch.sample_perfect_tile(loop=l75, n=2, max_innermost_factor=4, decision=[1536, 1]) + sch.annotate(block_or_loop=b66, ann_key="meta_schedule.cooperative_fetch", ann_val=v77) + + sch.reverse_compute_at(block=b3, loop=l53, preserve_unit_loops=True) + sch.compute_inline(block=b0) + # v78 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) + # sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v78) + # fmt: on + return sch + + +def _make_sch() -> Schedule: + prim_func = create_te_workload(WORKLOAD, 0) + prim_func = prim_func.with_attr("global_symbol", "main") + prim_func = prim_func.with_attr("tir.noalias", True) + mod = IRModule({"main": prim_func}) + return Schedule(mod, debug_mask="all") + + +def _apply_postproc(sch: Schedule): + sch.enter_postproc() + ctx = TuneContext(target=TARGET) + for p in POSTPROCS: + p.initialize_with_tune_context(ctx) + assert p.apply(sch) + + +def run_sch(sch: Schedule): + print(sch.mod.script()) + print(sch.trace) + print(tvm.lower(sch.mod).script()) + tvm.build(sch.mod, target=TARGET) + builder = ms.builder.LocalBuilder() + runner = ms.runner.RPCRunner( + rpc_config=ms.runner.RPCConfig( + tracker_host=RPC_HOST, + tracker_port=RPC_PORT, + tracker_key=RPC_KEY, + session_timeout_sec=60, + ), + alloc_repeat=3, + max_workers=5, + ) + (builder_result,) = builder.build( # pylint: disable=unbalanced-tuple-unpacking + [ms.builder.BuilderInput(sch.mod, TARGET)] + ) + if builder_result.error_msg is not None: + print(builder_result.error_msg) + return + try: + runner_input = ms.runner.RunnerInput( + builder_result.artifact_path, + device_type=TARGET.kind.name, + args_info=ms.arg_info.ArgInfo.from_prim_func(sch.mod["main"]), + ) + (runner_future,) = runner.run([runner_input]) # pylint: disable=unbalanced-tuple-unpacking + runner_result = runner_future.result() + if runner_result.error_msg is not None: + print(runner_result.error_msg) + else: + print([float(x) * 1000.0 for x in runner_result.run_secs]) + finally: + remove_build_dir(builder_result.artifact_path) + + +def main(): + sch = schedule_fn(_make_sch()) + _apply_postproc(sch) + run_sch(sch) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_meta_schedule.py b/tests/python/meta_schedule/test_meta_schedule.py new file mode 100644 index 000000000000..64890f426791 --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging +from os import cpu_count + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=30, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + alloc_repeat=3, + max_workers=ARGS.rpc_workers, + ) + sch: tir.Schedule = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.EvolutionarySearchConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + init_min_unmeasured=50 + ), + runner=runner, + task_name=ARGS.workload, + work_dir=ARGS.work_dir, + num_threads=cpu_count(), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index b40f3c9f56ea..e8fae11d5cc9 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -102,7 +102,14 @@ def test_mod(): floordiv = tvm.te.floordiv z = te.var("z") ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3)) - ck.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (0, 7)) + ck.verify( + flm(y, 8), + {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, + ( + z * 8 + x * 4 - 8 * floordiv(z * 8 + x * 4, 8), + z * 8 + x * 4 + 3 - 8 * floordiv(z * 8 + x * 4, 8), + ), + ) ck1 = IntSetChecker() ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2)) ck1.verify( diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 2de30eff3f5c..cb8bbd1063c9 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -633,7 +633,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) tvm.ir.assert_structural_equal(res[1][0], 0) - tvm.ir.assert_structural_equal(res[1][1], floormod(floordiv(l0[0] * 6 + l1[0], 3), 2)) + tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 4a4cd6a31ef1..7914195effe1 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.arith import analyzer def test_cast(): @@ -50,6 +51,14 @@ def test_mul(): assert m.base == 2 +def test_floormod(): + analyzer = tvm.arith.Analyzer() + x, y = te.var("x"), te.var("y") + m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256)) + assert m.coeff == 4 + assert m.base == 0 + + def test_div_shift(): analyzer = tvm.arith.Analyzer() x, y = te.var("x"), te.var("y") @@ -175,6 +184,7 @@ def test_let(): test_add_sub() test_mul() test_div_shift() + test_floormod() test_min_max_select() test_mix_index() test_constraint_scope() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 6ca2a2a5fcb0..549882126d50 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -80,6 +80,10 @@ def test_vector_simplify(): ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) + ck.verify( + fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), + tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4) + ) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), @@ -277,6 +281,7 @@ def test_add_index_simplify(): flm = tvm.te.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x) + ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)) def test_sub_index_simplify(): diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py index d95397b42c77..9dadf94973f4 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -import re from typing import List +import re import numpy as np + from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.feature_extractor import PyFeatureExtractor from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.feature_extractor import PyFeatureExtractor def test_meta_schedule_feature_extractor(): diff --git a/tests/python/unittest/test_meta_schedule_mutator.py b/tests/python/unittest/test_meta_schedule_mutator.py new file mode 100644 index 000000000000..b4d94dc9a8e3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List, Optional + +import re + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T + +from tvm.meta_schedule.mutator import PyMutator +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import Schedule, Trace + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_mutator(): + class FancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + mutator = FancyMutator() + sch = Schedule(Matmul) + res = mutator.apply(sch.trace) + assert res is not None + new_sch = sch.copy() + res.apply_to_schedule(new_sch, remove_postproc=True) + assert_structural_equal(sch.mod, new_sch.mod) + + +def test_meta_schedule_mutator_as_string(): + class YetAnotherFancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + pass + + def __str__(self) -> str: + return f"YetAnotherFancyMutator({_get_hex_address(self.handle)})" + + mutator = YetAnotherFancyMutator() + pattern = re.compile(r"YetAnotherFancyMutator\(0x[a-f|0-9]*\)") + assert pattern.match(str(mutator)) + + +if __name__ == "__main__": + test_meta_schedule_mutator() + test_meta_schedule_mutator_as_string() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 78477e6acdd6..e9105ea0c337 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -29,7 +29,7 @@ from tvm.script import tir as T from tvm.target import Target from tvm.tir.schedule import BlockRV, Schedule - +from tvm import register_func # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, # fmt: off @@ -50,6 +50,42 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] +@tvm.script.ir_module +class MatmulCustomized: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class MatmulCustomizedNoneRule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + T.block_attr({"schedule_rule": "None"}) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "None"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + @tvm.script.ir_module class DuplicateMatmul: @T.prim_func @@ -102,7 +138,7 @@ def main(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [1024, 1024], dtype="float32") D = T.match_buffer(d, [1024, 1024], dtype="float32") # body - # with tir.block("root") + # with T.block("root") B = T.alloc_buffer([1024, 1024], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): with T.block("A"): @@ -120,6 +156,209 @@ def main(a: T.handle, d: T.handle) -> None: D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) +# with T.block("root"): + +# with T.block("A"): +# # template: meta_schedule.testing.some_rule +# ... +# with T.block("B"): +# # ReLU +# ... +# with T.block("C"): +# # bias_add +# ... + + + +@tvm.script.ir_module +class Conv2d_Winograd: + @T.prim_func + def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_winograd: T.handle) -> None: + # function attr dict + T.func_attr({"layout_free_placeholders": [var_placeholder_1]}) + placeholder = T.match_buffer(var_placeholder, [1, 14, 14, 128], elem_offset=0, align=128, offset_factor=1) + placeholder_1 = T.match_buffer(var_placeholder_1, [6, 6, 128, 128], elem_offset=0, align=128, offset_factor=1) + conv2d_winograd = T.match_buffer(var_conv2d_winograd, [1, 12, 12, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space.winograd"}) + data_pad = T.alloc_buffer([1, 16, 16, 128], elem_offset=0, align=128, offset_factor=1) + input_tile = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + B = T.alloc_buffer([6, 6], elem_offset=0, align=128, offset_factor=1) + data_pack = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1) + inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1) + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 16, 16, 128): + with T.block("data_pad"): + T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) + T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) + T.block_attr({ + "schedule_rule": "None", + }) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32") + for eps, nu, p, ci in T.grid(6, 6, 9, 128): + with T.block("input_tile"): + T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]]) + T.writes([input_tile[eps, nu, p, ci]]) + T.block_attr({ + "schedule_rule": "None", + }) + input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci] + for i, j in T.grid(6, 6): + with T.block("B"): + T.writes([B[i, j]]) + T.block_attr({ + "const_matrix" : True, + "schedule_rule": "None", + }) + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): + with T.block("data_pack"): + eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap("SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]) + T.reads([data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[T.min(r_a, r_b):(T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))), T.min(eps_1, nu_1):(T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)))]]) + T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) + T.block_attr({ + "auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], + "schedule_rule": "None", + }) + with T.init(): + data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) + data_pack[eps_1, nu_1, p_1, ci_1] = (data_pack[eps_1, nu_1, p_1, ci_1] + ((input_tile[r_a, r_b, p_1, ci_1]*B[r_a, eps_1])*B[r_b, nu_1])) + for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): + with T.block("bgemm"): + eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) + T.reads([bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2]]) + T.writes([bgemm[eps_2, nu_2, p_2, co]]) + T.block_attr({ + "schedule_rule": "None", + }) + with T.init(): + bgemm[eps_2, nu_2, p_2, co] = T.float32(0) + bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2])) + for i_1, j_1 in T.grid(6, 4): + with T.block("A"): + T.writes([A[i_1, j_1]]) + T.block_attr({ + "const_matrix" : True, + "schedule_rule": "None", + }) + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): + with T.block("inverse"): + vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap("SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]) + T.reads([inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[T.min(r_a_1, r_b_1):(T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))), T.min(vh, vw):(T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw)))]]) + T.writes([inverse[vh, vw, p_3, co_1]]) + T.block_attr({ + "schedule_rule": "None", + "auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], + }) + with T.init(): + inverse[vh, vw, p_3, co_1] = T.float32(0) + inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw])) + for n, h, w, co_2 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]]) + T.writes([conv2d_winograd[n, h, w, co_2]]) + T.block_attr({ + "schedule_rule": "None" + }) + conv2d_winograd[n, h, w, co_2] = inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2] + +@tvm.script.ir_module +class Conv2d_Winograd_Cuda: + @T.prim_func + def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_winograd: T.handle) -> None: + # function attr dict + T.func_attr({"layout_free_placeholders": [var_placeholder_1]}) + placeholder = T.match_buffer(var_placeholder, [1, 14, 14, 128], elem_offset=0, align=128, offset_factor=1) + placeholder_1 = T.match_buffer(var_placeholder_1, [6, 6, 128, 128], elem_offset=0, align=128, offset_factor=1) + conv2d_winograd = T.match_buffer(var_conv2d_winograd, [1, 12, 12, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + data_pad = T.alloc_buffer([1, 16, 16, 128], elem_offset=0, align=128, offset_factor=1) + input_tile = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + B = T.alloc_buffer([6, 6], elem_offset=0, align=128, offset_factor=1) + data_pack = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1) + inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1) + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 16, 16, 128): + with T.block("data_pad"): + T.block_attr({ + "schedule_rule": "None", + }) + T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) + T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32") + for eps, nu, p, ci in T.grid(6, 6, 9, 128): + with T.block("input_tile"): + T.block_attr({ + "schedule_rule": "None", + }) + T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]]) + T.writes([input_tile[eps, nu, p, ci]]) + input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci] + for i, j in T.grid(6, 6): + with T.block("B"): + T.writes([B[i, j]]) + T.block_attr({ + "const_matrix":True, + "schedule_rule": "None", + }) + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): + with T.block("data_pack"): + eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap("SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]) + T.reads([data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[T.min(r_a, r_b):(T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))), T.min(eps_1, nu_1):(T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)))]]) + T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) + T.block_attr({ + "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"], + "schedule_rule": "None", + }) + with T.init(): + data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) + data_pack[eps_1, nu_1, p_1, ci_1] = (data_pack[eps_1, nu_1, p_1, ci_1] + ((input_tile[r_a, r_b, p_1, ci_1]*B[r_a, eps_1])*B[r_b, nu_1])) + for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): + with T.block("bgemm"): + T.block_attr({ + "schedule_rule": "None", + }) + eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) + T.reads([bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2]]) + T.writes([bgemm[eps_2, nu_2, p_2, co]]) + with T.init(): + bgemm[eps_2, nu_2, p_2, co] = T.float32(0) + bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2])) + for i_1, j_1 in T.grid(6, 4): + with T.block("A"): + T.writes([A[i_1, j_1]]) + T.block_attr({ + "const_matrix":True, + "schedule_rule": "None", + }) + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): + with T.block("inverse"): + vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap("SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]) + T.reads([inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[T.min(r_a_1, r_b_1):(T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))), T.min(vh, vw):(T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw)))]]) + T.writes([inverse[vh, vw, p_3, co_1]]) + T.block_attr({ + "auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], + "schedule_rule": "None", + }) + with T.init(): + inverse[vh, vw, p_3, co_1] = T.float32(0) + inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw])) + for n, h, w, co_2 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + T.block_attr({ + "schedule_rule": "None", + }) + T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]]) + T.writes([conv2d_winograd[n, h, w, co_2]]) + conv2d_winograd[n, h, w, co_2] = inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2] + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -338,5 +577,437 @@ def correct_trace(a, b, c, d): ) +def test_meta_schedule_post_order_apply_custom_search_space(): + @register_func("tvm.meta_schedule.test.custom_search_space") + def custom_search_space_func(sch: Schedule, block: BlockRV): + raise ValueError("Customized search space triggered!") + + mod = MatmulCustomized + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises(ValueError, match="Customized search space triggered!"): + _ = post_order_apply.generate_design_space(mod) + + +class DontCallThisRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + print(sch.get(block)) + raise RuntimeError("This schedule rule should not be called!") + + +def test_meta_schedule_post_order_apply_custom_search_space_none_rule(): + mod = MatmulCustomizedNoneRule + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + _ = post_order_apply.generate_design_space(mod) + + +@pytest.mark.xfail # for compute_at bug +def test_meta_schedule_post_order_apply_custom_search_space_winograd(): + @register_func("tvm.meta_schedule.test.custom_search_space.winograd") + def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Schedule]: + b1 = sch.get_block(name="A") + sch.compute_inline(block=b1) + b2 = sch.get_block(name="B") + sch.compute_inline(block=b2) + b3 = sch.get_block(name="inverse") + l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3) + sch.unroll(loop=l4) + sch.unroll(loop=l5) + sch.unroll(loop=l8) + sch.unroll(loop=l9) + v10, v11 = sch.sample_perfect_tile(n=2, loop=l6, max_innermost_factor=64, decision=[1, 9]) + l12, l13 = sch.split(loop=l6, factors=[v10, v11]) + v14, v15 = sch.sample_perfect_tile(n=2, loop=l7, max_innermost_factor=64, decision=[2, 64]) + l16, l17 = sch.split(loop=l7, factors=[v14, v15]) + sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9) + b18 = sch.get_block(name="data_pack") + l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18) + sch.unroll(loop=l19) + sch.unroll(loop=l20) + sch.unroll(loop=l23) + sch.unroll(loop=l24) + v25, v26 = sch.sample_perfect_tile(n=2, loop=l21, max_innermost_factor=64, decision=[9, 1]) + l27, l28 = sch.split(loop=l21, factors=[v25, v26]) + v29, v30 = sch.sample_perfect_tile(n=2, loop=l22, max_innermost_factor=64, decision=[32, 4]) + l31, l32 = sch.split(loop=l22, factors=[v29, v30]) + sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24) + b33 = sch.get_block(name="bgemm") + b34 = sch.cache_write(block=b33, write_buffer_index=0, storage_scope="global") + b33, b34 = b34, b33 + l35, l36, l37, l38, l39 = sch.get_loops(block=b34) + v40, v41, v42, v43 = sch.sample_perfect_tile( + n=4, loop=l35, max_innermost_factor=64, decision=[1, 2, 3, 1] + ) + l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43]) + v48, v49, v50, v51 = sch.sample_perfect_tile( + n=4, loop=l36, max_innermost_factor=64, decision=[1, 1, 1, 6] + ) + l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51]) + v56, v57, v58, v59 = sch.sample_perfect_tile( + n=4, loop=l37, max_innermost_factor=64, decision=[1, 1, 1, 9] + ) + l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59]) + v64, v65, v66, v67 = sch.sample_perfect_tile( + n=4, loop=l38, max_innermost_factor=64, decision=[2, 1, 16, 4] + ) + l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67]) + v72, v73 = sch.sample_perfect_tile(n=2, loop=l39, max_innermost_factor=64, decision=[16, 8]) + l74, l75 = sch.split(loop=l39, factors=[v72, v73]) + sch.reorder( + l44, l52, l60, l68, l45, l53, l61, l69, l74, l46, l54, l62, l70, l75, l47, l55, l63, l71 + ) + sch.reverse_compute_at(block=b33, loop=l69, preserve_unit_loops=True) + b76 = sch.get_block(name="root") + sch.annotate(block_or_loop=b76, ann_key="auto_parallel_extent", ann_val=64) + sch.annotate(block_or_loop=b76, ann_key="auto_vectorize_extent", ann_val=32) + v77 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate(block_or_loop=b76, ann_key="auto_unroll_explicit", ann_val=v77) + + b78 = sch.get_block(name="input_tile") + (b79,) = sch.get_consumers(block=b78) + l80 = sch.sample_compute_location(block=b79, decision=4) + sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) + + b81 = sch.get_block(name="data_pad") + (b82,) = sch.get_consumers(block=b81) + l83 = sch.sample_compute_location(block=b82, decision=-2) + sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True) + return [sch] + + mod = Conv2d_Winograd + + # Add annotation + sch = Schedule(mod) + sch.annotate( + sch.get_block("root"), + "schedule_rule", + "tvm.meta_schedule.test.custom_search_space.winograd", + ) + mod = sch.mod + context = TuneContext( + mod=mod, + target=Target("llvm --num-cores=16"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + (sch,) = schs + assert str(sch.trace) == "\n".join( + [ + 'b0 = sch.get_block(name="data_pad", func_name="main")', + 'b1 = sch.get_block(name="input_tile", func_name="main")', + 'b2 = sch.get_block(name="B", func_name="main")', + 'b3 = sch.get_block(name="data_pack", func_name="main")', + 'b4 = sch.get_block(name="bgemm", func_name="main")', + 'b5 = sch.get_block(name="A", func_name="main")', + 'b6 = sch.get_block(name="inverse", func_name="main")', + 'b7 = sch.get_block(name="conv2d_winograd", func_name="main")', + 'b8 = sch.get_block(name="root", func_name="main")', + 'b9 = sch.get_block(name="A", func_name="main")', + "sch.compute_inline(block=b9)", + 'b10 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b10)", + 'b11 = sch.get_block(name="inverse", func_name="main")', + "l12, l13, l14, l15, l16, l17 = sch.get_loops(block=b11)", + "sch.unroll(loop=l12)", + "sch.unroll(loop=l13)", + "sch.unroll(loop=l16)", + "sch.unroll(loop=l17)", + "v18, v19 = sch.sample_perfect_tile(loop=l14, n=2, max_innermost_factor=64, decision=[1, 9])", + "l20, l21 = sch.split(loop=l14, factors=[v18, v19])", + "v22, v23 = sch.sample_perfect_tile(loop=l15, n=2, max_innermost_factor=64, decision=[2, 64])", + "l24, l25 = sch.split(loop=l15, factors=[v22, v23])", + "sch.reorder(l20, l24, l21, l25, l12, l13, l16, l17)", + 'b26 = sch.get_block(name="data_pack", func_name="main")', + "l27, l28, l29, l30, l31, l32 = sch.get_loops(block=b26)", + "sch.unroll(loop=l27)", + "sch.unroll(loop=l28)", + "sch.unroll(loop=l31)", + "sch.unroll(loop=l32)", + "v33, v34 = sch.sample_perfect_tile(loop=l29, n=2, max_innermost_factor=64, decision=[9, 1])", + "l35, l36 = sch.split(loop=l29, factors=[v33, v34])", + "v37, v38 = sch.sample_perfect_tile(loop=l30, n=2, max_innermost_factor=64, decision=[32, 4])", + "l39, l40 = sch.split(loop=l30, factors=[v37, v38])", + "sch.reorder(l35, l39, l36, l40, l27, l28, l31, l32)", + 'b41 = sch.get_block(name="bgemm", func_name="main")', + 'b42 = sch.cache_write(block=b41, write_buffer_index=0, storage_scope="global")', + "l43, l44, l45, l46, l47 = sch.get_loops(block=b41)", + "v48, v49, v50, v51 = sch.sample_perfect_tile(loop=l43, n=4, max_innermost_factor=64, decision=[1, 2, 3, 1])", + "l52, l53, l54, l55 = sch.split(loop=l43, factors=[v48, v49, v50, v51])", + "v56, v57, v58, v59 = sch.sample_perfect_tile(loop=l44, n=4, max_innermost_factor=64, decision=[1, 1, 1, 6])", + "l60, l61, l62, l63 = sch.split(loop=l44, factors=[v56, v57, v58, v59])", + "v64, v65, v66, v67 = sch.sample_perfect_tile(loop=l45, n=4, max_innermost_factor=64, decision=[1, 1, 1, 9])", + "l68, l69, l70, l71 = sch.split(loop=l45, factors=[v64, v65, v66, v67])", + "v72, v73, v74, v75 = sch.sample_perfect_tile(loop=l46, n=4, max_innermost_factor=64, decision=[2, 1, 16, 4])", + "l76, l77, l78, l79 = sch.split(loop=l46, factors=[v72, v73, v74, v75])", + "v80, v81 = sch.sample_perfect_tile(loop=l47, n=2, max_innermost_factor=64, decision=[16, 8])", + "l82, l83 = sch.split(loop=l47, factors=[v80, v81])", + "sch.reorder(l52, l60, l68, l76, l53, l61, l69, l77, l82, l54, l62, l70, l78, l83, l55, l63, l71, l79)", + "sch.reverse_compute_at(block=b42, loop=l77, preserve_unit_loops=True)", + 'b84 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b84, ann_key="auto_parallel_extent", ann_val=64)', + 'sch.annotate(block_or_loop=b84, ann_key="auto_vectorize_extent", ann_val=32)', + "v85 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1)", + 'sch.annotate(block_or_loop=b84, ann_key="auto_unroll_explicit", ann_val=v85)', + 'b86 = sch.get_block(name="input_tile", func_name="main")', + "l87 = sch.sample_compute_location(block=b86, decision=-1)", + "sch.compute_at(block=b86, loop=l87, preserve_unit_loops=True)", + 'b88 = sch.get_block(name="data_pad", func_name="main")', + "l89 = sch.sample_compute_location(block=b88, decision=-1)", + "sch.compute_at(block=b88, loop=l89, preserve_unit_loops=True)", + ], + ) + + +@pytest.mark.xfail # for compute_at bug +def test_meta_schedule_post_order_apply_custom_search_space_winograd_cuda(): + @register_func("tvm.meta_schedule.test.custom_search_space.winograd.cuda") + def custom_search_space_winograd_func_cuda(sch: Schedule, block: BlockRV) -> List[Schedule]: + b1 = sch.get_block(name="inverse") + l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) + sch.unroll(loop=l2) + sch.unroll(loop=l3) + sch.unroll(loop=l6) + sch.unroll(loop=l7) + v8, v9 = sch.sample_perfect_tile(n=2, loop=l4, max_innermost_factor=64, decision=[3, 3]) + l10, l11 = sch.split(loop=l4, factors=[v8, v9]) + v12, v13 = sch.sample_perfect_tile(n=2, loop=l5, max_innermost_factor=64, decision=[2, 64]) + l14, l15 = sch.split(loop=l5, factors=[v12, v13]) + sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) + b16 = sch.get_block(name="data_pack") + l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) + sch.unroll(loop=l17) + sch.unroll(loop=l18) + sch.unroll(loop=l21) + sch.unroll(loop=l22) + v23, v24 = sch.sample_perfect_tile(n=2, loop=l19, max_innermost_factor=64, decision=[3, 3]) + l25, l26 = sch.split(loop=l19, factors=[v23, v24]) + v27, v28 = sch.sample_perfect_tile(n=2, loop=l20, max_innermost_factor=64, decision=[64, 2]) + l29, l30 = sch.split(loop=l20, factors=[v27, v28]) + sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22) + b31 = sch.get_block(name="bgemm") + b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") + b31, b32 = b32, b31 + l33, l34, l35, l36, l37 = sch.get_loops(block=b32) + v38, v39, v40, v41, v42 = sch.sample_perfect_tile( + n=5, loop=l33, max_innermost_factor=64, decision=[1, 1, 1, 1, 6] + ) + l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) + v48, v49, v50, v51, v52 = sch.sample_perfect_tile( + n=5, loop=l34, max_innermost_factor=64, decision=[1, 1, 1, 3, 2] + ) + l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) + v58, v59, v60, v61, v62 = sch.sample_perfect_tile( + n=5, loop=l35, max_innermost_factor=64, decision=[3, 1, 1, 1, 3] + ) + l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) + v68, v69, v70, v71, v72 = sch.sample_perfect_tile( + n=5, loop=l36, max_innermost_factor=64, decision=[4, 2, 1, 4, 4] + ) + l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) + v78, v79, v80 = sch.sample_perfect_tile( + n=3, loop=l37, max_innermost_factor=64, decision=[32, 1, 4] + ) + l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) + sch.reorder( + l43, + l53, + l63, + l73, + l44, + l54, + l64, + l74, + l45, + l55, + l65, + l75, + l81, + l82, + l46, + l56, + l66, + l76, + l83, + l47, + l57, + l67, + l77, + ) + l84 = sch.fuse(l43, l53, l63, l73) + sch.bind(loop=l84, thread_axis="blockIdx.x") + l85 = sch.fuse(l44, l54, l64, l74) + sch.bind(loop=l85, thread_axis="vthread.x") + l86 = sch.fuse(l45, l55, l65, l75) + sch.bind(loop=l86, thread_axis="threadIdx.x") + b87 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) + l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b87) + l96 = sch.fuse(l92, l93, l94, l95) + v97, v98 = sch.sample_perfect_tile( + n=2, loop=l96, max_innermost_factor=4, decision=[1536, 3] + ) + l99, l100 = sch.split(loop=l96, factors=[v97, v98]) + sch.vectorize(loop=l100) + sch.annotate(block_or_loop=l99, ann_key="loop_type", ann_val="lazy_cooperative_fetch") + b101 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) + l102, l103, l104, l105, l106, l107, l108, l109 = sch.get_loops(block=b101) + l110 = sch.fuse(l106, l107, l108, l109) + v111, v112 = sch.sample_perfect_tile( + n=2, loop=l110, max_innermost_factor=4, decision=[432, 1] + ) + l113, l114 = sch.split(loop=l110, factors=[v111, v112]) + sch.vectorize(loop=l114) + sch.annotate(block_or_loop=l113, ann_key="loop_type", ann_val="lazy_cooperative_fetch") + sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True) + b115 = sch.get_block(name="input_tile") + (b116,) = sch.get_consumers(block=b115) + l117, l118, l119, l120, l121, l122, l123, l124 = sch.get_loops(block=b116) + sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) + sch.set_scope(block=b115, buffer_index=0, storage_scope="local") + b125 = sch.get_block(name="A") + sch.compute_inline(block=b125) + b126 = sch.get_block(name="B") + sch.compute_inline(block=b126) + b127 = sch.get_block(name="data_pad") + sch.compute_inline(block=b127) + b128 = sch.get_block(name="root") + v129 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=0 + ) + sch.annotate(block_or_loop=b128, ann_key="auto_unroll_explicit", ann_val=v129) + return [sch] + + mod = Conv2d_Winograd_Cuda + + # Add annotation + sch = Schedule(mod) + sch.annotate( + sch.get_block("root"), + "schedule_rule", + "tvm.meta_schedule.test.custom_search_space.winograd.cuda", + ) + mod = sch.mod + context = TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3070"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + (sch,) = schs + assert str(sch.trace) == "\n".join( + [ + 'b0 = sch.get_block(name="data_pad", func_name="main")', + 'b1 = sch.get_block(name="input_tile", func_name="main")', + 'b2 = sch.get_block(name="B", func_name="main")', + 'b3 = sch.get_block(name="data_pack", func_name="main")', + 'b4 = sch.get_block(name="bgemm", func_name="main")', + 'b5 = sch.get_block(name="A", func_name="main")', + 'b6 = sch.get_block(name="inverse", func_name="main")', + 'b7 = sch.get_block(name="conv2d_winograd", func_name="main")', + 'b8 = sch.get_block(name="root", func_name="main")', + 'b9 = sch.get_block(name="inverse", func_name="main")', + "l10, l11, l12, l13, l14, l15 = sch.get_loops(block=b9)", + "sch.unroll(loop=l10)", + "sch.unroll(loop=l11)", + "sch.unroll(loop=l14)", + "sch.unroll(loop=l15)", + "v16, v17 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64, decision=[3, 3])", + "l18, l19 = sch.split(loop=l12, factors=[v16, v17])", + "v20, v21 = sch.sample_perfect_tile(loop=l13, n=2, max_innermost_factor=64, decision=[2, 64])", + "l22, l23 = sch.split(loop=l13, factors=[v20, v21])", + "sch.reorder(l18, l22, l19, l23, l10, l11, l14, l15)", + 'b24 = sch.get_block(name="data_pack", func_name="main")', + "l25, l26, l27, l28, l29, l30 = sch.get_loops(block=b24)", + "sch.unroll(loop=l25)", + "sch.unroll(loop=l26)", + "sch.unroll(loop=l29)", + "sch.unroll(loop=l30)", + "v31, v32 = sch.sample_perfect_tile(loop=l27, n=2, max_innermost_factor=64, decision=[3, 3])", + "l33, l34 = sch.split(loop=l27, factors=[v31, v32])", + "v35, v36 = sch.sample_perfect_tile(loop=l28, n=2, max_innermost_factor=64, decision=[64, 2])", + "l37, l38 = sch.split(loop=l28, factors=[v35, v36])", + "sch.reorder(l33, l37, l34, l38, l25, l26, l29, l30)", + 'b39 = sch.get_block(name="bgemm", func_name="main")', + 'b40 = sch.cache_write(block=b39, write_buffer_index=0, storage_scope="local")', + "l41, l42, l43, l44, l45 = sch.get_loops(block=b39)", + "v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l41, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 6])", + "l51, l52, l53, l54, l55 = sch.split(loop=l41, factors=[v46, v47, v48, v49, v50])", + "v56, v57, v58, v59, v60 = sch.sample_perfect_tile(loop=l42, n=5, max_innermost_factor=64, decision=[1, 1, 1, 3, 2])", + "l61, l62, l63, l64, l65 = sch.split(loop=l42, factors=[v56, v57, v58, v59, v60])", + "v66, v67, v68, v69, v70 = sch.sample_perfect_tile(loop=l43, n=5, max_innermost_factor=64, decision=[3, 1, 1, 1, 3])", + "l71, l72, l73, l74, l75 = sch.split(loop=l43, factors=[v66, v67, v68, v69, v70])", + "v76, v77, v78, v79, v80 = sch.sample_perfect_tile(loop=l44, n=5, max_innermost_factor=64, decision=[4, 2, 1, 4, 4])", + "l81, l82, l83, l84, l85 = sch.split(loop=l44, factors=[v76, v77, v78, v79, v80])", + "v86, v87, v88 = sch.sample_perfect_tile(loop=l45, n=3, max_innermost_factor=64, decision=[32, 1, 4])", + "l89, l90, l91 = sch.split(loop=l45, factors=[v86, v87, v88])", + "sch.reorder(l51, l61, l71, l81, l52, l62, l72, l82, l53, l63, l73, l83, l89, l90, l54, l64, l74, l84, l91, l55, l65, l75, l85)", + "l92 = sch.fuse(l51, l61, l71, l81)", + 'sch.bind(loop=l92, thread_axis="blockIdx.x")', + "l93 = sch.fuse(l52, l62, l72, l82)", + 'sch.bind(loop=l93, thread_axis="vthread.x")', + "l94 = sch.fuse(l53, l63, l73, l83)", + 'sch.bind(loop=l94, thread_axis="threadIdx.x")', + 'b95 = sch.cache_read(block=b39, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b95, loop=l89, preserve_unit_loops=True)", + "l96, l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b95)", + "l104 = sch.fuse(l100, l101, l102, l103)", + "v105, v106 = sch.sample_perfect_tile(loop=l104, n=2, max_innermost_factor=4, decision=[1536, 3])", + "l107, l108 = sch.split(loop=l104, factors=[v105, v106])", + "sch.vectorize(loop=l108)", + 'sch.annotate(block_or_loop=l107, ann_key="loop_type", ann_val="lazy_cooperative_fetch")', + 'b109 = sch.cache_read(block=b39, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b109, loop=l89, preserve_unit_loops=True)", + "l110, l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b109)", + "l118 = sch.fuse(l114, l115, l116, l117)", + "v119, v120 = sch.sample_perfect_tile(loop=l118, n=2, max_innermost_factor=4, decision=[432, 1])", + "l121, l122 = sch.split(loop=l118, factors=[v119, v120])", + "sch.vectorize(loop=l122)", + 'sch.annotate(block_or_loop=l121, ann_key="loop_type", ann_val="lazy_cooperative_fetch")', + "sch.reverse_compute_at(block=b40, loop=l94, preserve_unit_loops=True)", + 'b123 = sch.get_block(name="input_tile", func_name="main")', + "b124, = sch.get_consumers(block=b123)", + "l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b124)", + "sch.compute_at(block=b123, loop=l128, preserve_unit_loops=True)", + 'sch.set_scope(block=b123, buffer_index=0, storage_scope="local")', + 'b133 = sch.get_block(name="A", func_name="main")', + "sch.compute_inline(block=b133)", + 'b134 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b134)", + 'b135 = sch.get_block(name="data_pad", func_name="main")', + "sch.compute_inline(block=b135)", + 'b136 = sch.get_block(name="root", func_name="main")', + "v137 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0)", + 'sch.annotate(block_or_loop=b136, ann_key="auto_unroll_explicit", ann_val=v137)', + ] + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py new file mode 100644 index 000000000000..6e17e7bac3f2 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import PyPostproc +from tvm.meta_schedule.utils import _get_hex_address +from tvm.script import tir as T +from tvm.target.target import Target +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_postproc(): + class FancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + schedule_matmul(sch) + return True + + postproc = FancyPostproc() + mod = Matmul + sch = Schedule(mod) + assert postproc.apply(sch) + try: + tvm.ir.assert_structural_equal(sch.mod, mod) + raise Exception("The postprocessors did not change the schedule.") + except ValueError: + _check_correct(sch) + + +def test_meta_schedule_postproc_fail(): + class FailingPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + return False + + postproc = FailingPostproc() + sch = Schedule(Matmul) + assert not postproc.apply(sch) + + +def test_meta_schedule_postproc_as_string(): + class NotSoFancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" + + postproc = NotSoFancyPostproc() + pattern = re.compile(r"NotSoFancyPostproc\(0x[a-f|0-9]*\)") + assert pattern.match(str(postproc)) + + +if __name__ == "__main__": + test_meta_schedule_postproc() + test_meta_schedule_postproc_fail() + test_meta_schedule_postproc_as_string() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index 31e92e09e50e..0cf96e086520 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -103,6 +103,107 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: C[v0, v1] = C_local[v0, v1] +@tvm.script.ir_module +class AfterRewrite1: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0, ax1 in T.grid(256, 16): + with T.block("A_shared_wmma.matrix_a"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax1) + T.reads([A_shared[v0, v1]]) + T.writes([A_shared_wmma_matrix_a[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_a"}) + A_shared_wmma_matrix_a[v0, v1] = A_shared[v0, v1] + for ax0, ax1 in T.grid(16, 32): + with T.block("B_shared_wmma.matrix_b"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([B_shared[v0, v1]]) + T.writes([B_shared_wmma_matrix_b[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_b"}) + B_shared_wmma_matrix_b[v0, v1] = B_shared[v0, v1] + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_fill"}) + with T.init(): + for i0_1, i1_1 in T.grid(16, 16): + with T.block("C_init"): + i_init, j_init = T.axis.remap("SS", [i0_1, i1_1]) + T.reads() + T.writes(C_local_wmma_accumulator[io * 16 + i_init, jo * 16 + j_init]) + C_local_wmma_accumulator[io * 16 + i_init, jo * 16 + j_init] = T.float32(0) + for i0_1, i1_1, i2_1 in T.grid(16, 16, 16): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0_1, i1_1, i2_1]) + T.reads(C_local_wmma_accumulator[io * 16 + i, jo * 16 + j], A_shared_wmma_matrix_a[io * 16 + i, ko * 16 + k], B_shared_wmma_matrix_b[ko * 16 + k, jo * 16 + j]) + T.writes(C_local_wmma_accumulator[io * 16 + i, jo * 16 + j]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync"}) + C_local_wmma_accumulator[io * 16 + i, jo * 16 + j] = C_local_wmma_accumulator[io * 16 + i, jo * 16 + j] + T.cast(A_shared_wmma_matrix_a[io * 16 + i, ko * 16 + k], "float32") * T.cast(B_shared_wmma_matrix_b[ko * 16 + k, jo * 16 + j], "float32") + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local_wmma.accumulator"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local_wmma_accumulator[v0, v1]]) + T.writes([C_local[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store"}) + C_local[v0, v1] = C_local_wmma_accumulator[v0, v1] + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks # fmt: on @@ -151,5 +252,70 @@ def test_rewrite_cooperative_fetch(): tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0) +def test_rewrite_cooperative_fetch_tensor_core(): + mod = create_prim_func(te_workload.matmul_fp16(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + l1, l2, l3 = sch.get_loops(block=b0) + _, l5 = sch.split(loop=l1, factors=[32, 16]) + _, l7 = sch.split(loop=l2, factors=[32, 16]) + _, l9 = sch.split(loop=l3, factors=[32, 16]) + _, _, l12, _, l14, _ = sch.get_loops(block=b0) + sch.reorder(l12, l14, l5, l7, l9) + b16 = sch.blockize(loop=l5) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync") + sch.annotate(block_or_loop=b16, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill") + b17 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b17, ann_key="meta_schedule.tensor_core_enabled", ann_val="1") + b18 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="local") + b19 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="wmma.accumulator") + sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store") + l20, l21, l22 = sch.get_loops(block=b16) + v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=64, decision=[1, 2, 1, 1, 16]) + l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27]) + v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=64, decision=[1, 2, 8, 1, 2]) + l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37]) + v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=64, decision=[4, 8, 1]) + l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45]) + sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42) + l49 = sch.fuse(l28, l38) + sch.bind(loop=l49, thread_axis="blockIdx.x") + l50 = sch.fuse(l29, l39) + sch.bind(loop=l50, thread_axis="blockIdx.y") + l51 = sch.fuse(l30, l40) + sch.bind(loop=l51, thread_axis="threadIdx.y") + b52 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True) + _, _, _, _, l57, l58 = sch.get_loops(block=b52) + l59 = sch.fuse(l57, l58) + _, v61 = sch.sample_perfect_tile(loop=l59, n=2, max_innermost_factor=4, decision=[32768, 1]) + sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61) + b62 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b62, loop=l46, preserve_unit_loops=True) + _, _, _, _, l67, l68 = sch.get_loops(block=b62) + l69 = sch.fuse(l67, l68) + _, v71 = sch.sample_perfect_tile(loop=l69, n=2, max_innermost_factor=4, decision=[8192, 4]) + sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + b72 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a") + b73 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b") + sch.compute_at(block=b72, loop=l48, preserve_unit_loops=True) + sch.compute_at(block=b73, loop=l48, preserve_unit_loops=True) + sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a") + sch.annotate(block_or_loop=b73, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b") + sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True) + sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, AfterRewrite1) + + if __name__ == "__main__": test_rewrite_cooperative_fetch() + test_rewrite_cooperative_fetch_tensor_core() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensor_core.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensor_core.py new file mode 100644 index 000000000000..c11890aefa80 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensor_core.py @@ -0,0 +1,275 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteTensorCore +from tvm.script import tir as T +from tvm.target import Target +from tvm.meta_schedule.testing import tir_tensor_intrin + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteTensorCore(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + + +@tvm.script.ir_module +class Before0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C_local = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + # C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i0_0_4_init, i1_0_4_init in T.grid(16, 2): + with T.block("blockized_C_init"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4_init) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4_init) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1, i1_1 in T.grid(16, 16): + with T.block("C_init"): + i_init = T.axis.spatial(512, io * 16 + i0_1) + j_init = T.axis.spatial(512, jo * 16 + i1_1) + T.reads([]) + T.writes([C_local_wmma_accumulator[i_init, j_init]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_fill"}) + C_local_wmma_accumulator[i_init, j_init] = T.float32(0) + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0, ax1 in T.grid(256, 16): + with T.block("A_shared_wmma.matrix_a"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax1) + T.reads([A_shared[v0, v1]]) + T.writes([A_shared_wmma_matrix_a[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_a"}) + A_shared_wmma_matrix_a[v0, v1] = A_shared[v0, v1] + for ax0, ax1 in T.grid(16, 32): + with T.block("B_shared_wmma.matrix_b"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([B_shared[v0, v1]]) + T.writes([B_shared_wmma_matrix_b[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_b"}) + B_shared_wmma_matrix_b[v0, v1] = B_shared[v0, v1] + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C_update"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1, i1_1, i2_1 in T.grid(16, 16, 16): + with T.block("C"): + i = T.axis.spatial(512, io * 16 + i0_1) + j = T.axis.spatial(512, jo * 16 + i1_1) + k = T.axis.reduce(512, ko * 16 + i2_1) + T.reads([C_local_wmma_accumulator[i, j], A_shared_wmma_matrix_a[i, k], B_shared_wmma_matrix_b[k, j]]) + T.writes([C_local_wmma_accumulator[i, j]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync"}) + C_local_wmma_accumulator[i, j] = C_local_wmma_accumulator[i, j] + T.cast(A_shared_wmma_matrix_a[i, k], "float32") * T.cast(B_shared_wmma_matrix_b[k, j], "float32") + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local_wmma.accumulator"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local_wmma_accumulator[v0, v1]]) + T.writes([C_local[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store"}) + C_local[v0, v1] = C_local_wmma_accumulator[v0, v1] + + +@tvm.script.ir_module +class After0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + s0 = T.var("int32") + s0_1 = T.var("int32") + s0_2 = T.var("int32") + s1 = T.var("int32") + s1_1 = T.var("int32") + s1_2 = T.var("int32") + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C_local = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i0_0_4_init, i1_0_4_init in T.grid(16, 2): + with T.block("blockized_C_init"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4_init) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4_init) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1_0, i1_1_0 in T.grid(1, 1): + with T.block("blockized_C_init"): + i_inito = T.axis.spatial(1, 0) + j_inito = T.axis.spatial(1, 0) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + C = T.match_buffer(C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + C.elem_offset % 256 // 16, T.float32(0), dtype="handle")) + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0_0, ax1_0 in T.grid(16, 1): + with T.block("blockized_A_shared_wmma.matrix_a"): + v0o = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + ax0_0) + v1o = T.axis.spatial(32, i2_0_0 * 8 + i2_0_1) + T.reads([A_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + T.writes([A_shared_wmma_matrix_a[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + A_1 = T.match_buffer(A_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", strides=[s1, s0], scope="shared", offset_factor=16) + C_1 = T.match_buffer(A_shared_wmma_matrix_a[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // 256 + C_1.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), A_1.data, A_1.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("blockized_B_shared_wmma.matrix_b"): + v0o = T.axis.spatial(32, i2_0_0 * 8 + i2_0_1) + v1o = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + ax1_0) + T.reads([B_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + T.writes([B_shared_wmma_matrix_b[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + A_2 = T.match_buffer(B_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", strides=[s1_1, s0_1], scope="shared", offset_factor=16) + C_2 = T.match_buffer(B_shared_wmma_matrix_b[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_b", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // 256 + C_2.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), A_2.data, A_2.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "row_major", dtype="handle")) + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C_update"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1_0, i1_1_0, i2_1_0 in T.grid(1, 1, 1): + with T.block("blockized_C"): + io_1 = T.axis.spatial(1, 0) + jo_1 = T.axis.spatial(1, 0) + ko_1 = T.axis.reduce(1, 0) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + A_3 = T.match_buffer(A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + B_1 = T.match_buffer(B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_b", offset_factor=16) + C_3 = T.match_buffer(C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, A_3.data, A_3.elem_offset // 256 + A_3.elem_offset % 256 // 16, B_1.data, B_1.elem_offset // 256 + B_1.elem_offset % 256 // 16, C_3.data, C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, dtype="handle")) + for ax0_0, ax1_0 in T.grid(16, 2): + with T.block("blockized_C_local_wmma.accumulator"): + v0o = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + ax0_0) + v1o = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + ax1_0) + T.reads([C_local_wmma_accumulator[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + T.writes([C_local[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + A_4 = T.match_buffer(C_local_wmma_accumulator[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + C_4 = T.match_buffer(C_local[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float32", strides=[s1_2, s0_2], offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(A_4.data, 16, 16, 16, A_4.elem_offset // 256 + A_4.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle")) + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_tensor_core(): + mod = Before0 + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After0) + + +if __name__ == "__main__": + test_rewrite_tensor_core() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index 4ab2741da181..9b39ad1bff3e 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -99,7 +99,7 @@ def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) - T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + T.where(i0_fused_1 < 1) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] @@ -107,7 +107,7 @@ def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) - T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + T.where(i0_fused_1 < 1) D[b] = T.sqrt(C[b], dtype="float32") diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index bebfec6122b3..d7158bef6fa5 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -195,6 +195,177 @@ def main(a: T.handle, b: T.handle) -> None: T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on +@T.prim_func +def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + +@T.prim_func +def GmmCuda1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 0, + "meta_schedule.thread_extent_high_inclusive": 32, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + + +@T.prim_func +def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 1024, + "meta_schedule.thread_extent_high_inclusive": 1024, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant @@ -227,5 +398,26 @@ def test_postproc_verify_gpu_3(): assert not ctx.postprocs[0].apply(sch) +def test_postproc_verify_gpu_4(): + mod = GmmCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_5(): + mod = GmmCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_6(): + mod = GmmCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py new file mode 100644 index 000000000000..1d34d94bfe05 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re +from typing import List + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.script import tir as T +from tvm.tir.schedule import BlockRV, Schedule + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_schedule_rule(): + class FancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [sch] + + sch_rule = FancyScheduleRule() + mod = Matmul + sch = Schedule(mod) + res = sch_rule.apply(sch, block=sch.get_block("matmul")) + assert len(res) == 1 + try: + tvm.ir.assert_structural_equal(mod, res[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except ValueError: + _check_correct(res[0]) + + +def test_meta_schedule_schedule_rule_as_string(): + class YetStillSomeFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + pass + + sch_rule = YetStillSomeFancyScheduleRule() + pattern = re.compile(r"YetStillSomeFancyScheduleRule\(0x[a-f|0-9]*\)") + assert pattern.match(str(sch_rule)) + + +if __name__ == "__main__": + test_meta_schedule_schedule_rule() + test_meta_schedule_schedule_rule_as_string() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 7bed18b0f9ea..47f405842c98 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -84,7 +84,7 @@ def test_gpu_softmax_mn(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4])", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4])", @@ -97,7 +97,7 @@ def test_gpu_softmax_mn(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4])", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4])", @@ -111,7 +111,7 @@ def test_gpu_softmax_mn(): "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5])", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)", + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5])", @@ -121,7 +121,7 @@ def test_gpu_softmax_mn(): "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l17, l18 = sch.split(loop=l15, factors=[None, v16])", 'sch.bind(loop=l18, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v16])", @@ -161,7 +161,7 @@ def test_gpu_softmax_mn_after_inline(): "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l5, l6 = sch.split(loop=l3, factors=[None, v4])", 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l7, l8, l9 = sch.get_loops(block=b0)", "l10, l11 = sch.split(loop=l9, factors=[None, v4])", @@ -175,14 +175,14 @@ def test_gpu_softmax_mn_after_inline(): "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l6, l7 = sch.split(loop=l4, factors=[None, v5])", 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)", + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', "l8, l9, l10 = sch.get_loops(block=b1)", "l11, l12 = sch.split(loop=l10, factors=[None, v5])", 'sch.bind(loop=l12, thread_axis="threadIdx.x")', "b13, b14 = sch.get_consumers(block=b0)", "l15, l16, l17, l18 = sch.get_loops(block=b13)", - "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l19, l20, l21 = sch.get_loops(block=b0)", "l22, l23 = sch.split(loop=l21, factors=[None, v5])", @@ -210,7 +210,7 @@ def test_gpu_batch_norm_bmn(): "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", "l4, l5 = sch.split(loop=l2, factors=[None, v3])", 'sch.bind(loop=l5, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)", 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', "l6, l7, l8, l9 = sch.get_loops(block=b0)", "l10 = sch.fuse(l8, l9)", diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index c6a63aae7427..c2ad9258f275 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -19,12 +19,14 @@ from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing.schedule_rule import ( multi_level_tiling, + multi_level_tiling_tensor_core, ) from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext from tvm.te import create_prim_func from tvm.meta_schedule.testing import te_workload from tvm.target import Target +from tvm.meta_schedule.testing import tir_tensor_intrin def _create_context(mod, target, rule) -> TuneContext: @@ -46,30 +48,30 @@ def test_cpu_matmul(): [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -107,30 +109,30 @@ def test_cpu_matmul_relu(): [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=True)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', @@ -169,36 +171,36 @@ def test_cuda_matmul(): [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", - "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", - "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", - "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", - "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", - "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", - "l30 = sch.fuse(l9, l19)", - 'sch.bind(loop=l30, thread_axis="blockIdx.x")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", "l31 = sch.fuse(l10, l20)", - 'sch.bind(loop=l31, thread_axis="vthread.x")', + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', "l32 = sch.fuse(l11, l21)", - 'sch.bind(loop=l32, thread_axis="threadIdx.x")', + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', - 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", ] ] # pylint: enable=line-too-long @@ -225,34 +227,34 @@ def test_cuda_matmul_relu(): [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", - "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", - "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", - "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", - "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", - "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", - "l30 = sch.fuse(l9, l19)", - 'sch.bind(loop=l30, thread_axis="blockIdx.x")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", "l31 = sch.fuse(l10, l20)", - 'sch.bind(loop=l31, thread_axis="vthread.x")', + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', "l32 = sch.fuse(l11, l21)", - 'sch.bind(loop=l32, thread_axis="threadIdx.x")', - 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", ] ] # pylint: enable=line-too-long @@ -273,8 +275,154 @@ def test_cuda_matmul_relu(): check_trace(spaces, expected) +def test_cuda_tensor_core_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "l4, l5 = sch.split(loop=l1, factors=[32, 16])", + "l6, l7 = sch.split(loop=l2, factors=[32, 16])", + "l8, l9 = sch.split(loop=l3, factors=[32, 16])", + "l10, l11, l12, l13, l14, l15 = sch.get_loops(block=b0)", + "sch.reorder(l12, l14, l5, l7, l9)", + "b16 = sch.blockize(loop=l5)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync")', + 'sch.annotate(block_or_loop=b16, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill")', + 'b17 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b17, ann_key="meta_schedule.tensor_core_enabled", ann_val="1")', + 'b18 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="local")', + 'b19 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="wmma.accumulator")', + 'sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store")', + "l20, l21, l22 = sch.get_loops(block=b16)", + "v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=64)", + "l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27])", + "v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=64)", + "l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37])", + "v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=64)", + "l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45])", + "sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)", + "l49 = sch.fuse(l28, l38)", + 'sch.bind(loop=l49, thread_axis="blockIdx.x")', + "l50 = sch.fuse(l29, l39)", + 'sch.bind(loop=l50, thread_axis="blockIdx.y")', + "l51 = sch.fuse(l30, l40)", + 'sch.bind(loop=l51, thread_axis="threadIdx.y")', + 'b52 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)", + "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)", + "l59 = sch.fuse(l57, l58)", + "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)', + 'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)", + "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)", + "l68 = sch.fuse(l66, l67)", + "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)', + 'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', + 'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', + "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)", + "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', + "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)", + ] + ] + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_fp16( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling_tensor_core(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_tensor_core_matmul_relu(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "l4, l5 = sch.split(loop=l1, factors=[32, 16])", + "l6, l7 = sch.split(loop=l2, factors=[32, 16])", + "l8, l9 = sch.split(loop=l3, factors=[32, 16])", + "l10, l11, l12, l13, l14, l15 = sch.get_loops(block=b0)", + "sch.reorder(l12, l14, l5, l7, l9)", + "b16 = sch.blockize(loop=l5)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync")', + 'sch.annotate(block_or_loop=b16, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill")', + 'b17 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b17, ann_key="meta_schedule.tensor_core_enabled", ann_val="1")', + 'b18 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="local")', + 'b19 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="wmma.accumulator")', + 'sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store")', + "l20, l21, l22 = sch.get_loops(block=b16)", + "v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=64)", + "l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27])", + "v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=64)", + "l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37])", + "v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=64)", + "l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45])", + "sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)", + "l49 = sch.fuse(l28, l38)", + 'sch.bind(loop=l49, thread_axis="blockIdx.x")', + "l50 = sch.fuse(l29, l39)", + 'sch.bind(loop=l50, thread_axis="blockIdx.y")', + "l51 = sch.fuse(l30, l40)", + 'sch.bind(loop=l51, thread_axis="threadIdx.y")', + 'b52 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)", + "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)", + "l59 = sch.fuse(l57, l58)", + "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)', + 'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)", + "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)", + "l68 = sch.fuse(l66, l67)", + "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)', + 'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', + 'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', + "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)", + "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', + "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)", + ] + ] + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu_fp16( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling_tensor_core(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + if __name__ == "__main__": test_cpu_matmul() test_cpu_matmul_relu() test_cuda_matmul() test_cuda_matmul_relu() + test_cuda_tensor_core_matmul() + test_cuda_tensor_core_matmul_relu() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py index 92c7da922c39..18db006c6ca8 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -74,7 +74,7 @@ def test_random_compute_location(): [ 'b0 = sch.get_block(name="move", func_name="main")', "l1 = sch.sample_compute_location(block=b0)", - "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=1)", + "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)", ] ] mod = Add diff --git a/tests/python/unittest/test_meta_schedule_sketch_cpu.py b/tests/python/unittest/test_meta_schedule_sketch_cpu.py new file mode 100644 index 000000000000..e6dd82294b8d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cpu.py @@ -0,0 +1,795 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("llvm --num-cores=16") + + +def test_meta_schedule_cpu_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l18, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l19, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b2, = sch.get_consumers(block=b0)", + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l18, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b2, = sch.get_consumers(block=b0)", + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l19, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l3, l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b1)", + "v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l14, l15, l16, l17 = sch.split(loop=l3, factors=[v10, v11, v12, v13])", + "v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l22, l23, l24, l25 = sch.split(loop=l4, factors=[v18, v19, v20, v21])", + "v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l30, l31, l32, l33 = sch.split(loop=l5, factors=[v26, v27, v28, v29])", + "v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l38, l39, l40, l41 = sch.split(loop=l6, factors=[v34, v35, v36, v37])", + "v42, v43 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l44, l45 = sch.split(loop=l7, factors=[v42, v43])", + "v46, v47 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l8, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l9, factors=[v50, v51])", + "sch.reorder(l14, l22, l30, l38, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=32)', + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', + "l55 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l55, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="global")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])", + "v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l23, l24, l25, l26 = sch.split(loop=l5, factors=[v19, v20, v21, v22])", + "v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l31, l32, l33, l34 = sch.split(loop=l6, factors=[v27, v28, v29, v30])", + "v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l39, l40, l41, l42 = sch.split(loop=l7, factors=[v35, v36, v37, v38])", + "v43, v44 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l45, l46 = sch.split(loop=l8, factors=[v43, v44])", + "v47, v48 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l49, l50 = sch.split(loop=l9, factors=[v47, v48])", + "v51, v52 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l53, l54 = sch.split(loop=l10, factors=[v51, v52])", + "sch.reorder(l15, l23, l31, l39, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42)", + "sch.reverse_compute_at(block=b3, loop=l39, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=32)', + "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v55)', + "l56 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l56, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="global")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])", + "v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l23, l24, l25, l26 = sch.split(loop=l5, factors=[v19, v20, v21, v22])", + "v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l31, l32, l33, l34 = sch.split(loop=l6, factors=[v27, v28, v29, v30])", + "v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l39, l40, l41, l42 = sch.split(loop=l7, factors=[v35, v36, v37, v38])", + "v43, v44 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l45, l46 = sch.split(loop=l8, factors=[v43, v44])", + "v47, v48 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l49, l50 = sch.split(loop=l9, factors=[v47, v48])", + "v51, v52 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l53, l54 = sch.split(loop=l10, factors=[v51, v52])", + "sch.reorder(l15, l23, l31, l39, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42)", + "sch.reverse_compute_at(block=b3, loop=l40, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=32)', + "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v55)', + "l56 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l56, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b1)", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l6, factors=[v13, v14, v15, v16])", + "v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l25, l26, l27, l28 = sch.split(loop=l7, factors=[v21, v22, v23, v24])", + "v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l33, l34, l35, l36 = sch.split(loop=l8, factors=[v29, v30, v31, v32])", + "v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l41, l42, l43, l44 = sch.split(loop=l9, factors=[v37, v38, v39, v40])", + "v45, v46 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l10, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l11, factors=[v49, v50])", + "v53, v54 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l55, l56 = sch.split(loop=l12, factors=[v53, v54])", + "sch.reorder(l17, l25, l33, l41, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44)", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.vectorize", ann_val=32)', + "v57 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v57)', + "l58 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l58, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b6, = sch.get_consumers(block=b1)", + "l7, l8, l9, l10, l11, l12, l13 = sch.get_loops(block=b1)", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l7, factors=[v14, v15, v16, v17])", + "v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l26, l27, l28, l29 = sch.split(loop=l8, factors=[v22, v23, v24, v25])", + "v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l34, l35, l36, l37 = sch.split(loop=l9, factors=[v30, v31, v32, v33])", + "v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l10, n=4, max_innermost_factor=64)", + "l42, l43, l44, l45 = sch.split(loop=l10, factors=[v38, v39, v40, v41])", + "v46, v47 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l11, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l12, factors=[v50, v51])", + "v54, v55 = sch.sample_perfect_tile(loop=l13, n=2, max_innermost_factor=64)", + "l56, l57 = sch.split(loop=l13, factors=[v54, v55])", + "sch.reorder(l18, l26, l34, l42, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44, l49, l53, l57, l21, l29, l37, l45)", + "sch.reverse_compute_at(block=b6, loop=l42, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.vectorize", ann_val=32)', + "v58 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v58)', + "l59 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l59, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b6, = sch.get_consumers(block=b1)", + "l7, l8, l9, l10, l11, l12, l13 = sch.get_loops(block=b1)", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l7, factors=[v14, v15, v16, v17])", + "v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l26, l27, l28, l29 = sch.split(loop=l8, factors=[v22, v23, v24, v25])", + "v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l34, l35, l36, l37 = sch.split(loop=l9, factors=[v30, v31, v32, v33])", + "v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l10, n=4, max_innermost_factor=64)", + "l42, l43, l44, l45 = sch.split(loop=l10, factors=[v38, v39, v40, v41])", + "v46, v47 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l11, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l12, factors=[v50, v51])", + "v54, v55 = sch.sample_perfect_tile(loop=l13, n=2, max_innermost_factor=64)", + "l56, l57 = sch.split(loop=l13, factors=[v54, v55])", + "sch.reorder(l18, l26, l34, l42, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44, l49, l53, l57, l21, l29, l37, l45)", + "sch.reverse_compute_at(block=b6, loop=l43, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.vectorize", ann_val=32)', + "v58 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v58)', + "l59 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l59, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected: List[List[str]] = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + "l3 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l3, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.max_pool2d_nchw( + n=1, + h=56, + w=56, + ci=512, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_batchnorm(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "l5 = sch.fuse(l3, l4)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "l5 = sch.fuse(l3, l4)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + "l3 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l3, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func(te_workload.norm_bmn(B=1, M=256, N=256)), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_softmax(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l15, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l16, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + "l15 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l15, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l16, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + "l15 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b0)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "l12 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l12, preserve_unit_loops=True)", + "l13 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l13, preserve_unit_loops=True)", + "b14, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l15 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b14)", + "sch.compute_at(block=b14, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b0)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "l12 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l12, preserve_unit_loops=True)", + "l13 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l13, preserve_unit_loops=True)", + "b14, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l15 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b14)", + "sch.compute_at(block=b14, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v4 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v4)', + "l5 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)", + "l6 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l6, preserve_unit_loops=True)", + "l7 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l7, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func(te_workload.softmax_mn(m=256, n=256)), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 9 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cpu_sketch_matmul() + test_meta_schedule_cpu_sketch_matmul_relu() + test_meta_schedule_cpu_sketch_conv2d_nchw() + test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_sketch_cpu_max_pool2d_nchw() + test_meta_schedule_cpu_sketch_batchnorm() + test_meta_schedule_cpu_sketch_softmax() diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py new file mode 100644 index 000000000000..ab118d34f94e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -0,0 +1,426 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _target_with_max_threads_per_block() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def test_meta_schedule_cuda_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l11, l12, l13, l14, l15 = sch.split(loop=l3, factors=[v6, v7, v8, v9, v10])", + "v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l21, l22, l23, l24, l25 = sch.split(loop=l4, factors=[v16, v17, v18, v19, v20])", + "v26, v27, v28 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64)", + "l29, l30, l31 = sch.split(loop=l5, factors=[v26, v27, v28])", + "sch.reorder(l11, l21, l12, l22, l13, l23, l29, l30, l14, l24, l31, l15, l25)", + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="blockIdx.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="vthread.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="threadIdx.x")', + 'b35 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b35, loop=l29, preserve_unit_loops=True)", + "l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35)", + "l42 = sch.fuse(l40, l41)", + "v43 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l29, preserve_unit_loops=True)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52)', + "sch.reverse_compute_at(block=b2, loop=l34, preserve_unit_loops=True)", + "v53 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b3 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l4, l5, l6 = sch.get_loops(block=b0)", + "v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l12, l13, l14, l15, l16 = sch.split(loop=l4, factors=[v7, v8, v9, v10, v11])", + "v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l22, l23, l24, l25, l26 = sch.split(loop=l5, factors=[v17, v18, v19, v20, v21])", + "v27, v28, v29 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64)", + "l30, l31, l32 = sch.split(loop=l6, factors=[v27, v28, v29])", + "sch.reorder(l12, l22, l13, l23, l14, l24, l30, l31, l15, l25, l32, l16, l26)", + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="blockIdx.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="vthread.x")', + "l35 = sch.fuse(l14, l24)", + 'sch.bind(loop=l35, thread_axis="threadIdx.x")', + 'b36 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b36, loop=l30, preserve_unit_loops=True)", + "l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b36)", + "l43 = sch.fuse(l41, l42)", + "v44 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)', + 'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b45, loop=l30, preserve_unit_loops=True)", + "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)", + "l52 = sch.fuse(l50, l51)", + "v53 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', + "sch.reverse_compute_at(block=b3, loop=l35, preserve_unit_loops=True)", + "sch.reverse_compute_inline(block=b1)", + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l16, l17, l18, l19, l20 = sch.split(loop=l4, factors=[v11, v12, v13, v14, v15])", + "v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l26, l27, l28, l29, l30 = sch.split(loop=l5, factors=[v21, v22, v23, v24, v25])", + "v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64)", + "l36, l37, l38, l39, l40 = sch.split(loop=l6, factors=[v31, v32, v33, v34, v35])", + "v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)", + "l46, l47, l48, l49, l50 = sch.split(loop=l7, factors=[v41, v42, v43, v44, v45])", + "v51, v52, v53 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64)", + "l54, l55, l56 = sch.split(loop=l8, factors=[v51, v52, v53])", + "v57, v58, v59 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)", + "l60, l61, l62 = sch.split(loop=l9, factors=[v57, v58, v59])", + "v63, v64, v65 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64)", + "l66, l67, l68 = sch.split(loop=l10, factors=[v63, v64, v65])", + "sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l60, l66, l55, l61, l67, l19, l29, l39, l49, l56, l62, l68, l20, l30, l40, l50)", + "l69 = sch.fuse(l16, l26, l36, l46)", + 'sch.bind(loop=l69, thread_axis="blockIdx.x")', + "l70 = sch.fuse(l17, l27, l37, l47)", + 'sch.bind(loop=l70, thread_axis="vthread.x")', + "l71 = sch.fuse(l18, l28, l38, l48)", + 'sch.bind(loop=l71, thread_axis="threadIdx.x")', + 'b72 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b72, loop=l66, preserve_unit_loops=True)", + "l73, l74, l75, l76, l77, l78, l79, l80, l81, l82 = sch.get_loops(block=b72)", + "l83 = sch.fuse(l79, l80, l81, l82)", + "v84 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v84)', + 'b85 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b85, loop=l66, preserve_unit_loops=True)", + "l86, l87, l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b85)", + "l96 = sch.fuse(l92, l93, l94, l95)", + "v97 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v97)', + "sch.reverse_compute_at(block=b3, loop=l71, preserve_unit_loops=True)", + "sch.compute_inline(block=b0)", + "v98 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v98)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="conv2d_nchw", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="compute", func_name="main")', + 'b6 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b7 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l8, l9, l10, l11, l12, l13, l14 = sch.get_loops(block=b1)", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l8, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64)", + "l30, l31, l32, l33, l34 = sch.split(loop=l9, factors=[v25, v26, v27, v28, v29])", + "v35, v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64)", + "l40, l41, l42, l43, l44 = sch.split(loop=l10, factors=[v35, v36, v37, v38, v39])", + "v45, v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64)", + "l50, l51, l52, l53, l54 = sch.split(loop=l11, factors=[v45, v46, v47, v48, v49])", + "v55, v56, v57 = sch.sample_perfect_tile(loop=l12, n=3, max_innermost_factor=64)", + "l58, l59, l60 = sch.split(loop=l12, factors=[v55, v56, v57])", + "v61, v62, v63 = sch.sample_perfect_tile(loop=l13, n=3, max_innermost_factor=64)", + "l64, l65, l66 = sch.split(loop=l13, factors=[v61, v62, v63])", + "v67, v68, v69 = sch.sample_perfect_tile(loop=l14, n=3, max_innermost_factor=64)", + "l70, l71, l72 = sch.split(loop=l14, factors=[v67, v68, v69])", + "sch.reorder(l20, l30, l40, l50, l21, l31, l41, l51, l22, l32, l42, l52, l58, l64, l70, l59, l65, l71, l23, l33, l43, l53, l60, l66, l72, l24, l34, l44, l54)", + "l73 = sch.fuse(l20, l30, l40, l50)", + 'sch.bind(loop=l73, thread_axis="blockIdx.x")', + "l74 = sch.fuse(l21, l31, l41, l51)", + 'sch.bind(loop=l74, thread_axis="vthread.x")', + "l75 = sch.fuse(l22, l32, l42, l52)", + 'sch.bind(loop=l75, thread_axis="threadIdx.x")', + 'b76 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b76, loop=l70, preserve_unit_loops=True)", + "l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 = sch.get_loops(block=b76)", + "l87 = sch.fuse(l83, l84, l85, l86)", + "v88 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v88)', + 'b89 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b89, loop=l70, preserve_unit_loops=True)", + "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)", + "l100 = sch.fuse(l96, l97, l98, l99)", + "v101 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b89, ann_key="meta_schedule.cooperative_fetch", ann_val=v101)', + "sch.reverse_compute_at(block=b7, loop=l75, preserve_unit_loops=True)", + "sch.reverse_compute_inline(block=b5)", + "sch.reverse_compute_inline(block=b4)", + "sch.reverse_compute_inline(block=b3)", + "sch.reverse_compute_inline(block=b2)", + "sch.compute_inline(block=b0)", + "v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v102)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_batchnorm(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "b2, = sch.get_consumers(block=b0)", + "l3, = sch.get_loops(block=b2)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l5, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9, l10 = sch.get_loops(block=b0)", + "l11 = sch.fuse(l9, l10)", + "l12, l13 = sch.split(loop=l11, factors=[None, v4])", + 'sch.bind(loop=l13, thread_axis="threadIdx.x")', + "v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v14)', + ], + [ + 'b0 = sch.get_block(name="root", func_name="main")', + "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.norm_bmn( + B=1, + M=256, + N=256, + ) + ), + target=_target_with_max_threads_per_block(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 2 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_softmax(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b1)", + "b4, = sch.get_consumers(block=b2)", + "l5, l6 = sch.get_loops(block=b4)", + "v7 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l8, l9 = sch.split(loop=l6, factors=[None, v7])", + 'sch.bind(loop=l9, thread_axis="threadIdx.x")', + "sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)", + 'sch.set_scope(block=b2, buffer_index=0, storage_scope="shared")', + "l10, l11, l12 = sch.get_loops(block=b2)", + "l13, l14 = sch.split(loop=l12, factors=[None, v7])", + 'sch.bind(loop=l14, thread_axis="threadIdx.x")', + "b15, b16 = sch.get_consumers(block=b0)", + "l17, l18, l19, l20 = sch.get_loops(block=b15)", + "sch.compute_at(block=b0, loop=l17, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l21, l22, l23 = sch.get_loops(block=b0)", + "l24, l25 = sch.split(loop=l23, factors=[None, v7])", + 'sch.bind(loop=l25, thread_axis="threadIdx.x")', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b0)", + "b3, = sch.get_consumers(block=b1)", + "l4, l5 = sch.get_loops(block=b3)", + "v6 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l7, l8 = sch.split(loop=l5, factors=[None, v6])", + 'sch.bind(loop=l8, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l4, preserve_unit_loops=True)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l9, l10, l11 = sch.get_loops(block=b1)", + "l12, l13 = sch.split(loop=l11, factors=[None, v6])", + 'sch.bind(loop=l13, thread_axis="threadIdx.x")', + "v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v14)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b1)", + "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l4, l5 = sch.get_loops(block=b0)", + "l6, l7 = sch.split(loop=l5, factors=[None, v3])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "v8 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v8)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b0)", + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.softmax_mn( + m=256, + n=256, + ) + ), + target=_target_with_max_threads_per_block(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cuda_sketch_matmul() + test_meta_schedule_cuda_sketch_matmul_relu() + test_meta_schedule_cuda_sketch_conv2d_nchw() + test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_cuda_sketch_batchnorm() + test_meta_schedule_cuda_sketch_softmax() diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py new file mode 100644 index 000000000000..8523275f5186 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import sys +from typing import Tuple + +import pytest + +import tvm +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model + + +@pytest.mark.skip("Skip because it runs too slowly as a unittest") +@pytest.mark.parametrize( + "model_name", + [ + # Image classification + "resnet50", + "alexnet", + "vgg16", + "squeezenet1_0", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "inception_v3", + "googlenet", + "shufflenet_v2_x1_0", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "resnext50_32x4d", + "wide_resnet50_2", + "mnasnet1_0", + # Segmentation + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", + # Object detection + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "maskrcnn_resnet50_fpn", + # video classification + "r3d_18", + "mc3_18", + "r2plus1d_18", + ], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("target", ["llvm", "cuda"]) +def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int, target: str): + if model_name == "inception_v3" and batch_size == 1: + pytest.skip("inception_v3 does not handle batch_size of 1") + + input_shape: Tuple[int, ...] + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + + output_shape: Tuple[int, int] = (batch_size, 1000) + mod, params = get_torch_model( + model_name=model_name, + input_shape=input_shape, + output_shape=output_shape, + dtype="float32", + ) + target = tvm.target.Target(target) + ms.integration.extract_task_from_relay(mod, params=params, target=target) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 277fa2407bd1..60cc7d83aee1 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -112,7 +112,6 @@ def _sch_rules(): M.AutoInline( into_producer=False, into_consumer=True, - # into_cache_only=False, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, @@ -122,7 +121,7 @@ def _sch_rules(): M.MultiLevelTiling( structure="SSSRRSRS", tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], - # use_tensor_core=True, + use_tensor_core=True, max_innermost_factor=64, vector_load_lens=[1, 2, 3, 4], reuse_read=schedule_rule.ReuseType( @@ -139,7 +138,6 @@ def _sch_rules(): M.AutoInline( into_producer=True, into_consumer=True, - # into_cache_only=True, inline_const_tensor=True, disallow_if_then_else=False, require_injective=False, @@ -161,10 +159,10 @@ def _postproc(): ) return [ - # M.RewriteCooperativeFetch(), + M.RewriteCooperativeFetch(), M.RewriteParallelVectorizeUnroll(), M.RewriteReductionBlock(), - # M.RewriteTensorCore(), + M.RewriteTensorCore(), M.VerifyGPUCode(), ] diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py new file mode 100644 index 000000000000..0bad0154a665 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import List + +from tvm.tir import ( + Evaluate, + For, + ForKind, + IndexMap, + Var, + decl_buffer, + floordiv, + floormod, +) +from tvm.tir.analysis import expr_deep_equal +from tvm.tir.schedule.analysis import suggest_index_map + + +def _make_vars(*args: str) -> List[Var]: + return [Var(arg, dtype="int32") for arg in args] + + +def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]: + assert len(loop_vars) == len(extents) + return [ + For( + loop_var=loop_var, + min_val=0, + extent=extent, + kind=ForKind.SERIAL, + body=Evaluate(0), + ) + for loop_var, extent in zip(loop_vars, extents) + ] + + +def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: + iters_1 = map1.apply(map2.src_iters) + iters_2 = map2.tgt_iters + assert len(iters_1) == len(iters_2) + for iter1, iter2 in zip(iters_1, iters_2): + assert expr_deep_equal(iter1, iter2) + + +def test_suggest_index_map_simple(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8, 256]), + indices=[ + floordiv(i, 16) * 4 + floordiv(j, 16), + floormod(i, 16) * 16 + floormod(j, 16), + ], + loops=_make_loops( + loop_vars=[i, j], + extents=[32, 64], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x, y: [ + floordiv(x, 4), + floordiv(y, 16), + floormod(x, 4), + floormod(y, 16), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +def test_suggest_index_map_bijective(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8]), + indices=[floormod(j, 4) * 2 + i], + loops=_make_loops( + loop_vars=[i, j], + extents=[2, 32], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x: [ + floormod(x, 2), + floordiv(x, 2), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +if __name__ == "__main__": + test_suggest_index_map_simple() + test_suggest_index_map_bijective() diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py new file mode 100644 index 000000000000..79a7aad10f25 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -0,0 +1,221 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable + +@T.prim_func +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + +@T.prim_func +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C_shared[vi, vj]]) + with T.init(): + C_shared[vi, vj] = T.float32(0) + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + with T.block("C_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(32, bx) + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 64): + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable +# fmt: on + + +def test_read_at_global_to_shared_a(): + sch = tir.Schedule(cuda_matmul, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 1, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + verify_trace_roundtrip(sch, cuda_matmul) + + +def test_read_at_global_to_shared_ab(): + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 2, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) + + +def test_read_at_local_to_shared_c(): + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.write_at(tx, block, 0, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index f5fc5a73d038..d9ddec6795a9 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -37,40 +37,36 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: with T.block("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) - T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) - T.writes([C[vi, vj]]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @T.prim_func -def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - C_rf = T.alloc_buffer([4, 128, 128]) - +def matmul_rfactor( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], +) -> None: + C_rf = T.alloc_buffer([4, 128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): with T.block("update_rf"): - vi2_inner_inner = T.axis.S(4, i2_inner_inner) - vi = T.axis.S(128, i0) - vj = T.axis.S(128, i1) - vi2_outer = T.axis.R(4, i2_outer) - vi2_inner_outer = T.axis.R(8, i2_inner_outer) - with T.init(): - C_rf[vi2_inner_inner, vi, vj] = 0.0 - C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( - A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] - * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + vi, vj, vi2_outer, vi2_inner_outer, vi2_inner_inner = T.axis.remap( + "SSRRS", [i0, i1, i2_outer, i2_inner_outer, i2_inner_inner] ) - - for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): + with T.init(): + C_rf[vi2_inner_inner, vi, vj] = T.float32(0) + C_rf[vi2_inner_inner, vi, vj] = ( + C_rf[vi2_inner_inner, vi, vj] + + A[vi, vi2_outer * 32 + vi2_inner_outer * 4 + vi2_inner_inner] + * B[vj, vi2_outer * 32 + vi2_inner_outer * 4 + vi2_inner_inner] + ) + for i0, i1, i2_inner_inner in T.grid(128, 128, 4): with T.block("update"): - vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) + vi, vj, vi2_inner_inner = T.axis.remap("SSR", [i0, i1, i2_inner_inner]) with T.init(): - C[vi_1, vj_1] = 0.0 - C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + C_rf[vi2_inner_inner, vi, vj] @T.prim_func @@ -141,24 +137,22 @@ def square_sum(a: T.handle, c: T.handle) -> None: @T.prim_func -def square_sum_rfactor(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - C = T.match_buffer(c, [16]) - C_rf = T.alloc_buffer([16, 256]) - - for i0, i1, i2 in T.grid(16, 256, 256): +def square_sum_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], C: T.Buffer[(16,), "float32"] +) -> None: + C_rf = T.alloc_buffer([16, 256], dtype="float32") + for b0, i0, j0 in T.grid(16, 256, 256): with T.block("C_rf"): - vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) + b, i, vj0 = T.axis.remap("SRS", [b0, i0, j0]) with T.init(): - C_rf[b, vi2] = 0.0 - C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) - - for i0_1, i2_1 in T.grid(16, 256): + C_rf[b, vj0] = T.float32(0) + C_rf[b, vj0] = C_rf[b, vj0] + A[b, i, vj0] * A[b, i, vj0] + for b0, j0 in T.grid(16, 256): with T.block("C"): - vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) + b, vj0 = T.axis.remap("SR", [b0, j0]) with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vj0] @T.prim_func @@ -167,51 +161,150 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) - for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 32768, 2): with T.block("C"): b = T.axis.S(16, i0) - i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) - j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) - T.reads([C[b], A[b, i, j]]) - T.writes([C[b]]) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer * 2 + i1_i2_fused_inner, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer * 2 + i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) - T.reads([C[b_1]]) - T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32") @T.prim_func -def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: +def square_sum_square_root_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16,), "float32"] +) -> None: + C = T.alloc_buffer([16], dtype="float32") + C_rf = T.alloc_buffer([2, 16], dtype="float32") + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 32768, 2): + with T.block("C_rf"): + b, vi1_i2_fused_outer, vi1_i2_fused_inner = T.axis.remap( + "SRS", [i0, i1_i2_fused_outer, i1_i2_fused_inner] + ) + with T.init(): + C_rf[vi1_i2_fused_inner, b] = T.float32(0) + C_rf[vi1_i2_fused_inner, b] = ( + C_rf[vi1_i2_fused_inner, b] + + A[ + b, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) // 256, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) % 256, + ] + * A[ + b, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) // 256, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) % 256, + ] + ) + for i0, i1_i2_fused_inner in T.grid(16, 2): + with T.block("C"): + b, vi1_i2_fused_inner = T.axis.remap("SR", [i0, i1_i2_fused_inner]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + C_rf[vi1_i2_fused_inner, b] + for i0_1 in T.serial(16): + with T.block("D"): + b_1 = T.axis.spatial(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + + +@T.prim_func +def transformed_square_sum_square_root_factor_one_1(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) - C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block("C_rf"): - vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) + with T.block("C"): + b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): - C_rf[vi1_i2_fused_inner, b] = 0.0 - C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in T.serial(0, 16): + with T.block("D"): + b_1 = T.axis.S(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + - for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): +@T.prim_func +def square_sum_square_root_factor_one_1_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16,), "float32"] +) -> None: + C = T.alloc_buffer([16], dtype="float32") + C_rf = T.alloc_buffer([1, 16], dtype="float32") + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + with T.block("C_rf"): + b = T.axis.spatial(16, i0) + i = T.axis.reduce(256, i1_i2_fused_outer // 256) + j = T.axis.reduce(256, i1_i2_fused_outer % 256) + vi1_i2_fused_inner = T.axis.spatial(1, i1_i2_fused_inner) + with T.init(): + C_rf[vi1_i2_fused_inner, b] = T.float32(0) + C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + A[b, i, j] * A[b, i, j] + for i0, i1_i2_fused_inner in T.grid(16, 1): with T.block("C"): - vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) + b, vi1_i2_fused_inner = T.axis.remap("SR", [i0, i1_i2_fused_inner]) with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] + C[b] = T.float32(0) + C[b] = C[b] + C_rf[vi1_i2_fused_inner, b] + for i0_1 in T.serial(16): + with T.block("D"): + b_1 = T.axis.spatial(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + - for i0_2 in T.serial(0, 16): +@T.prim_func +def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + D = T.match_buffer(d, [16]) + C = T.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): + with T.block("C"): + b = T.axis.S(16, i0) + i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) + with T.init(): + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in T.serial(0, 16): with T.block("D"): - b_2 = T.axis.S(16, i0_2) - D[b_2] = T.sqrt(C[b_2], dtype="float32") + b_1 = T.axis.S(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + + +@T.prim_func +def square_sum_square_root_factor_one_2_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16,), "float32"] +) -> None: + C = T.alloc_buffer([16], dtype="float32") + C_rf = T.alloc_buffer([16, 1], dtype="float32") + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): + with T.block("C_rf"): + b = T.axis.spatial(16, i0) + i = T.axis.reduce(256, i1_i2_fused_inner // 256) + j = T.axis.reduce(256, i1_i2_fused_inner % 256) + vi1_i2_fused_outer = T.axis.spatial(1, i1_i2_fused_outer) + with T.init(): + C_rf[b, vi1_i2_fused_outer] = T.float32(0) + C_rf[b, vi1_i2_fused_outer] = C_rf[b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j] + for i0, i1_i2_fused_outer in T.grid(16, 1): + with T.block("C"): + b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vi1_i2_fused_outer] + for i0_1 in T.serial(16): + with T.block("D"): + b_1 = T.axis.spatial(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") @T.prim_func @@ -229,26 +322,24 @@ def square_sum_with_annotation(a: T.handle, c: T.handle) -> None: @T.prim_func -def square_sum_with_annotation_rfactor(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - C = T.match_buffer(c, [16]) - C_rf = T.alloc_buffer([16, 256]) - - for i0, i1, i2 in T.grid(16, 256, 256): +def square_sum_with_annotation_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], C: T.Buffer[(16,), "float32"] +) -> None: + C_rf = T.alloc_buffer([16, 256], dtype="float32") + for b0, i0, j0 in T.grid(16, 256, 256): with T.block("C_rf"): + b, i, vj0 = T.axis.remap("SRS", [b0, i0, j0]) T.block_attr({"test_annotation": 1}) - vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): - C_rf[b, vi2] = 0.0 - C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) - - for i0_1, i2_1 in T.grid(16, 256): + C_rf[b, vj0] = T.float32(0) + C_rf[b, vj0] = C_rf[b, vj0] + A[b, i, vj0] * A[b, i, vj0] + for b0, j0 in T.grid(16, 256): with T.block("C"): + b, vj0 = T.axis.remap("SR", [b0, j0]) T.block_attr({"test_annotation": 1}) - vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vj0] @T.prim_func @@ -370,24 +461,20 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: @T.prim_func -def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128]) - B = T.match_buffer(b, []) - B_rf = T.alloc_buffer([128]) - - for i in range(128): +def rowsum_zero_dim_rfactor(A: T.Buffer[(128,), "float32"], B: T.Buffer[(), "float32"]) -> None: + B_rf = T.alloc_buffer([128], dtype="float32") + for k0 in T.serial(128): with T.block("B_rf"): - vi0 = T.axis.S(128, i) + vk0 = T.axis.spatial(128, k0) with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] - - for i in range(128): + B_rf[vk0] = T.float32(0) + B_rf[vk0] = B_rf[vk0] + A[vk0] + for k0 in T.serial(128): with T.block("B"): - vi0_1 = T.axis.R(128, i) + vk0 = T.axis.reduce(128, k0) with T.init(): - B[()] = 0.0 - B[()] = B[()] + B_rf[vi0_1] + B[()] = T.float32(0) + B[()] = B[()] + B_rf[vk0] @T.prim_func @@ -405,20 +492,20 @@ def rowsum_predicate(a: T.handle, b: T.handle) -> None: @T.prim_func -def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128, 128], dtype="float32") - B = T.match_buffer(b, [128], dtype="float32") +def rowsum_predicate_rfactor( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"] +) -> None: B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block("B_rf"): - vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) + vi, vk_0, vk_1 = T.axis.remap("SSR", [i, k_0, k_1]) T.where(k_0 * 10 + k_1 < 128) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): with T.block("B"): - vk_0, vi = T.axis.remap("RS", [k_0, i]) + vi, vk_0 = T.axis.remap("SR", [i, k_0]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0] @@ -466,50 +553,49 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: @T.prim_func -def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: - A = T.match_buffer(a, [16, 16, 16]) - C = T.alloc_buffer([16, 16]) - D = T.alloc_buffer([16, 16]) - E = T.alloc_buffer([16, 16]) - F = T.match_buffer(f, [16, 16]) - C_rf = T.alloc_buffer([16, 16, 4]) - +def multiple_reduction_blocks_rfactor( + A: T.Buffer[(16, 16, 16), "float32"], F: T.Buffer[(16, 16), "float32"] +) -> None: + C = T.alloc_buffer([16, 16], dtype="float32") + D = T.alloc_buffer([16, 16], dtype="float32") + E = T.alloc_buffer([16, 16], dtype="float32") + C_rf = T.alloc_buffer([16, 16, 4], dtype="float32") for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): with T.block("C_rf"): - vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) + ci, cj, vk1o, vk1i = T.axis.remap("SSSR", [i, j1, k1o, k1i]) with T.init(): - C_rf[ci, cj, vk1o] = 0.0 - C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] - for i_1 in T.serial(0, 16): - for j1_1 in T.serial(0, 16): - for k1o_1 in T.serial(0, 4): + C_rf[ci, cj, vk1o] = T.float32(0) + C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, vk1o * 4 + vk1i] + for i in T.serial(16): + for j1 in T.serial(16): + for k1o in T.serial(4): with T.block("C"): - vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) + ci, cj, vk1o = T.axis.remap("SSR", [i, j1, k1o]) with T.init(): - C[ci_1, cj_1] = 0.0 - C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] + C[ci, cj] = T.float32(0) + C[ci, cj] = C[ci, cj] + C_rf[ci, cj, vk1o] for k2o, k2i in T.grid(4, 4): with T.block("D"): - di, dj = T.axis.remap("SS", [i_1, j1_1]) - dk = T.axis.R(16, k2o * 4 + k2i) + di, dj = T.axis.remap("SS", [i, j1]) + dk = T.axis.reduce(16, k2o * 4 + k2i) with T.init(): - D[di, dj] = 0.0 - D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] - for j2 in T.serial(0, 16): + D[di, dj] = T.float32(0) + D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] + for j2 in T.serial(16): for k3o, k3i in T.grid(4, 4): with T.block("E"): - ei, ej = T.axis.remap("SS", [i_1, j2]) - ek = T.axis.R(16, k3o * 4 + k3i) + ei, ej = T.axis.remap("SS", [i, j2]) + ek = T.axis.reduce(16, k3o * 4 + k3i) with T.init(): - E[ei, ej] = 0.0 - E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] + E[ei, ej] = T.float32(0) + E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): with T.block("F"): - fi, fj = T.axis.remap("SS", [i_1, j2]) - fk = T.axis.R(16, k4o * 4 + k4i) + fi, fj = T.axis.remap("SS", [i, j2]) + fk = T.axis.reduce(16, k4o * 4 + k4i) with T.init(): - F[fi, fj] = 0.0 - F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] + F[fi, fj] = T.float32(0) + F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -548,6 +634,28 @@ def test_reduction_rfactor_square_sum_square_root(): verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) +def test_reduction_rfactor_square_sum_square_root_factor_one_1(): + s = tir.Schedule(transformed_square_sum_square_root_factor_one_1, debug_mask="all") + C = s.get_block("C") + _, _, f_i = s.get_loops(C) + rf_block = s.rfactor(f_i, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_factor_one_1_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root_factor_one_1) + + +def test_reduction_rfactor_square_sum_square_root_factor_one_2(): + s = tir.Schedule(transformed_square_sum_square_root_factor_one_2, debug_mask="all") + C = s.get_block("C") + _, f_o, _ = s.get_loops(C) + rf_block = s.rfactor(f_o, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_factor_one_2_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root_factor_one_2) + + def test_reduction_rfactor_loop_multiple_children(): s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") k, _, _ = s.get_loops(s.get_block("C")) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 84ececebbcba..fd2115bddbed 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -66,7 +66,7 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: for i_j_k_fused in T.serial(0, (n * 16384)): with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) - vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, n*128), n)) vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -164,7 +164,7 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: for fused in T.serial(0, 2097152): with T.block("B"): vi = T.axis.S(128, T.floordiv(fused, 16384)) - vj = T.axis.S(128, T.floormod(T.floordiv(fused, 128), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(fused, 16384), 128)) vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -205,7 +205,7 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): with T.block("B"): - T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128) vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) vj = T.axis.S(128, j1) vk = T.axis.S(128, k0 * 43 + k1) @@ -223,8 +223,8 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: T.reads( [ A[ - T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), - T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floordiv(i_j_k_fused, 16384), + T.floordiv(T.floormod(i_j_k_fused, 16384), 128), T.floormod(i_j_k_fused, 128), ] ] @@ -232,15 +232,15 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: T.writes( [ B[ - T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), - T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floordiv(i_j_k_fused, 16384), + T.floordiv(T.floormod(i_j_k_fused, 16384), 128), T.floormod(i_j_k_fused, 128), ] ] ) with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) - vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, 16384), 128)) vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -343,7 +343,7 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: with T.block("B"): vi = T.axis.S( 127, - i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), + i * 32 + T.floordiv(j_k_fused, 128), ) vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index f1c97c57b2ff..1923eb23af5b 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -82,6 +82,15 @@ def _make_compute_inline(input): # pylint: disable=redefined-builtin ) +def _make_split(inputs, outputs): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("Split"), + inputs=inputs, + attrs=[], + outputs=outputs, + ) + + def _make_enter_postproc(): return Instruction( kind=InstructionKind.get("EnterPostproc"), @@ -129,6 +138,17 @@ def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name return Trace(insts=insts, decisions={}) +def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="B", output=b0), + _make_get_loops(input=b0, outputs=[l1]), + _make_split([l1, None, 32], [l2, l3]), + ], + decisions={}, + ) + + def test_trace_construct_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) assert str(trace) == "\n".join( @@ -235,6 +255,17 @@ def test_trace_simplified_2(): ) +def test_trace_simplified_3(): + trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "l1, = sch.get_loops(block=b0)", + "l2, l3 = sch.split(loop=l1, factors=[None, 32])", + ) + ) + + def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json() diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py new file mode 100644 index 000000000000..0962e147ff96 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +def packed_index_map_func(m, n): + return m // 16, n // 16, m % 16, n % 16 + + +@T.prim_func +def two_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((8, 8, 16, 16), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 + + +@T.prim_func +def two_elementwise_transformed_input_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (8, 8, 16, 16), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi // 16, vj // 16, vi % 16, vj % 16] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def two_elementwise_transformed_output_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (8, 8, 16, 16), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi // 16, vj // 16, vi % 16, vj % 16] = B[vi, vj] + 1.0 + + +@T.prim_func +def permuted_shared_memory(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + A_shared = T.alloc_buffer((128, 128), scope="shared") + for i0, j0, in T.grid(32, 4): + for fused_i1_j1 in T.thread_binding(0, 32, 'threadIdx.x'): + for j2 in T.vectorized(0, 4): + with T.block("A_shared"): + vi = T.axis.S(128, i0 * 4 + fused_i1_j1 // 8) + vj = T.axis.S(128, j0 * 32 + fused_i1_j1 % 8 * 4 + j2) + A_shared[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_shared[vi, vj] + 1.0 + + +@T.prim_func +def permuted_shared_memory_transformed(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + A_shared = T.alloc_buffer((32, 4, 4, 32), scope="shared") + for i0, j0, in T.grid(32, 4): + for fused_i1_j1 in T.thread_binding(0, 32, 'threadIdx.x'): + for j2 in T.vectorized(0, 4): + with T.block("A_shared"): + vi = T.axis.S(128, i0 * 4 + fused_i1_j1 // 8) + vj = T.axis.S(128, j0 * 32 + fused_i1_j1 % 8 * 4 + j2) + A_shared[vi // 4, vj // 32, vi % 4, (((vj % 32) // 8) ^ (vi % 4)) + vj % 8] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_shared[vi // 4, vj // 32, vi % 4, (((vj % 32) // 8) ^ (vi % 4)) + vj % 8] + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks +# fmt: on + + +def test_two_elementwise_transform_intermediate_buffer(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("B") + sch.transform_layout(block, 0, False, packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +def test_two_elementwise_transform_input_buffer(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("B") + sch.transform_layout(block, 0, True, packed_index_map_func) + print(sch.mod['main'].script()) + tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +def test_two_elementwise_transform_output_buffer(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("C") + sch.transform_layout(block, 0, False, packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +@pytest.mark.skip("xor is not supported by IntSet") +def test_permuted_layout(): + sch = tir.Schedule(permuted_shared_memory, debug_mask="all") + block = sch.get_block("A_shared") + sch.transform_layout(block, 0, False, + lambda i, j: (i // 4, j // 32, i % 4, (((j % 32) // 8) ^ (i % 4)) + j % 8)) + tvm.ir.assert_structural_equal(permuted_shared_memory_transformed, sch.mod['main']) + verify_trace_roundtrip(sch=sch, mod=permuted_shared_memory) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_apply_block_bound_predicate.py b/tests/python/unittest/test_tir_transform_apply_block_bound_predicate.py new file mode 100644 index 000000000000..7e651fb7f1ea --- /dev/null +++ b/tests/python/unittest/test_tir_transform_apply_block_bound_predicate.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir, te +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.ApplyBlockBoundPredicate()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_print(original): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) + mod = tvm.tir.transform.LowerInitBlock()(mod) + mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod["main"].script()) + mod = tvm.tir.transform.ApplyBlockBoundPredicate()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod["main"].script()) + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + + +@T.prim_func +def read_out_of_bound_after_compute_at(A: T.Buffer[(16,), "float32"], C: T.Buffer[(16,), "float32"]) -> None: + B = T.alloc_buffer([16], dtype="float32") + for j in T.serial(16): + for ax0 in T.serial(2): + with T.block("B"): + v = T.axis.spatial(16, j + ax0) + T.reads(A[v]) + T.writes(B[v]) + T.block_attr({"require_bound_predicate":v >= 0 and v < 16}) + B[v] = A[v] + with T.block("C"): + v = T.axis.spatial(16, j) + T.reads(B[v : v + 2]) + T.writes(C[v]) + C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") + + +@T.prim_func +def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + cache = T.alloc_buffer([224, 224], dtype="float32") + dache = T.alloc_buffer([224, 224], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + with T.block("cache"): + h = T.axis.spatial(224, hh_0 * 8 + ax0 - 1) + w = T.axis.spatial(224, ww_0 * 8 + ax1 - 1) + T.reads(X[h, w]) + T.writes(cache[h, w]) + T.block_attr({"require_bound_predicate":h >= 0 and h < 224 and w >= 0 and w < 224}) + cache[h, w] = X[h, w] + for ax0, ax1 in T.grid(10, 10): + with T.block("dache"): + h = T.axis.spatial(224, hh_0 * 8 + ax0 - 1) + w = T.axis.spatial(224, ww_0 * 8 + ax1 - 1) + T.reads(X[h, w]) + T.writes(dache[h, w]) + T.block_attr({"require_bound_predicate":h >= 0 and h < 224 and w >= 0 and w < 224}) + dache[h, w] = X[h, w] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + h = T.axis.spatial(224, hh_0 * 8 + hh_1) + w = T.axis.spatial(224, ww_0 * 8 + ww_1) + kh, kw = T.axis.remap("RR", [khh, kww]) + T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1], dache[h + kh - 1, w + kw - 1]]) + T.writes([Y[h, w]]) + with T.init(): + Y[h, w] = 0.0 + Y[h, w] = T.max(Y[h, w], T.if_then_else( + T.likely(1 <= h + kh, dtype="bool") and \ + T.likely(h + kh < 225, dtype="bool") and \ + T.likely(1 <= w + kw, dtype="bool") and \ + T.likely(w + kw < 225, dtype="bool"), + cache[h + kh - 1, w + kw - 1]+ dache[h + kh - 1, w + kw - 1], 0.0, dtype="float32")) + + +@T.prim_func +def batch_norm_after_compute_at(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + for i0_0 in T.serial(1): + with T.block(): + T.reads(A[0 : 64, 0 : 256, 0 : 256]) + T.writes(D[0 : 64]) + C = T.alloc_buffer([1], dtype="float32") + for ax0, ax1, ax2 in T.grid(64, 256, 256): + with T.block("C"): + b = T.axis.spatial(1, ax0) + i, j = T.axis.remap("RR", [ax1, ax2]) + T.reads(C[b], A[b, i, j]) + T.writes(C[b]) + T.block_attr({"require_bound_predicate":b >= 0 and b < 1}) + if i == 0 and j == 0: + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.spatial(1, i0_1) + T.where(i0_1 < 1) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + + +@T.prim_func +def transformed_batch_norm(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + for i0_0 in T.serial(1): + with T.block(): + T.reads(A[0 : 64, 0 : 256, 0 : 256]) + T.writes(D[0 : 64]) + C = T.alloc_buffer([1], dtype="float32") + for ax0, ax1, ax2 in T.grid(1, 256, 256): + with T.block("C"): + b = T.axis.spatial(1, 0) + i, j = T.axis.remap("RR", [ax1, ax2]) + T.reads(C[b], A[b, i, j]) + T.writes(C[b]) + if i == 0 and j == 0: + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.spatial(1, i0_1) + T.where(i0_1 < 1) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks +# fmt: on + + +def test_read_out_of_bound(): + # This IR should not be mutated in this pass. + _check(read_out_of_bound_after_compute_at, read_out_of_bound_after_compute_at) + + +def test_tiled_pooling_cache(): + # This IR should not be mutated in this pass. + _check(tiled_pooling_cache_after_compute_at, tiled_pooling_cache_after_compute_at) + + +def test_batch_norm(): + _check(batch_norm_after_compute_at, transformed_batch_norm) + + +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.ApplyBlockBoundPredicate()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + + +if __name__ == "__main__": + test_read_out_of_bound() + test_tiled_pooling_cache() + test_batch_norm() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 9b844853f243..715e63d03619 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -221,7 +221,7 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) - B = T.alloc_buffer((8,), "float32") + B = T.alloc_buffer((T.min(n, 1) * 8,), "float32") for j in range(0, 8): with T.block() as []: T.reads(A[i * 8 + j]) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index ca3d4aa70d0b..6e41e8f96e40 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -254,6 +254,64 @@ def annotated_loops(a: T.handle) -> None: A[i] = 0.0 +@T.prim_func +def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + # body + # with T.block("root") + cache = T.alloc_buffer([10, 10], dtype="float32") + dache = T.alloc_buffer([10, 10], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + with T.block("cache"): + T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.writes(cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224}) + cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] + for ax0, ax1 in T.grid(10, 10): + with T.block("dache"): + T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.writes(dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224}) + dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + T.reads(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1], cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1], dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1]) + T.writes(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1]) + Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1] = T.max(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1], + T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool") + and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool") + and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool") + and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"), + cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1] + + dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1], + T.float32(0), dtype="float32")) + + +@T.prim_func +def flattened_tiled_pooling_cache_after_compute_at(X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"]) -> None: + cache = T.allocate([100], "float32", "global") + dache = T.allocate([100], "float32", "global") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225: + T.store(cache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True) + for ax0, ax1 in T.grid(10, 10): + if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225: + T.store(dache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True) + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + T.store(Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1, + T.max(T.load("float32", Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1), + T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool") + and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool") + and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool") + and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"), + T.load("float32", cache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11) + + T.load("float32", dache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11), + T.float32(0), dtype="float32")), True) + + def test_elementwise(): _check(compacted_elementwise_func, flattened_elementwise_func) @@ -305,6 +363,10 @@ def test_annotated_loops(): tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) +def test_bound_predicate(): + _check(tiled_pooling_cache_after_compute_at, flattened_tiled_pooling_cache_after_compute_at) + + if __name__ == "__main__": test_elementwise() test_gpu_workload() @@ -315,3 +377,4 @@ def test_annotated_loops(): test_strided_buffer() test_lower_te() test_annotated_loops() + test_bound_predicate() diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py new file mode 100644 index 000000000000..1c9b69665d1c --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -0,0 +1,741 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm import tir, te, TVMError +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod['main'].script()) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_error(func): + mod = tvm.IRModule.from_expr(func) + with pytest.raises(ValueError): + tvm.tir.transform.InjectSoftwarePipeline()(mod) + + +@T.prim_func +def simple_compute(a: T.handle, c: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + C = T.match_buffer(c, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1], 'software_pipeline_order': [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_simple_compute(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16]]) + T.writes([C[tx, 0:16]]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0]]) + T.writes([B[0, tx, 0]]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16], B[0:2, tx, 0]]) + T.writes([B[0:2, tx, 0], C[tx, 0:15]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1]]) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads([B[i % 2, tx, 0]]) + T.writes([C[tx, i]]) + C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 0]]) + T.writes([C[tx, 15]]) + C[tx, 15] = B[1, tx, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_simple(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, + annotations={"software_pipeline_stage": [0, 1, 1, 1], + "software_pipeline_order": [0, 1, 2, 3]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1] + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_simple(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + # body + # with T.block("root") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 0]]) + T.writes([B[0, tx, i, 0]]) + B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 0]]) + T.writes([B[0, tx, 15, 0]]) + B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_prefetch_inner(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 1, 1], "software_pipeline_order": [0, 2, 1, 3]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_prefetch_inner(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]]) + T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_interleaving(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_interleaving(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_double_buffer(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.block_attr({"double_buffer_scope": 0}) + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_double_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + # body + # with T.block("root") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[0, 0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[0:2, tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0:2, 0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( + 2 + ) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[(i + 1) % 2, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_incorrect_reorder(a: T.handle, d: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + D = T.match_buffer(d, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [0, 2, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_conflicting_order(a: T.handle, d: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + D = T.match_buffer(d, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [ 0, 1, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_missing_annotation(a: T.handle, c: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + C = T.match_buffer(c, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + + +def test_simple_compute(): + _check(simple_compute, transformed_simple_compute) + + +def test_nest_pipeline_simple(): + _check(nested_pipeline_simple, transformed_nested_pipeline_simple) + + +def test_nest_pipeline_prefetch_inner(): + _check(nested_pipeline_prefetch_inner, transformed_nested_pipeline_prefetch_inner) + + +def test_nest_pipeline_interleaving(): + _check(nested_pipeline_interleaving, transformed_nested_pipeline_interleaving) + + +def test_nest_pipeline_double_buffer(): + _check(nested_pipeline_double_buffer, transformed_nested_pipeline_double_buffer) + + +# def test_error_reorder(): +# _check_error(simple_compute_incorrect_reorder) + + +# def test_error_conflicting_order(): +# _check_error(simple_compute_conflicting_order) + + +def test_error_missing_annotation(): + _check_error(simple_compute_missing_annotation) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py new file mode 100644 index 000000000000..5a2ede204769 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py @@ -0,0 +1,398 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +from tvm.script import tir as T +import sys +import pytest + + +@tvm.script.ir_module +class Transpose: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 16): + A_shared_dyn[ax1, ax0] = A[ax0, ax1] + with T.block("B"): + for ax1, ax0 in T.grid(16, 128): + T.block_attr({"auto_copy": 1}) + B[ax1, ax0] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax1, ax0 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16, "local_stage": True}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToWmma: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float16", scope="shared.dyn") + A_wmma = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + with T.block("A_wmma"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + A_wmma[ax0, ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToShared: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + C_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("C_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + C_shared[ax0, ax1] = C_accum[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToGlobal: + @T.prim_func + def main(c: T.handle) -> None: + C = T.match_buffer(c, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + with T.block("C_global"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + C[bx * 128 + ax0, by * 128 + ax1] = C_accum[ax0, ax1] + +@tvm.script.ir_module +class TransformedGlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy":1, "vector_bytes":16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + +@tvm.script.ir_module +class TransformedSharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[129, 1], scope="shared.dyn") + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0 : 128, 0 : 128]) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy":1, "vector_bytes":16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + B[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128] + +@tvm.script.ir_module +class TransformedGlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1, "local_stage":True, "vector_bytes":16}) + A_local = T.alloc_buffer([16, 4], dtype="float32", scope="local") + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 16, 1, 1, 1): + for vec in T.vectorized(4): + A_local[ax0 * 16 + ax1 + ax2, (ax3 + ax4) * 4 + vec] = A[((bx % 8 + ax0) * 16 + ax1) * 8 + (ty_1 % 128 + ax2), ((by % 8 + ax3) * 32 + (tx % 32 + ax4)) * 4 + vec] + for serial in T.serial(16): + for vec in T.vectorized(4): + A_shared_dyn[(((serial * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((serial * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A_local[(serial * 8 + (tx * 4 + vec) // 128 + ty_1) % 128 // 8 + (((tx * 4 + vec) // 128 + ty_1) % 8 - ty_1 % 128), ((tx * 4 + vec) % 128 // 4 - tx % 32) * 4 + vec % 4] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + +@tvm.script.ir_module +class TransformedSharedToWmma: + @T.prim_func + def main() -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float16", strides=[136, 1], scope="shared.dyn") + A_wmma = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + with T.block("C_shared"): + T.reads(A_shared_dyn[0 : 128, 0 : 128]) + T.writes(A_wmma[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_load"): + T.reads(A_shared_dyn[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(A_wmma[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + src = T.match_buffer(A_shared_dyn[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float16", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + tgt = T.match_buffer(A_wmma[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(tgt.data, 16, 16, 16, tgt.elem_offset // 256 + tgt.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), src.data, src.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + +@tvm.script.ir_module +class TransformedWmmaToShared: + @T.prim_func + def main() -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + C_shared = T.alloc_buffer([128, 128], dtype="float32", strides=[136, 1], scope="shared.dyn") + with T.block("A_wmma"): + T.reads(C_accum[0 : 128, 0 : 128]) + T.writes(C_shared[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_store"): + T.reads(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(C_shared[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + src = T.match_buffer(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + tgt = T.match_buffer(C_shared[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(src.data, 16, 16, 16, src.elem_offset // 256 + src.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), tgt.data, tgt.elem_offset, s1 * 16, 2, dtype="handle"), s1, "row_major", dtype="handle")) + +@tvm.script.ir_module +class TransformedWmmaToGlobal: + @T.prim_func + def main(C: T.Buffer[(1024, 1024), "float32"]) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.reads() + T.writes(C[0 : 1024, 0 : 1024]) + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + T.reads() + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + with T.block("C_global"): + T.reads(C_accum[0 : 128, 0 : 128]) + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.block_attr({"auto_copy":1, "vector_bytes":16}) + C_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", strides=[136, 1], scope="shared.dyn") + for ax0_0 in T.serial(8): + for ax1_0 in T.serial(8): + with T.block("wmma_store"): + T.reads(C_accum[ax0_0 * 16 : ax0_0 * 16 + 16, ax1_0 * 16 : ax1_0 * 16 + 16]) + T.writes(C_shared_dyn[(ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 : (ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 + 16, (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 : (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 + 16]) + src = T.match_buffer(C_accum[ax0_0 * 16 : ax0_0 * 16 + 16, ax1_0 * 16 : ax1_0 * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + tgt = T.match_buffer(C_shared_dyn[(ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 : (ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 + 16, (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 : (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 + 16], [16, 16], dtype="float32", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(src.data, 16, 16, 16, src.elem_offset // 256 + src.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), tgt.data, tgt.elem_offset, s1 * 16, 2, dtype="handle"), s1, "row_major", dtype="handle")) + for ax0_ax1_ax2_ax3_ax4_ax5_fused_0 in T.serial(2): + for ax0_ax1_ax2_ax3_ax4_ax5_fused_1 in T.thread_binding(8, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_ax4_ax5_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_ax4_ax5_fused_3 in T.vectorized(4): + C[((bx % 8 + 0) * 8 + (ax0_0 % 64 + 0)) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 // 8 % 16, ((by % 8 + 0) * 8 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 % 8) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) % 16] = C_shared_dyn[(0 + 0) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 // 8 % 16, (0 * 8 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 % 8) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) % 16] + + +def _check(original, transformed): + mod = tvm.tir.transform.LowerAutoCopy()(original) + tvm.ir.assert_structural_equal(mod, transformed, True) + + +def test_coalesce_vectorize(): + _check(GlobalToShared, TransformedGlobalToShared) + + +def test_inverse(): + _check(SharedToGlobal, TransformedSharedToGlobal) + + +def test_local_stage(): + _check(GlobalToSharedWithLocalStage, TransformedGlobalToSharedWithLocalStage) + + +def test_rewrite_shared_to_wmma(): + _check(SharedToWmma, TransformedSharedToWmma) + + +def test_rewrite_wmma_to_shared(): + _check(WmmaToShared, TransformedWmmaToShared) + + +def test_rewrite_wmma_to_global(): + _check(WmmaToGlobal, TransformedWmmaToGlobal) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +def test_auto_padding(): + mod = tvm.tir.transform.LowerAutoCopy()(Transpose) + mod = tvm.tir.transform.FlattenBuffer()(mod) + verify_single_allocation(mod['main'].body, 16 * 130) + + +if __name__ == "__main__": + test_coalesce_vectorize() + test_inverse() + test_local_stage() + test_rewrite_shared_to_wmma() + test_rewrite_wmma_to_shared() + test_rewrite_wmma_to_global() + test_auto_padding() diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py new file mode 100644 index 000000000000..6217d2f0989a --- /dev/null +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T + + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +def tesd_renormalize_split_pattern(): + after = tvm.tir.transform.RenomalizeSplitPattern()(Before) + tvm.ir.assert_structural_equal(after, After) + + +if __name__ == "__main__": + tesd_renormalize_split_pattern()