From 96416c4941df0ba292dcf2f7ddb35c8909860341 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 4 Feb 2022 07:12:29 +0800 Subject: [PATCH] [TIR][Schedule] Update compact_dataflow constraint (#10158) --- .../schedule_rule/auto_inline.cc | 5 +- .../schedule_rule/random_compute_location.cc | 5 +- src/tir/schedule/analysis.h | 23 +++-- src/tir/schedule/analysis/analysis.cc | 85 +++++++++---------- .../schedule/primitive/blockize_tensorize.cc | 6 +- .../schedule/primitive/cache_read_write.cc | 6 +- src/tir/schedule/primitive/compute_at.cc | 7 +- src/tir/schedule/primitive/compute_inline.cc | 8 +- src/tir/schedule/primitive/for_kind.cc | 6 +- src/tir/schedule/primitive/get_block_loop.cc | 6 +- src/tir/schedule/primitive/reduction.cc | 6 +- .../unittest/test_tir_schedule_compute_at.py | 52 +++++++++++- .../unittest/test_tir_schedule_for_kind.py | 61 +++++++++++++ 13 files changed, 189 insertions(+), 87 deletions(-) diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 38156f86e6cb..0cfe35298dd6 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -121,9 +121,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, } } // Last cond: Check inline into the consumers or the spatial producer - tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, // - /*require_stage_pipeline=*/false, // - /*require_subtree_compact_dataflow=*/false); + tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, + /*require_stage_pipeline=*/false); if (into_consumer) { Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index 957ad89af106..e4b5d5bde256 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -67,9 +67,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode { return false; } // Cond 2. The block should be the direct child block of the root block. - if (GetScopeRoot(sch->state(), block_sref, // - /*require_stage_pipeline=*/false, // - /*require_subtree_compact_dataflow=*/false) + if (GetScopeRoot(sch->state(), block_sref, + /*require_stage_pipeline=*/false) ->parent != nullptr) { return false; } diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index cdbb70bef6dd..dc8adb144b4b 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -76,20 +76,12 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); * \param self The schedule state * \param sref The sref whose scope is to be checked * \param require_stage_pipeline A boolean indicating whether to check stage pipeline - * \param require_subtree_compact_dataflow A boolean indicating whether to check - * subtree compact dataflow property. The scope root may have one or more subtrees rooted at - * its direct children, and this property requires all the blocks of the subtree - * that the specified sref is in to be complete block or reduction block. * \throw ScheduleError if * 1) the sref has been the root of the AST (so it has no scope root), or * 2) require_stage_pipeline = true, but its scope root is not a stage pipeline - * 3) require_subtree_compact_dataflow = true, but the subtree that the sref is in doesn't satisfy - * the compact dataflow condition, i.e. a block in the subtree is neither complete block nor - * reduction block * \return The block sref to the scope root */ -StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline, - bool require_subtree_compact_dataflow); +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline); /*! * \brief The information of a block scope, including the leaf blocks, @@ -173,6 +165,19 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees + * rooted at its direct children, and this property requires all the blocks of the subtree + * that the specified sref is in to be complete block or reduction block. + * \param self The schedule state + * \param subtree_root The sref of the subtree root to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact + * dataflow condition, i.e. a block in the subtree is neither complete block nor + * reduction block + */ +void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root, + const StmtSRef& scope_root_sref); /*! * \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is * not allocated under the current scope diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index afdff9d5f832..c7ed67187793 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -47,9 +47,8 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl /******** Scope ********/ -StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, // - bool require_stage_pipeline, // - bool require_subtree_compact_dataflow) { +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, + bool require_stage_pipeline) { class RootBlockError : public ScheduleError { public: explicit RootBlockError(IRModule mod) : mod_(mod) {} @@ -85,31 +84,6 @@ Definition of a scope that is a stage pipeline: Block block_; }; - class NotCompactDataFlowError : public ScheduleError { - public: - explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) - : mod_(std::move(mod)), - subtree_root_(std::move(subtree_root)), - violate_block_(std::move(violate_block)) { - ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); - } - String FastErrorString() const final { - return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " - "because some of its child block on SRef tree is neither a complete block nor a " - "reduction block"; - } - String DetailRenderTemplate() const final { - return "The queried subtree root {0} in SRef tree does not have compact dataflow, because " - "its child block {1} on SRef tree is neither a complete block nor a reduction block"; - } - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } - - IRModule mod_; - Stmt subtree_root_; - Block violate_block_; - }; - StmtSRef scope_root_sref{nullptr}; StmtSRef scope_root_subtree{nullptr}; // Step 1. Find the scope root and the subtree that the given sref is in @@ -135,18 +109,6 @@ Definition of a scope that is a stage pipeline: throw NotStagePipelineError(self->mod, GetRef(block)); } } - // Step 3. Handle `require_subtree_compact_dataflow` - if (require_subtree_compact_dataflow) { - Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_subtree); - for (const StmtSRef& block_sref : child_block_srefs) { - if (!IsCompleteBlock(self, block_sref, scope_root_sref) && - !IsReductionBlock(self, block_sref, scope_root_sref)) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - throw NotCompactDataFlowError(self->mod, GetRef(scope_root_subtree->stmt), - GetRef(block)); - } - } - } return scope_root_sref; } @@ -401,6 +363,44 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl reduction_block_error_code); } +void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root, + const StmtSRef& scope_root_sref) { + class NotCompactDataFlowError : public ScheduleError { + public: + explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block) + : mod_(std::move(mod)), + subtree_root_(std::move(subtree_root)), + violate_block_(std::move(violate_block)) { + ICHECK(subtree_root_->IsInstance() || subtree_root_->IsInstance()); + } + String FastErrorString() const final { + return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " + "because some of its child block on SRef tree is neither a complete block nor a " + "reduction block"; + } + String DetailRenderTemplate() const final { + return "The queried subtree root {0} in SRef tree does not have compact dataflow, because " + "its child block {1} on SRef tree is neither a complete block nor a reduction block"; + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } + + IRModule mod_; + Stmt subtree_root_; + Block violate_block_; + }; + + Array child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); + for (const StmtSRef& block_sref : child_block_srefs) { + if (!IsCompleteBlock(self, block_sref, scope_root_sref) && + !IsReductionBlock(self, block_sref, scope_root_sref)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw NotCompactDataFlowError(self->mod, GetRef(subtree_root->stmt), + GetRef(block)); + } + } +} + bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref); @@ -1843,9 +1843,8 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } // Cond 2. The block is a reduction block and has trivial binding. - const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // - /*require_stage_pipeline=*/false, // - /*require_subtree_compact_dataflow=*/false); + const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, + /*require_stage_pipeline=*/false); if (!IsReductionBlock(self, block_sref, scope_sref) // || !IsTrivialBinding(self, block_sref) // || HasBeenMultiLevelTiled(block_sref)) { diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index bbeb9caaab9b..bbabcbeb4592 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -515,8 +515,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { // Step 8: Update the cached flags StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); - StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false, - /*require_subtree_compact_dataflow=*/false); + StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false); bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); self->block_info[scope_root].affine_binding = scope_block_affine_binding; @@ -629,8 +628,7 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, self->Replace(block_sref, new_block, {{block_realize->block, new_block}}); // Step 6: Update the cached flags. - StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, - /*require_subtree_compact_dataflow=*/false); + StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); self->UpdateScopeBlockInfo(static_cast(scope_root->stmt)->body); } diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 4a80279d97cb..05695a8c4dc4 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -631,8 +631,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Buffer read_buffer = GetNthAccessBuffer(self, GetRef(block), read_buffer_index, /*is_write=*/false); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/false); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); // Step 2. Create CacheStageInfo @@ -703,8 +702,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Buffer write_buffer = GetNthAccessBuffer(self, GetRef(block), write_buffer_index, /*is_write=*/true); - StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/false); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); // Step 2. Creating CacheStageInfo CacheStageInfo info; diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 00886e8f8a22..0063a9ab43f0 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -456,14 +456,15 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks - // Check condition 1) and 2): stage pipeline and subtree compact dataflow + // Check condition 1) : scope stage pipeline StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, - /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/true); + /*require_stage_pipeline=*/true); Block scope_root = GetRef(scope_root_sref->StmtAs()); BlockScope scope = self->GetBlockScope(scope_root_sref); Array producer_srefs = GetProducers(block_sref, scope); Array consumer_srefs = GetConsumers(block_sref, scope); + // Check condition 2) : `block` is a complete or reduction block + CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref); // Check condition 3): `block` and `loop` are under the same scope, // and `loop` is not the ancestor of `block` NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref, diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index fe2c679142b7..9a9860b42bc6 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -548,9 +548,8 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, // - /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/false); + StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, + /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); CheckCompleteBlock(self, producer_block_sref, scope_root_sref); @@ -593,8 +592,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block Block consumer_block = GetRef(_consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // - /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/false); + /*require_stage_pipeline=*/true); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref); // Step 2. Check completeness diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index bff429312f31..333d78346453 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -157,9 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref * parallelized/vectorized/bound. */ // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. - GetScopeRoot(self, loop_sref, // - /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/true); + StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref, + /*require_stage_pipeline=*/true); + CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref); // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index c044de3bc644..a13e52515708 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -78,8 +78,7 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent } Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { - StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, - /*require_stage_pipeline=*/false); + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); Array edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref); Array results; results.reserve(edges.size()); @@ -92,8 +91,7 @@ Array GetProducers(const ScheduleState& self, const StmtSRef& block_sr } Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { - StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, - /*require_stage_pipeline=*/false); + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); Array edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref); Array results; results.reserve(edges.size()); diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 72ea199bf436..4b9b78e3b299 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -203,8 +203,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, } // Cond 1. Check block is reduction StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, - /*require_stage_pipeline=*/false, - /*require_subtree_compact_dataflow=*/false); + /*require_stage_pipeline=*/false); CheckReductionBlock(self, block_sref, scope_root_sref); // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); @@ -1009,8 +1008,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); const Block& block = block_realize->block; StmtSRef scope_root = GetScopeRoot(self, block_sref, // - /*require_stage_pipeline=*/true, - /*require_subtree_compact_dataflow=*/false); + /*require_stage_pipeline=*/true); CheckReductionBlock(self, block_sref, scope_root); const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 240a1cc9f53b..756407dfd72d 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -628,6 +628,7 @@ def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] + @T.prim_func def not_all_compact_data_flow(a: T.handle, c: T.handle): A = T.match_buffer(a, (128, 128), "float32") @@ -645,6 +646,7 @@ def not_all_compact_data_flow(a: T.handle, c: T.handle): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj * 2 + 1] = B[vi, vj * 2 + 1] * 2.0 + @T.prim_func def not_all_compact_data_flow_after_compute_at(a: T.handle, c: T.handle): A = T.match_buffer(a, (128, 128), "float32") @@ -663,6 +665,7 @@ def not_all_compact_data_flow_after_compute_at(a: T.handle, c: T.handle): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj * 2 + 1] = B[vi, vj * 2 + 1] * 2.0 + @T.prim_func def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128), "float32") @@ -757,6 +760,42 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") +@T.prim_func +def multi_reduction(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(), "float32"]): + B = T.alloc_buffer((16, ), dtype="float32") + for i, k in T.grid(16, 16): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0.0 + B[vi] += A[vi, vk] + for k in T.grid(16): + with T.block("C"): + vk = T.axis.remap("R", [k]) + with T.init(): + C[()] = 0.0 + C[()] += B[vk] + + +@T.prim_func +def multi_reduction_after_compute_at( + A: T.Buffer[(16, 16), "float32"], + C:T.Buffer[(), "float32"], +): + B = T.alloc_buffer((16, ), dtype="float32") + for k in T.grid(16): + for kk in T.grid(16): + with T.block("B"): + vi, vk = T.axis.remap("SR", [k, kk]) + with T.init(): + B[vi] = 0.0 + B[vi] += A[vi, vk] + with T.block("C"): + vk = T.axis.remap("R", [k]) + with T.init(): + C[()] = 0.0 + C[()] += B[vk] + # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -833,6 +872,15 @@ def test_compute_at_cuda_matmul_4(): verify_trace_roundtrip(sch=sch, mod=cuda_matmul_4) +def test_compute_at_reduction_block(): + sch = tir.Schedule(multi_reduction, debug_mask="all") + block = sch.get_block("B") + (loop,) = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=False) + tvm.ir.assert_structural_equal(multi_reduction_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=multi_reduction) + + def test_reverse_compute_at_tiled(): sch = tir.Schedule(tiled, debug_mask="all") block = sch.get_block("C") @@ -878,11 +926,11 @@ def test_compact_dataflow(): verify_trace_roundtrip(sch=sch, mod=not_all_compact_data_flow) -def test_fail_subtree_compact_dataflow(): +def test_fail_subtree_complete_block(): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0") loop, _ = sch.get_loops(sch.get_block("C")) - with pytest.raises(tvm.tir.ScheduleError, match="compact dataflow"): + with pytest.raises(tvm.tir.ScheduleError, match="complete block"): sch.compute_at(block, loop) diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 93876c668913..caecde05b40f 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -277,6 +277,59 @@ def thread_bound_block_inside_init(a: T.handle, b: T.handle) -> None: B[vi, vj] = B[vi, vj] + A[vi, vj, vk] +@T.prim_func +def decomposed_gemm( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +): + local = T.alloc_buffer((16, 16), "float32") + for i, j in T.grid(4, 4): + for ii, jj in T.grid(4, 4): + with T.block("init"): + vi = T.axis.S(16, i * 4 + ii) + vj = T.axis.S(16, j * 4 + jj) + local[vi, vj] = 0 + for k, ii, jj in T.grid(16, 4, 4): + with T.block("update"): + vi = T.axis.S(16, i * 4 + ii) + vj = T.axis.S(16, j * 4 + jj) + vk = T.axis.R(16, k) + local[vi, vj] += A[vi, vk] * B[vj, vk] + for ii, jj in T.grid(4, 4): + with T.block("C"): + vi = T.axis.S(16, i * 4 + ii) + vj = T.axis.S(16, j * 4 + jj) + C[vi, vj] = local[vi, vj] + + +@T.prim_func +def decomposed_gemm_after_vectorize( + A: T.Buffer[(16, 16), "float32"], + B: T.Buffer[(16, 16), "float32"], + C: T.Buffer[(16, 16), "float32"], +): + local = T.alloc_buffer((16, 16), "float32") + for i, j in T.grid(4, 4): + for ii, jj in T.grid(4, 4): + with T.block("init"): + vi = T.axis.S(16, i * 4 + ii) + vj = T.axis.S(16, j * 4 + jj) + local[vi, vj] = 0 + for k, ii, jj in T.grid(16, 4, 4): + with T.block("update"): + vi = T.axis.S(16, i * 4 + ii) + vj = T.axis.S(16, j * 4 + jj) + vk = T.axis.R(16, k) + local[vi, vj] += A[vi, vk] * B[vj, vk] + for ii in range(4): + for jj in T.vectorized(4): + with T.block("C"): + vi = T.axis.S(16, i * 4 + ii) + vj = T.axis.S(16, j * 4 + jj) + C[vi, vj] = local[vi, vj] + + # pylint: enable=no-member,invalid-name,unused-variable @@ -407,5 +460,13 @@ def test_block_inside_init(): verify_trace_roundtrip(s, mod=block_inside_init) +def test_vectorize_after_decompose(): + s = tir.Schedule(decomposed_gemm, debug_mask="all") + jj = s.get_loops(s.get_block("C"))[-1] + s.vectorize(jj) + tvm.ir.assert_structural_equal(s.mod["main"], decomposed_gemm_after_vectorize) + verify_trace_roundtrip(s, mod=decomposed_gemm) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))