diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 7ba7b77842e3..2153f8c979f3 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -403,6 +403,31 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the index mapping was performed at producer rather than consumer. + * \param block_rv The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \return The cache stage block. + */ + virtual BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + /*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block who writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the index mapping was performed at producer rather than consumer. + * \param block_rv The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \return The cache stage block. + */ + virtual BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 3b951cbe68c7..cde051b9e636 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1083,6 +1083,18 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member self, block, write_buffer_index, storage_scope ) + + @type_checked + def reverse_cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV: + return _ffi_api.ScheduleReverseCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope + ) + + @type_checked + def reverse_cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV: + return _ffi_api.ScheduleReverseCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope + ) ########## Schedule: Compute location ########## diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index ae8eaa15b58c..c9a79c06ff63 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -519,6 +519,27 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReverseCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = + tir::ReverseCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b675914bea38..57bec9d4f5c5 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -114,6 +114,10 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index bd89b2481142..17316d5c0531 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -259,6 +259,32 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); +/*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the index mapping was performed at producer instead of consumer. + * \param self The state of the schedule + * \param block_sref The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); +/*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block that writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the index mapping was performed at producer instead of consumer. + * \param self The state of the schedule + * \param block_sref The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); /******** Schedule: Compute location ********/ /*! diff --git a/src/tir/schedule/primitive/reverse_cache_read_write.cc b/src/tir/schedule/primitive/reverse_cache_read_write.cc new file mode 100644 index 000000000000..1cc834c92031 --- /dev/null +++ b/src/tir/schedule/primitive/reverse_cache_read_write.cc @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Error Classes ********/ + +/******** Helper Functions/Classes ********/ + +/*! \brief Mutator for ReverseCacheRead */ +class ReverseCacheReadRewriter : public StmtExprMutator {}; + +/*! \brief Mutator for ReverseCacheWrite */ +class ReverseCacheWriteRewriter : public StmtExprMutator {}; + +/******** Implementation ********/ + +StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope) {} + +StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope) {} + +/******** Instruction Registration ********/ + +struct ReverseCacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseCacheRead"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index, + String storage_scope) { + return sch->ReverseCacheRead(block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, + String storage_scope) { + PythonAPICall py("reverse_cache_read"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReverseCacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseCacheWrite"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index, + String storage_scope) { + return sch->ReverseCacheWrite(block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, + String storage_scope) { + PythonAPICall py("reverse_cache_write"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheWriteTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 79f25c004077..d7d708b143ce 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -275,6 +275,32 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReverseCacheRead(block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReverseCacheRead"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReverseCacheWrite(block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReverseCacheWrite"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 7ab522530786..56a5ef09e928 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -74,6 +74,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/tests/python/sparsetir/tc-spmm.py b/tests/python/sparsetir/tc-spmm.py new file mode 100644 index 000000000000..f4d02d290950 --- /dev/null +++ b/tests/python/sparsetir/tc-spmm.py @@ -0,0 +1,325 @@ +"""Tensor-Core SpMM +Related work: https://arxiv.org/pdf/2112.02052.pdf +""" + +import dgl +import tvm +import tvm.testing +import tvm.tir as tir +import argparse +from tvm.script import tir as T +from tvm.sparse import lower_sparse_buffer, lower_sparse_iter + + +@T.prim_func +def wmma_sync_desc(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None: + A_frag = T.match_buffer( + a_frag, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a" + ) + B_frag = T.match_buffer( + b_frag, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b" + ) + C_frag = T.match_buffer( + c_frag, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.accumulator" + ) + + with T.block("root"): + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + T.block_attr({"sparse": True}) + C_frag[vii, vjj] = C_frag[vii, vjj] + A_frag[vii, vkk] * B_frag[vkk, vjj] + + +@T.prim_func +def wmma_sync_impl(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None: + A_frag = T.match_buffer( + a_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a" + ) + B_frag = T.match_buffer( + b_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b" + ) + C_frag = T.match_buffer( + c_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.accumulator" + ) + + with T.block("root"): + T.reads( + [ + C_frag[0:16, 0:16], + A_frag[0:16, 0:16], + B_frag[0:16, 0:16], + ] + ) + T.writes(C_frag[0:16, 0:16]) + for tx in T.thread_binding(0, 32, "threadIdx.x"): + T.evaluate( + T.tvm_mma_sync( + C_frag.data, + C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), + A_frag.data, + A_frag.elem_offset // 256 + T.floordiv(T.floormod(A_frag.elem_offset, 256), 16), + B_frag.data, + B_frag.elem_offset // 256 + T.floordiv(T.floormod(B_frag.elem_offset, 256), 16), + C_frag.data, + C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_load_a_desc(a: T.handle, a_frag: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="global") + A_frag = T.match_buffer( + a_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a" + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(A_frag[0:16, 0:16]) + for i, j in T.grid(16, 16): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + A_frag[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_a_impl(a: T.handle, a_frag: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + A = T.match_buffer( + a, (16, 16), "float16", align=128, offset_factor=16, scope="global", strides=[s0, s1] + ) + A_frag = T.match_buffer( + a_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a" + ) + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(A_frag[0:16, 0:16]) + for tx in T.thread_binding(0, 32, "threadIdx.x"): + T.evaluate( + T.tvm_load_matrix_sync( + A_frag.data, + 16, + 16, + 16, + A_frag.elem_offset // 256 + T.floordiv(T.floormod(A_frag.elem_offset, 256), 16), + A.access_ptr("r"), + A.strides[0], + "row_major", + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_load_b_desc(b: T.handle, b_frag: T.handle) -> None: + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="global") + B_frag = T.match_buffer( + b_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b" + ) + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + B_frag[vii, vjj] = B[vii, vjj] + + +@T.prim_func +def wmma_load_b_impl(b: T.handle, b_frag: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + B = T.match_buffer( + b, (16, 16), "float16", align=128, offset_factor=16, scope="global", strides=[s0, s1] + ) + B_frag = T.match_buffer( + b_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b" + ) + with T.block("root"): + T.reads(B[0:16, 0:16]) + T.writes(B_frag[0:16, 0:16]) + for tx in T.thread_binding(0, 32, "threadIdx.x"): + T.evaluate( + T.tvm_load_matrix_sync( + B_frag.data, + 16, + 16, + 16, + B_frag.elem_offset // 256 + T.floordiv(T.floormod(B_frag.elem_offset, 256), 16), + B.access_ptr("r"), + B.strides[0], + "row_major", + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_fill_desc(c_frag: T.handle) -> None: + C_frag = T.match_buffer( + c_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.accumulator" + ) + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("init"): + vii, vjj = T.axis.remap("SS", [i, j]) + C_frag[vii, vjj] = T.float16(0) + + +@T.prim_func +def wmma_fill_impl(c_frag: T.handle) -> None: + C_frag = T.match_buffer( + c_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.accumulator" + ) + with T.block("root"): + T.reads([]) + T.writes(C_frag[0:16, 0:16]) + for tx in T.thread_binding(0, 32, "threadIdx.x"): + T.evaluate( + T.tvm_fill_fragment( + C_frag.data, + 16, + 16, + 16, + C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), + T.float16(0), + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_store_desc(c_frag: T.handle, c: T.handle) -> None: + C_frag = T.match_buffer( + c_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.accumulator" + ) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="global") + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = C_frag[vii, vjj] + + +@T.prim_func +def wmma_store_impl(c_frag: T.handle, c: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + C_frag = T.match_buffer( + c_frag, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.accumulator" + ) + C = T.match_buffer( + c, (16, 16), "float16", align=128, offset_factor=16, scope="global", strides=[s0, s1] + ) + with T.block("root"): + T.reads(C_frag[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for tx in T.thread_binding(0, 32, "threadIdx.x"): + T.evaluate( + T.tvm_store_matrix_sync( + C_frag.data, + 16, + 16, + 16, + C_frag.elem_offset // 256 + T.floordiv(T.floormod(C_frag.elem_offset, 256), 16), + C.access_ptr("w"), + C.strides[0], + "row_major", + dtype="handle", + ) + ) + + +WMMA_SYNC = tir.TensorIntrin.register( + "wmma_sync", + wmma_sync_desc, + wmma_sync_impl, +) + +WMMA_LOAD_A = tir.TensorIntrin.register( + "wmma_load_a", + wmma_load_a_desc, + wmma_load_a_impl, +) + +WMMA_LOAD_B = tir.TensorIntrin.register( + "wmma_load_b", + wmma_load_b_desc, + wmma_load_b_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_fill", + wmma_fill_desc, + wmma_fill_impl, +) + + +WMMA_STORE = tir.TensorIntrin.register( + "wmma_store", + wmma_store_desc, + wmma_store_impl, +) + + +@T.prim_func +def tcspmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + mb: T.int32, + nb: T.int32, + nnzb: T.int32, + feat_size: T.int32, + block_size: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 2}) + IO = T.dense_fixed(mb) + JO = T.dense_variable(IO, (nb, nnzb), indptr, "int32") + II = T.dense_fixed(block_size) + JI = T.sparse_fixed(JO, (nb * block_size, block_size), indices, "int32") + J = T.dense_fixed(nb * block_size) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, [IO, JO, II, JI], "float16") + B = T.match_sparse_buffer(b, [J, F], "float16") + C = T.match_sparse_buffer(c, [IO, II, F], "float16") + with T.iter([IO, JO, II, JI, F], "SRSRS", "tcspmm") as [io, jo, ii, ji, f]: + with T.init(): + C[io, ii, f] = T.float16(0) + C[io, ii, f] = ( + C[io, ii, f] + A[io, jo, ii, ji] * B[ji, f] + ) + + +def bench_tc_spmm(): + MB, NB, NNZB, F, B = tcspmm.params[-5:] + mod = tvm.IRModule.from_expr(tcspmm.specialize({ + MB: 128, NB: 128, NNZB: 1024, F: 64, B: 16 + })) + mod = lower_sparse_iter(mod) + sch = tir.Schedule(mod) + blk_outer = sch.get_block("tcspmm0") + blk_inner = sch.get_block("tcspmm1") + i, = sch.get_loops(blk_outer) + jo, ii, ji, f = sch.get_loops(blk_inner) + fo, fi = sch.split(f, [None, 16]) + sch.reorder(fo, ii, ji, fi) + new_blk = sch.blockize(ii) + A_local = sch.cache_read(blk_inner, 1, "wmma.matrix_a") + B_shared = sch.cache_read(blk_inner, 2, "shared") + print(sch.mod["main"].script()) + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("TensorCore SpMM in Sparse-TIR") + parser.add_argument("--dataset", "-d", type=str, default="arxiv", help="dataset name") + args = parser.parse_args() + name = args.dataset + # g = get_dataset(name) + bench_tc_spmm()