Skip to content

Commit

Permalink
[MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensoriza…
Browse files Browse the repository at this point in the history
…tion on CUDA

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
  • Loading branch information
3 people committed Jul 11, 2022
1 parent c4dc41a commit 8dd2de5
Show file tree
Hide file tree
Showing 15 changed files with 995 additions and 55 deletions.
26 changes: 26 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
* intrinsics
* \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
* "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
* loading operand A, loading operand B, tensor core computation, storing the result. The value of
* the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
* beforehand
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,10 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori

/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
/*!
* \brief Mark that the init statement of a block should be further rewritten using tensorization.
*/
constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";

/*!
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
from .multi_level_tiling import (
MultiLevelTiling,
MultiLevelTilingWithIntrin,
ReuseType,
MultiLevelTilingTensorCore,
)
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
56 changes: 55 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Multi-level tiling with reuse."""
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, Mapping, NamedTuple, Optional

from tvm._ffi import register_object

Expand Down Expand Up @@ -131,3 +131,57 @@ def __init__(
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)


@register_object("meta_schedule.MultiLevelTilingTensorCore")
class MultiLevelTilingTensorCore(ScheduleRule):
"""Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
intrinsics.
Parameters
----------
intrin_group : Mapping[str, str]
A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
"compute", "store", which represent the tensor intrin for initialization, loading operand A,
loading operand B, tensor core computation, storing the result.
The value of the map should be names of tensor intrinsics, must be registerd via
TensorIntrin.register(...) beforehand
structure : str
The tiling structure. Recommended:
- 'SSRSRS' on CPU
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- None on CPU
- [blockIdx.y, vthread.x, threadIdx.y] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
intrin_group: Mapping[str, str],
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
intrin_group,
structure,
tile_binds,
max_innermost_factor,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ReuseType,
ScheduleRule,
)
from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
from tvm.tir import tensor_intrin
from tvm.target import Target


Expand Down Expand Up @@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling_tensor_core(target: Target, scope="shared") -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert scope in ["shared", "global"]
if target.kind.name == "cuda":
return MultiLevelTilingTensorCore(
intrin_group={
"init": tensor_intrin.WMMA_FILL_16x16x16_F32_INTRIN,
"load_a": tensor_intrin.WMMA_LOAD_16x16x16_F16_A_INTRIN,
"load_b": tensor_intrin.WMMA_LOAD_16x16x16_F16_B_INTRIN,
"compute": tensor_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
"store": tensor_intrin.WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if scope == "shared"
else tensor_intrin.WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
},
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must" if scope == "shared" else "no",
levels=[2],
scope="shared",
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def random_compute_location(target: Target) -> ScheduleRule:
"""Default schedule rules for with random-compute-location"""
if target.kind.name == "llvm":
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,15 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
)

WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_trans"
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_trans"
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_INTRIN,
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
)

Expand Down
15 changes: 15 additions & 0 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);

// Rewrite auto tensorization related annotations
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
// Remove tensorization annotation as it shouldn't be propagated to the init block.
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize);
}
if (Optional<String> tensorize_init =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init)) {
// The annotation of tensorization of the init statement should be moved to the init block
// after 'DecomposeReduction'.
sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
tensorize_init.value());
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
}
++rewritten;
}
if (rewritten == 0) {
Expand Down
34 changes: 16 additions & 18 deletions src/meta_schedule/postproc/rewrite_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,24 @@ void CollectTensorizationJobs(
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
if (const auto* block = obj.as<tir::BlockNode>()) {
tir::StmtSRef block_sref = sch->GetSRef(block);
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (vectorize_init_loop) {
jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
}
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (block_name.find("init") && vectorize_init_loop) {
jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
}
}
});
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref)) {
CanReverseComputeInline(state, block_sref) &&
!GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
return InlineType::kInlineIntoProducer;
}
}
Expand Down
28 changes: 18 additions & 10 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;

TVM_REGISTER_OBJECT_TYPE(StateNode);

State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
ObjectPtr<StateNode> node = make_object<StateNode>();
node->sch = std::move(sch);
Expand Down Expand Up @@ -133,6 +135,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
results.push_back(std::move(new_state));
}
state->write_reuse.emplace(0, consumer_rvs[0]);
results.push_back(state);
return results;
} else {
Expand All @@ -146,6 +149,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
BlockRV write_cache =
state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
/*storage_scope=*/config.scope);
state->write_reuse.emplace(0, write_cache);
for (int level : levels) {
State new_state = state->Copy();
const LoopRV& loop_rv = new_state->tiles[level - 1].back();
Expand Down Expand Up @@ -247,22 +251,26 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
// Annotate cooperative fetching
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
AnnotateCooperativeFetching(&sch, cache_read_block);
new_state->read_reuse.emplace(i, cache_read_block);
}
results.push_back(std::move(new_state));
}
return results;
}

void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
const tir::BlockRV& block) const {
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
(*sch)->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
(*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
}
}

// Constructor

ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
Expand Down
35 changes: 34 additions & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/tir/schedule/schedule.h>

#include <unordered_map>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -93,6 +94,10 @@ class StateNode : public Object {
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;
/*! \brief The mapping from buffer index to read cache block. */
std::unordered_map<int, tir::BlockRV> read_reuse;
/*! \brief The mapping from buffer index to write cache block. */
std::unordered_map<int, tir::BlockRV> write_reuse;

/*!
* \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
Expand All @@ -112,6 +117,31 @@ class State : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
};

class TensorCoreStateNode : public StateNode {
public:
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
tir::BlockRV tensor_core_reindex_B;
/*! \brief The Tensor Core reindex store block for Tensor Core computation */
tir::BlockRV tensor_core_reindex_store;

State Copy() const final;

static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
};

class TensorCoreState : public State {
public:
explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
Array<Array<tir::LoopRV>> tiles = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};

struct AutoTensorizationState : public State {};

/*!
* \brief Helper to apply a sub-rule to a list of auto scheduling states
* \tparam FLambda The type of the sub-rule functor
Expand Down Expand Up @@ -148,11 +178,14 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
void InitializeWithTuneContext(const TuneContext& context) final;

// Entry of the mega rule; Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override;

protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);

// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;

public:
/*!
* \brief The tiling structure. Recommended:
Expand Down
Loading

0 comments on commit 8dd2de5

Please sign in to comment.