Skip to content

Commit

Permalink
[Metaschedule] Auto tensorization for CPU / GPU dot product (#11088)
Browse files Browse the repository at this point in the history
* [Metaschedule] Auto-tensorization for CPU / GPU dot product

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>

* doc update

* add vnni conv2d test

* add dp4a test

* adding tests for rewrite_tensorize

* add rewrite_tensorize test

* add missing pydoc

* black

* more doc

* adding auto tensorize integration test

* add dp4a test

* fix target name

* fix dtype in test

* skip bert test

* replace hard-coded llvm intrinsic id in test with look up

* remove unnecessary include, add doc for the rest of params

* update postproc.h

* update doc

* fix shape in te matmul workload

* fix newline in cppdoc

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people committed Apr 26, 2022
1 parent 4dc47df commit 6846484
Show file tree
Hide file tree
Showing 15 changed files with 1,457 additions and 27 deletions.
6 changes: 4 additions & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ class Postproc : public runtime::ObjectRef {
*/
TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock);
/*!
* \brief Create a postprocessor that tensorize Tensor Core related components
* \brief Create a postprocessor that applies tensorization to annotated blocks
* \param vectorize_init_loop Whether or not vectorize the initialization loop produced by
* DecomposeReduction
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteTensorCore();
TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);

/*!
* \brief Creates a postprocessor that verifies if the GPU code is correct
Expand Down
23 changes: 23 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,29 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
* \param intrin_name The name of a tensor intrinsic, must be registerd via
* TensorIntrin.register(...) beforehand
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
String intrin_name, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";

/*!
* \brief Mark that a block should be further rewritten using tensorization.
*/
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
from .rewrite_tensorize import RewriteTensorize
38 changes: 38 additions & 0 deletions python/tvm/meta_schedule/postproc/rewrite_tensorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that tensorize related components."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.RewriteTensorize")
class RewriteTensorize(Postproc):
"""A postprocessor that applies tensorization to annotated blocks.
Parameters
----------
vectorize_init_loop : bool
Whether or not vectorize the initialization loop produced by DecomposeReduction
"""

def __init__(self, vectorize_init_loop=False) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member
vectorize_init_loop,
)
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, ReuseType
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
49 changes: 49 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,52 @@ def __init__(
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)


@register_object("meta_schedule.MultiLevelTilingWithIntrin")
class MultiLevelTilingWithIntrin(ScheduleRule):
"""Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
Parameters
----------
intrin_name : str
The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
structure : str
The tiling structure. Recommended:
- 'SSRSRS' on CPU
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- None on CPU
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
intrin_name: str,
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member
intrin_name,
structure,
tile_binds,
max_innermost_factor,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def f_compute(i, j):

def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
a = te.placeholder((n, k), name="A")
b = te.placeholder((m, k), name="B")
b = te.placeholder((k, m), name="B")
k = te.reduce_axis((0, k), name="k")
c = te.compute(
(n, m),
Expand Down
105 changes: 105 additions & 0 deletions src/meta_schedule/postproc/rewrite_tensorize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/meta_schedule/postproc.h>

#include <algorithm>

#include "../utils.h"

namespace tvm {
namespace meta_schedule {

using tir::BlockRV;
using tir::LoopRV;

void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
const tir::PrimFuncNode* func, bool vectorize_init_loop) {
std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs;

tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
if (const auto* block = obj.as<tir::BlockNode>()) {
tir::StmtSRef block_sref = sch->GetSRef(block);
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (vectorize_init_loop) {
jobs.emplace_back(block_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
}
}
}
});

for (auto kv : jobs) {
tir::BlockRV block = sch->GetBlock(kv.first, func_name);
sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
kv.second(block);
}
}

class RewriteTensorizeNode : public PostprocNode {
public:
void InitializeWithTuneContext(const TuneContext& context) final {}

bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

bool vectorize_init_loop = false;

static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode);
};

bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
for (const auto& kv : sch->mod()->functions) {
GlobalVar g_var = kv.first;
BaseFunc base_func = kv.second;
if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) {
ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop);
}
}
return true;
}

Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) {
ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
n->vectorize_init_loop = vectorize_init_loop;
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize")
.set_body_typed(Postproc::RewriteTensorize);

} // namespace meta_schedule
} // namespace tvm
25 changes: 3 additions & 22 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,28 +260,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>();
n->structure = structure;
n->tile_binds = tile_binds.value_or({});
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->vector_load_lens = vector_load_lens.defined()
? support::AsVector<Integer, int>(vector_load_lens.value())
: std::vector<int>();
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
for (int i = 0, len = structure.size(); i < len; ++i) {
char c = structure.data()[i];
if (c == 'S') {
n->s_indices_.push_back(i);
} else if (c == 'R') {
n->r_indices_.push_back(i);
} else {
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
}
}
n->thread_warp_size_ = -1;
n->max_threads_per_block_ = -1;
return ScheduleRule(n);
auto node = MultiLevelTilingInitCommon<MultiLevelTilingNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
return ScheduleRule(node);
}

TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode);
Expand Down
30 changes: 30 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,36 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode);
};

template <typename NodeType>
ObjectPtr<NodeType> MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<NodeType> n = make_object<NodeType>();
n->structure = structure;
n->tile_binds = tile_binds.value_or({});
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->vector_load_lens = vector_load_lens.defined()
? support::AsVector<Integer, int>(vector_load_lens.value())
: std::vector<int>();
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
for (int i = 0, len = structure.size(); i < len; ++i) {
char c = structure.data()[i];
if (c == 'S') {
n->s_indices_.push_back(i);
} else if (c == 'R') {
n->r_indices_.push_back(i);
} else {
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
}
}
n->thread_warp_size_ = -1;
n->max_threads_per_block_ = -1;
return n;
}

} // namespace meta_schedule
} // namespace tvm

Expand Down
Loading

0 comments on commit 6846484

Please sign in to comment.