Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA #12059

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
* intrinsics
* \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
* "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
* loading operand A, loading operand B, tensor core computation, storing the result. The value of
* the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
* beforehand
* \param structure The tiling structure. Recommended:
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,10 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori

/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
/*!
* \brief Mark that the init statement of a block should be further rewritten using tensorization.
*/
constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";

/*!
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
from .multi_level_tiling import (
MultiLevelTiling,
MultiLevelTilingWithIntrin,
ReuseType,
MultiLevelTilingTensorCore,
)
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
54 changes: 53 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Multi-level tiling with reuse."""
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, Mapping, NamedTuple, Optional

from tvm._ffi import register_object

Expand Down Expand Up @@ -131,3 +131,55 @@ def __init__(
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)


@register_object("meta_schedule.MultiLevelTilingTensorCore")
class MultiLevelTilingTensorCore(ScheduleRule):
"""Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
intrinsics.

Parameters
----------
intrin_group : Mapping[str, str]
A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
"compute", "store", which represent the tensor intrin for initialization, loading operand A,
loading operand B, tensor core computation, storing the result.
The value of the map should be names of tensor intrinsics, must be registerd via
TensorIntrin.register(...) beforehand
structure : str
The tiling structure. Recommended:
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- [blockIdx.y, vthread.x, threadIdx.y] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
intrin_group: Mapping[str, str],
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
intrin_group,
structure,
tile_binds,
max_innermost_factor,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ReuseType,
ScheduleRule,
)
from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
from tvm.tir import tensor_intrin
from tvm.target import Target


Expand Down Expand Up @@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling_tensor_core(
target: Target,
write_reuse_scope="shared",
in_dtype="float16",
out_dtype="float32",
trans_b=False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert write_reuse_scope in ["shared", "global"]
if target.kind.name == "cuda":
return MultiLevelTilingTensorCore(
intrin_group=tensor_intrin.get_wmma_intrin_group(
write_reuse_scope, in_dtype, out_dtype, trans_b
),
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must" if write_reuse_scope == "shared" else "no",
levels=[2],
scope=write_reuse_scope,
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def random_compute_location(target: Target) -> ScheduleRule:
"""Default schedule rules for with random-compute-location"""
if target.kind.name == "llvm":
Expand Down
73 changes: 68 additions & 5 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
from typing import Tuple
from typing import Tuple, Dict
from tvm.script import tir as T
from tvm.tir.function import PrimFunc
from .. import IntImm, Cast
Expand Down Expand Up @@ -769,15 +769,15 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
)

WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_trans"
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_trans"
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_INTRIN,
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
)

Expand Down Expand Up @@ -806,3 +806,66 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
TensorIntrin.register(
WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, "float16", "global")
)


def get_wmma_intrin_group(
store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
) -> Dict[str, str]:
"""Get a group of intrinsics for wmma tensor core with the given configurations

Parameters
----------
store_scope : str
Must be one of ["global", "shared"]. The memory scope of the result buffer.

in_dtype : str
The input data type.

out_dtype : str
The output data dtype.

trans_b : bool
Whether the input matrix B is transposed.

Returns
-------
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert store_scope in ["global", "shared"]
assert in_dtype in ["float16"]
assert out_dtype in ["float16", "float32"]

load_a_intrins = {"float16": WMMA_LOAD_16x16x16_F16_A_INTRIN}
load_b_intrins = {
"float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
if trans_b
else WMMA_LOAD_16x16x16_F16_B_INTRIN
}
compute_intrins = {
"float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f16_INTRIN,
"float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
if trans_b
else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
}
init_intrins = {
"float16": WMMA_FILL_16x16x16_F16_INTRIN,
"float32": WMMA_FILL_16x16x16_F32_INTRIN,
}
store_intrins = {
"float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
"float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if store_scope == "shared"
else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
}
return {
"init": init_intrins[out_dtype],
"load_a": load_a_intrins[in_dtype],
"load_b": load_b_intrins[in_dtype],
"compute": compute_intrins[out_dtype],
"store": store_intrins[out_dtype],
}
17 changes: 17 additions & 0 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,23 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);

// Rewrite auto tensorization related annotations
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
// Remove tensorization annotation as it shouldn't be propagated to the init block.
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize);
Optional<String> tensorize_init =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init);
// The annotation of tensorization of the init statement should be moved to the init block
// after 'DecomposeReduction'.
// Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt.
sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
tensorize_init.value_or(""));
if (tensorize_init.defined()) {
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
}
}
++rewritten;
}
if (rewritten == 0) {
Expand Down
6 changes: 3 additions & 3 deletions src/meta_schedule/postproc/rewrite_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ void CollectTensorizationJobs(
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
if (const auto* block = obj.as<tir::BlockNode>()) {
tir::StmtSRef block_sref = sch->GetSRef(block);
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
if (intrin_name.value() != "") {
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (vectorize_init_loop) {
} else if (block_name.find("init") && vectorize_init_loop) {
jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref)) {
CanReverseComputeInline(state, block_sref) &&
!GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
return InlineType::kInlineIntoProducer;
}
}
Expand Down
28 changes: 18 additions & 10 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;

TVM_REGISTER_OBJECT_TYPE(StateNode);

State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
ObjectPtr<StateNode> node = make_object<StateNode>();
node->sch = std::move(sch);
Expand Down Expand Up @@ -133,6 +135,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
results.push_back(std::move(new_state));
}
state->write_reuse.emplace(0, consumer_rvs[0]);
results.push_back(state);
return results;
} else {
Expand All @@ -146,6 +149,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
BlockRV write_cache =
state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
/*storage_scope=*/config.scope);
state->write_reuse.emplace(0, write_cache);
for (int level : levels) {
State new_state = state->Copy();
const LoopRV& loop_rv = new_state->tiles[level - 1].back();
Expand Down Expand Up @@ -247,22 +251,26 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
// Annotate cooperative fetching
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
AnnotateCooperativeFetching(&sch, cache_read_block);
new_state->read_reuse.emplace(i, cache_read_block);
}
results.push_back(std::move(new_state));
}
return results;
}

void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
const tir::BlockRV& block) const {
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
(*sch)->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
(*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
}
}

// Constructor

ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
Expand Down
Loading