Skip to content

Commit

Permalink
[TIR][Schedule] Update compact_dataflow constraint (#10158)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Feb 3, 2022
1 parent a8741e2 commit 96416c4
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 87 deletions.
5 changes: 2 additions & 3 deletions src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
}
}
// Last cond: Check inline into the consumers or the spatial producer
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref,
/*require_stage_pipeline=*/false);
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
Expand Down
5 changes: 2 additions & 3 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
return false;
}
// Cond 2. The block should be the direct child block of the root block.
if (GetScopeRoot(sch->state(), block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false)
if (GetScopeRoot(sch->state(), block_sref,
/*require_stage_pipeline=*/false)
->parent != nullptr) {
return false;
}
Expand Down
23 changes: 14 additions & 9 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,12 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
* \param self The schedule state
* \param sref The sref whose scope is to be checked
* \param require_stage_pipeline A boolean indicating whether to check stage pipeline
* \param require_subtree_compact_dataflow A boolean indicating whether to check
* subtree compact dataflow property. The scope root may have one or more subtrees rooted at
* its direct children, and this property requires all the blocks of the subtree
* that the specified sref is in to be complete block or reduction block.
* \throw ScheduleError if
* 1) the sref has been the root of the AST (so it has no scope root), or
* 2) require_stage_pipeline = true, but its scope root is not a stage pipeline
* 3) require_subtree_compact_dataflow = true, but the subtree that the sref is in doesn't satisfy
* the compact dataflow condition, i.e. a block in the subtree is neither complete block nor
* reduction block
* \return The block sref to the scope root
*/
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline,
bool require_subtree_compact_dataflow);
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline);

/*!
* \brief The information of a block scope, including the leaf blocks,
Expand Down Expand Up @@ -173,6 +165,19 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref);

/*!
* \brief Check the subtree compact dataflow property. The scope root may have one or more subtrees
* rooted at its direct children, and this property requires all the blocks of the subtree
* that the specified sref is in to be complete block or reduction block.
* \param self The schedule state
* \param subtree_root The sref of the subtree root to be checked
* \param scope_root_sref The scope root of the block
* \throw ScheduleError If the subtree that the sref is in doesn't satisfy the compact
* dataflow condition, i.e. a block in the subtree is neither complete block nor
* reduction block
*/
void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref);
/*!
* \brief Check if the block is an output block, i.e. the block writes to at least a buffer that is
* not allocated under the current scope
Expand Down
85 changes: 42 additions & 43 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl

/******** Scope ********/

StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, //
bool require_stage_pipeline, //
bool require_subtree_compact_dataflow) {
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref,
bool require_stage_pipeline) {
class RootBlockError : public ScheduleError {
public:
explicit RootBlockError(IRModule mod) : mod_(mod) {}
Expand Down Expand Up @@ -85,31 +84,6 @@ Definition of a scope that is a stage pipeline:
Block block_;
};

class NotCompactDataFlowError : public ScheduleError {
public:
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block)
: mod_(std::move(mod)),
subtree_root_(std::move(subtree_root)),
violate_block_(std::move(violate_block)) {
ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
}
String FastErrorString() const final {
return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
"because some of its child block on SRef tree is neither a complete block nor a "
"reduction block";
}
String DetailRenderTemplate() const final {
return "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a complete block nor a reduction block";
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }

IRModule mod_;
Stmt subtree_root_;
Block violate_block_;
};

StmtSRef scope_root_sref{nullptr};
StmtSRef scope_root_subtree{nullptr};
// Step 1. Find the scope root and the subtree that the given sref is in
Expand All @@ -135,18 +109,6 @@ Definition of a scope that is a stage pipeline:
throw NotStagePipelineError(self->mod, GetRef<Block>(block));
}
}
// Step 3. Handle `require_subtree_compact_dataflow`
if (require_subtree_compact_dataflow) {
Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_subtree);
for (const StmtSRef& block_sref : child_block_srefs) {
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
}
}
}
return scope_root_sref;
}

Expand Down Expand Up @@ -401,6 +363,44 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
reduction_block_error_code);
}

