Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 13, 2022
1 parent d27dd7b commit cde325a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
14 changes: 10 additions & 4 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,19 @@ def multi_level_tiling(target: Target) -> ScheduleRule:


def multi_level_tiling_tensor_core(
target: Target, scope="shared", in_dtype="float16", out_dtype="float32", trans_b=False
target: Target,
write_reuse_scope="shared",
in_dtype="float16",
out_dtype="float32",
trans_b=False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert scope in ["shared", "global"]
if target.kind.name == "cuda":
return MultiLevelTilingTensorCore(
intrin_group=tensor_intrin.get_wmma_intrin_group(scope, in_dtype, out_dtype, trans_b),
intrin_group=tensor_intrin.get_wmma_intrin_group(
write_reuse_scope, in_dtype, out_dtype, trans_b
),
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
Expand All @@ -130,9 +136,9 @@ def multi_level_tiling_tensor_core(
scope="shared",
),
reuse_write=ReuseType(
req="must" if scope == "shared" else "no",
req="must" if write_reuse_scope == "shared" else "no",
levels=[2],
scope="shared",
scope=write_reuse_scope,
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,13 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:


def get_wmma_intrin_group(
scope: str, in_dtype: str, out_dtype: str, trans_b: bool
store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
) -> Dict[str, str]:
"""Get a group of intrinsics for wmma tensor core with the given configurations
Parameters
----------
scope : str
store_scope : str
Must be one of ["global", "shared"]. The memory scope of the result buffer.
in_dtype : str
Expand All @@ -832,7 +832,7 @@ def get_wmma_intrin_group(
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert scope in ["global", "shared"]
assert store_scope in ["global", "shared"]
assert in_dtype in ["float16"]
assert out_dtype in ["float16", "float32"]

Expand All @@ -856,10 +856,10 @@ def get_wmma_intrin_group(
}
store_intrins = {
"float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
if scope == "shared"
if store_scope == "shared"
else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
"float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if scope == "shared"
if store_scope == "shared"
else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
}
return {
Expand Down
17 changes: 14 additions & 3 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ State TensorCoreStateNode::Copy() const {
}

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
* intrinsics.
*/
class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
private:
Expand Down Expand Up @@ -216,8 +217,18 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
sch->ComputeInline(state->tensor_core_reindex_A);
sch->ComputeInline(state->tensor_core_reindex_B);

sch->StorageAlign(state->read_reuse.at(0), 0, -2, 32, 8);
sch->StorageAlign(state->read_reuse.at(1), 0, -2, 32, 8);
for (int i = 0; i < 2; ++i) {
const tir::BlockRV cache_read = state->read_reuse.at(i);
const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs<tir::BlockNode>();
tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer(
sch->state(), GetRef<tir::Block>(cache_read_block), 0, tir::BufferIndexType::kWrite);
const DataType& dtype = cache_read_buffer->dtype;
if (dtype.is_float16()) {
sch->StorageAlign(cache_read, 0, -2, 32, 8);
} else if (dtype.is_int() && dtype.bits() == 8) {
sch->StorageAlign(cache_read, 0, -2, 32, 16);
}
}
return {state};
}

Expand Down

0 comments on commit cde325a

Please sign in to comment.