From cde325a927987c7e35c42c25b240023c8b3a7e8c Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 13 Jul 2022 10:52:58 -0700 Subject: [PATCH] address comments --- .../tvm/meta_schedule/testing/schedule_rule.py | 14 ++++++++++---- python/tvm/tir/tensor_intrin/cuda.py | 10 +++++----- .../multi_level_tiling_tensor_core.cc | 17 ++++++++++++++--- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 65161d356fce2..9b6f5ba24257e 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -113,13 +113,19 @@ def multi_level_tiling(target: Target) -> ScheduleRule: def multi_level_tiling_tensor_core( - target: Target, scope="shared", in_dtype="float16", out_dtype="float32", trans_b=False + target: Target, + write_reuse_scope="shared", + in_dtype="float16", + out_dtype="float32", + trans_b=False, ) -> ScheduleRule: """Default schedule rules for with multi-level tiling reuse for tensor core""" assert scope in ["shared", "global"] if target.kind.name == "cuda": return MultiLevelTilingTensorCore( - intrin_group=tensor_intrin.get_wmma_intrin_group(scope, in_dtype, out_dtype, trans_b), + intrin_group=tensor_intrin.get_wmma_intrin_group( + write_reuse_scope, in_dtype, out_dtype, trans_b + ), structure="SSSRRSRS", tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], max_innermost_factor=4, # 64 // tensor intrin size @@ -130,9 +136,9 @@ def multi_level_tiling_tensor_core( scope="shared", ), reuse_write=ReuseType( - req="must" if scope == "shared" else "no", + req="must" if write_reuse_scope == "shared" else "no", levels=[2], - scope="shared", + scope=write_reuse_scope, ), ) raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index e5014d9e9e44a..e7d5defcf3215 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -809,13 +809,13 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: def get_wmma_intrin_group( - scope: str, in_dtype: str, out_dtype: str, trans_b: bool + store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool ) -> Dict[str, str]: """Get a group of intrinsics for wmma tensor core with the given configurations Parameters ---------- - scope : str + store_scope : str Must be one of ["global", "shared"]. The memory scope of the result buffer. in_dtype : str @@ -832,7 +832,7 @@ def get_wmma_intrin_group( ret : Dict[str, str] A group of tensor intrinsics. """ - assert scope in ["global", "shared"] + assert store_scope in ["global", "shared"] assert in_dtype in ["float16"] assert out_dtype in ["float16", "float32"] @@ -856,10 +856,10 @@ def get_wmma_intrin_group( } store_intrins = { "float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN - if scope == "shared" + if store_scope == "shared" else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN, "float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN - if scope == "shared" + if store_scope == "shared" else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, } return { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index bc4ea64243ee1..7465f1307c7cc 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -80,7 +80,8 @@ State TensorCoreStateNode::Copy() const { } /*! - * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core + * intrinsics. */ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { private: @@ -216,8 +217,18 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( sch->ComputeInline(state->tensor_core_reindex_A); sch->ComputeInline(state->tensor_core_reindex_B); - sch->StorageAlign(state->read_reuse.at(0), 0, -2, 32, 8); - sch->StorageAlign(state->read_reuse.at(1), 0, -2, 32, 8); + for (int i = 0; i < 2; ++i) { + const tir::BlockRV cache_read = state->read_reuse.at(i); + const tir::BlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); + tir::Buffer cache_read_buffer = tir::GetNthAccessBuffer( + sch->state(), GetRef(cache_read_block), 0, tir::BufferIndexType::kWrite); + const DataType& dtype = cache_read_buffer->dtype; + if (dtype.is_float16()) { + sch->StorageAlign(cache_read, 0, -2, 32, 8); + } else if (dtype.is_int() && dtype.bits() == 8) { + sch->StorageAlign(cache_read, 0, -2, 32, 16); + } + } return {state}; }