void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root,
const StmtSRef& scope_root_sref) {
class NotCompactDataFlowError : public ScheduleError {
public:
explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block)
: mod_(std::move(mod)),
subtree_root_(std::move(subtree_root)),
violate_block_(std::move(violate_block)) {
ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>());
}
String FastErrorString() const final {
return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, "
"because some of its child block on SRef tree is neither a complete block nor a "
"reduction block";
}
String DetailRenderTemplate() const final {
return "The queried subtree root {0} in SRef tree does not have compact dataflow, because "
"its child block {1} on SRef tree is neither a complete block nor a reduction block";
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; }

IRModule mod_;
Stmt subtree_root_;
Block violate_block_;
};

Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root);
for (const StmtSRef& block_sref : child_block_srefs) {
if (!IsCompleteBlock(self, block_sref, scope_root_sref) &&
!IsReductionBlock(self, block_sref, scope_root_sref)) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
GetRef<Block>(block));
}
}
}

bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_root_sref) {
const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref);
Expand Down Expand Up @@ -1843,9 +1843,8 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}

// Cond 2. The block is a reduction block and has trivial binding.
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
if (!IsReductionBlock(self, block_sref, scope_sref) //
|| !IsTrivialBinding(self, block_sref) //
|| HasBeenMultiLevelTiled(block_sref)) {
Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {

// Step 8: Update the cached flags
StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false);
bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
self->block_info[scope_root].affine_binding = scope_block_affine_binding;
Expand Down Expand Up @@ -629,8 +628,7 @@ void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
self->Replace(block_sref, new_block, {{block_realize->block, new_block}});

// Step 6: Update the cached flags.
StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
self->UpdateScopeBlockInfo(static_cast<const BlockNode*>(scope_root->stmt)->body);
}

Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer read_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, /*is_write=*/false);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);

// Step 2. Create CacheStageInfo
Expand Down Expand Up @@ -703,8 +702,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
Buffer write_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, /*is_write=*/true);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);

// Step 2. Creating CacheStageInfo
CacheStageInfo info;
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,15 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
// Check condition 1) and 2): stage pipeline and subtree compact dataflow
// Check condition 1) : scope stage pipeline
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/true);
/*require_stage_pipeline=*/true);
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
// Check condition 2) : `block` is a complete or reduction block
CheckCompleteOrReductionBlock(self, block_sref, scope_root_sref);
// Check condition 3): `block` and `loop` are under the same scope,
// and `loop` is not the ancestor of `block`
NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
Expand Down
8 changes: 3 additions & 5 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,8 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Block producer_block = GetRef<Block>(_producer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
StmtSRef scope_root_sref = GetScopeRoot(self, producer_block_sref,
/*require_stage_pipeline=*/true);
// Step 2. Check completeness
CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
Expand Down Expand Up @@ -593,8 +592,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
Block consumer_block = GetRef<Block>(_consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
/*require_stage_pipeline=*/true);
Buffer inlined_buffer =
NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
// Step 2. Check completeness
Expand Down
6 changes: 3 additions & 3 deletions src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
GetScopeRoot(self, loop_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/true);
StmtSRef scope_root_sref = GetScopeRoot(self, loop_sref,
/*require_stage_pipeline=*/true);
CheckSubtreeCompactDataflow(self, loop_sref, scope_root_sref);

// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent
}

Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_stage_pipeline=*/false);
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsByDst(block_sref);
Array<StmtSRef> results;
results.reserve(edges.size());
Expand All @@ -92,8 +91,7 @@ Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sr
}

Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false,
/*require_stage_pipeline=*/false);
StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
Array<Dependency> edges = self->GetBlockScope(scope_root)->GetDepsBySrc(block_sref);
Array<StmtSRef> results;
results.reserve(edges.size());
Expand Down
6 changes: 2 additions & 4 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
}
// Cond 1. Check block is reduction
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false,
/*require_subtree_compact_dataflow=*/false);
/*require_stage_pipeline=*/false);
CheckReductionBlock(self, block_sref, scope_root_sref);
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
Expand Down Expand Up @@ -1009,8 +1008,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get());
const Block& block = block_realize->block;
StmtSRef scope_root = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
/*require_stage_pipeline=*/true);
CheckReductionBlock(self, block_sref, scope_root);
const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref);
if (rf_loop->kind != ForKind::kSerial) {
Expand Down
Loading

0 comments on commit 96416c4

Please sign in to comment.