Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix][TensorIR] Non-positive constant input factors for split #9805

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,17 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
/*!
* \brief Get the producer of a specific block
* \brief Get the producer of a specific block, under the same block scope
* \param block_rv The block in the query
* \return A list of blocks, the producers of the given block
* \return A list of blocks, the producers of the given block under the same scope of the given
* block
*/
virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
/*!
* \brief Get the consumers of a specific block
* \brief Get the consumers of a specific block, under the same block scope
* \param block_rv The block to be queried
* \return A list of blocks, the consumers of the given block
* \return A list of blocks, the consumers of the given block under the same scope of the given
* block
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
Expand All @@ -266,8 +268,8 @@ class ScheduleNode : public runtime::Object {
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The tiling factors, and at most one of which is -1, which means that
* factor is inferred.
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def split(
Potential inputs are:
- None
- ExprRV
- Non-negative constant integers
- Positive constant integers

Returns
-------
Expand Down
33 changes: 30 additions & 3 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,31 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
IRModule mod_;
For loop_;
};

class NonPositiveFactorError : public ScheduleError {
public:
explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx)
: mod_(std::move(mod)), factor_(factor), idx_(idx) {}

String FastErrorString() const final {
return "ScheduleError: All the constant factors are required to be positive. However, some "
"constant input factor is zero or negative.";
}
String DetailRenderTemplate() const final {
std::ostringstream os;
os << "All the constant factors are required to be positive. However, the factor at position "
<< idx_ << " is " << factor_;
return os.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

private:
IRModule mod_;
int64_t factor_;
size_t idx_;
};

// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Expand All @@ -389,13 +414,15 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
for (size_t i = 0; i < factor_rvs.size(); i++) {
if (!factor_rvs[i].defined()) {
factors.push_back(Integer(-1));
if (infer_index == -1) {
infer_index = i;
} else {
if (infer_index != -1) {
throw NotSingleInferFactorError(state_->mod);
}
infer_index = i;
} else {
PrimExpr factor = this->Get(factor_rvs[i].value());
if (is_const_int(factor) && !is_positive_const(factor)) {
throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
}
factors.push_back(factor);
tot_length *= factor;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,18 @@ def test_split_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_split_with_non_positive_factors():
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
i, j, k = sch.get_loops(block_b)
with pytest.raises(tvm.tir.ScheduleError):
sch.split(i, factors=[-2, -64])
with pytest.raises(tvm.tir.ScheduleError):
sch.split(j, factors=[0, None])
with pytest.raises(tvm.tir.ScheduleError):
sch.split(k, factors=[None, -16])


def test_fuse_split_fail_with_thread_binding():
sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all")
block_b = sch.get_block("B")
Expand Down