Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix CrossThreadReduction on CUDA #13

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_
constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";

/*!
* \brief Mark that the block need to add predicate for block var bounds during lowering
* \brief Mark that the block need to add predicate for block var bounds during lowering
*/
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";

Expand Down
7 changes: 7 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
*/
TVM_DLL Pass ApplyBlockBoundPredicate();

/*!
* \brief Narrow the extents of some loops by checking whether some constraints in the block iter
* bound predicates can be directly applied on the loops.
* \return The pass.
*/
TVM_DLL Pass ApplyBlockBoundPredicate();

/*!
* \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
* corresponding iter_values in BlockRealize, for opaque blocks by removing all
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def get(target: Target) -> List[ScheduleRule]:
]
if target.kind.name == "cuda":
return [
cross_thread_reduction(target),
multi_level_tiling(target),
auto_inline_after_tiling(target),
cross_thread_reduction(target),
parallel_vectorize_unroll(target),
]
raise NotImplementedError(f"{target.kind.name} is not supported")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _sch_rules() -> List[ScheduleRule]:
)

return [
M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
M.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
Expand All @@ -177,6 +176,7 @@ def _sch_rules() -> List[ScheduleRule]:
require_ordered=False,
disallow_op=None,
),
M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]),
M.ParallelizeVectorizeUnroll(
max_jobs_per_core=-1, # disable parallelize
max_vectorize_extent=-1, # disable vectorize
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/schedule_rule/random_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
return false;
}
// Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined()) {
// Cond 5. The block is not tiled.
if (tir::HasBeenMultiLevelTiled(block_sref)) {
return false;
}
// Cond 6. The block has at lease one consumer.
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,14 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri
*/
bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if the given block has been applied by multi-level tiling. We check this by examine
* the block's annotation.
* \param block_sref The block to be checked
* \return A boolean indicating whether the block has been multi-level tiled.
*/
bool HasBeenMultiLevelTiled(const StmtSRef& block_sref);

/*!
* \brief Checks if the rfactor or cross thread reduction is beneficial to the given block.
* \param self The schedule state.
Expand Down
14 changes: 11 additions & 3 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,10 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref
return total_unused_block_vars >= 1;
}

bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) {
return tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined();
}

std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref) {
Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
Expand Down Expand Up @@ -1976,12 +1980,16 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
return false;
}

// Cond 3. The block is a reduction block and has trivial binding.
// Cond 3. The block satisfies all the following properties
// - it is a reduction block;
// - it has trivial bindings;
// - it has not been tiled by multi-level tiling.
const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/false, //
/*require_subtree_compact_dataflow=*/false);
if (!(IsReductionBlock(self, block_sref, scope_sref) && //
IsTrivialBinding(self, block_sref))) {
if (!IsReductionBlock(self, block_sref, scope_sref) //
|| !IsTrivialBinding(self, block_sref) //
|| HasBeenMultiLevelTiled(block_sref)) {
return false;
}

Expand Down
130 changes: 130 additions & 0 deletions tests/python/unittest/test_meta_schedule_sketch_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def _target() -> Target:
return Target("cuda", host="llvm")


def _target_with_max_threads_per_block() -> Target:
return Target("nvidia/geforce-rtx-3080")


def test_meta_schedule_cuda_sketch_matmul():
# pylint: disable=line-too-long
expected = [
Expand Down Expand Up @@ -289,8 +293,134 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl
check_trace(spaces, expected)


def test_meta_schedule_cuda_sketch_batchnorm():
# pylint: disable=line-too-long
expected = [
[
'b0 = sch.get_block(name="C", func_name="main")',
'b1 = sch.get_block(name="root", func_name="main")',
"b2, = sch.get_consumers(block=b0)",
"l3, = sch.get_loops(block=b2)",
"v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l5, l6 = sch.split(loop=l3, factors=[None, v4])",
'sch.bind(loop=l6, thread_axis="threadIdx.x")',
"sch.compute_at(block=b0, loop=l5, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l7, l8, l9, l10 = sch.get_loops(block=b0)",
"l11 = sch.fuse(l9, l10)",
"l12, l13 = sch.split(loop=l11, factors=[None, v4])",
'sch.bind(loop=l13, thread_axis="threadIdx.x")',
"v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v14)',
],
[
'b0 = sch.get_block(name="root", func_name="main")',
"v1 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)',
],
]
# pylint: enable=line-too-long
ctx = create_context(
create_prim_func(
te_workload.norm_bmn(
B=1,
M=256,
N=256,
)
),
target=_target_with_max_threads_per_block(),
)
spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
assert len(spaces) == 2
check_trace(spaces, expected)


def test_meta_schedule_cuda_sketch_softmax():
# pylint: disable=line-too-long
expected = [
[
'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
'b1 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")',
'b3 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b1)",
"b4, = sch.get_consumers(block=b2)",
"l5, l6 = sch.get_loops(block=b4)",
"v7 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l8, l9 = sch.split(loop=l6, factors=[None, v7])",
'sch.bind(loop=l9, thread_axis="threadIdx.x")',
"sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)",
'sch.set_scope(block=b2, buffer_index=0, storage_scope="shared")',
"l10, l11, l12 = sch.get_loops(block=b2)",
"l13, l14 = sch.split(loop=l12, factors=[None, v7])",
'sch.bind(loop=l14, thread_axis="threadIdx.x")',
"b15, b16 = sch.get_consumers(block=b0)",
"l17, l18, l19, l20 = sch.get_loops(block=b15)",
"sch.compute_at(block=b0, loop=l17, preserve_unit_loops=True)",
'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")',
"l21, l22, l23 = sch.get_loops(block=b0)",
"l24, l25 = sch.split(loop=l23, factors=[None, v7])",
'sch.bind(loop=l25, thread_axis="threadIdx.x")',
"v26 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v26)',
],
[
'b0 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")',
'b2 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b0)",
"b3, = sch.get_consumers(block=b1)",
"l4, l5 = sch.get_loops(block=b3)",
"v6 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l7, l8 = sch.split(loop=l5, factors=[None, v6])",
'sch.bind(loop=l8, thread_axis="threadIdx.x")',
"sch.compute_at(block=b1, loop=l4, preserve_unit_loops=True)",
'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")',
"l9, l10, l11 = sch.get_loops(block=b1)",
"l12, l13 = sch.split(loop=l11, factors=[None, v6])",
'sch.bind(loop=l13, thread_axis="threadIdx.x")',
"v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v14)',
],
[
'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")',
'b1 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b2 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b1)",
"v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])",
"l4, l5 = sch.get_loops(block=b0)",
"l6, l7 = sch.split(loop=l5, factors=[None, v3])",
'sch.bind(loop=l7, thread_axis="threadIdx.x")',
"v8 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v8)',
],
[
'b0 = sch.get_block(name="T_softmax_exp", func_name="main")',
'b1 = sch.get_block(name="root", func_name="main")',
"sch.compute_inline(block=b0)",
"v2 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)',
],
]
# pylint: enable=line-too-long
ctx = create_context(
create_prim_func(
te_workload.softmax_mn(
m=256,
n=256,
)
),
target=_target_with_max_threads_per_block(),
)
spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
assert len(spaces) == 4
check_trace(spaces, expected)


if __name__ == "__main__":
test_meta_schedule_cuda_sketch_matmul()
test_meta_schedule_cuda_sketch_matmul_relu()
test_meta_schedule_cuda_sketch_conv2d_nchw()
test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu()
test_meta_schedule_cuda_sketch_batchnorm()
test_meta_schedule_cuda_sketch_softmax()