From 7674ea84fea6de441b37927ec8f59679ae2f2927 Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Thu, 15 Dec 2022 17:24:55 -0800 Subject: [PATCH] =?UTF-8?q?Add=20check=20for=20non-contiguous=20memory=20a?= =?UTF-8?q?ccess=20when=20lowering=20to=20async=20dma=E2=80=A6=20(#13613)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add check for non-contiguous memory access when lowering to async dma copies. * lint * lint and nits * lint --- src/tir/transforms/lower_async_dma.cc | 28 +++ .../test_hexagon/test_async_dma_pipeline.py | 206 ++++++++++++++++++ 2 files changed, 234 insertions(+) diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 9a950c10c776..94769dae0899 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -22,6 +22,7 @@ */ #include +#include #include #include @@ -34,6 +35,12 @@ class AsyncDMALowerer : public StmtExprMutator { public: explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {} + // Create member statement to track a mapping from iter var to iter range + Stmt VisitStmt_(const ForNode* op) final { + input_iters.Set(op->loop_var, Range(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const AttrStmtNode* op) final { // Convert this, for example: // attr [0] "async_wait_queue_scope" = 0; @@ -146,6 +153,17 @@ class AsyncDMALowerer : public StmtExprMutator { // map loop variable to zero for the store index & simplify Array store_index = bufferstorenode->indices; + + // Use DetectIterMap to detect whether store index is non-contiguous. + arith::Analyzer analyzer; + auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck, + &analyzer, false); + if (!store_iter_map->errors.empty()) { + LOG(FATAL) + << "Unable to lower async dma for non contiguous memory access with store index: " + << store_index; + } + store_index.MutateByApply([&](PrimExpr expr) { arith::Analyzer analyzer; return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); @@ -153,6 +171,15 @@ class AsyncDMALowerer : public StmtExprMutator { // map loop variable to zero for the load index & simplify Array load_index = bufferloadnode->indices; + + // Use DetectIterMap to detect whether load index is non-contiguous. + auto load_iter_map = + DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false); + if (!load_iter_map->errors.empty()) { + LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: " + << load_index; + } + load_index.MutateByApply([&](PrimExpr expr) { arith::Analyzer analyzer; return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); @@ -176,6 +203,7 @@ class AsyncDMALowerer : public StmtExprMutator { private: std::set queue_ids_; bool dma_bypass_cache_; + Map input_iters = Map(); }; namespace transform { diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 51427f18f6f4..914a26c51180 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -25,6 +25,193 @@ VRMPY_SIZE_B = 128 VRMPY_SIZE_INT32 = 32 +# pylint: disable=invalid-name +@T.prim_func +def conv2d_async_non_contig( + p0: T.Buffer[(T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)), "uint8"], + fused_constant_1: T.Buffer[ + (T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)), + "uint8", + ], + conv2d_NCHWc_int8: T.Buffer[ + (T.int64(1), T.int64(1), T.int64(54), T.int64(54), T.int64(32)), "int32" + ], +): + """Non contiguous memory access is used in this conv2d taken from MS.""" + # pylint: disable=no-self-argument + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + p0_global_vtcm = T.alloc_buffer( + [T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)], + dtype="uint8", + scope="global.vtcm", + ) + fused_constant_global_vtcm = T.alloc_buffer( + [T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)], + dtype="uint8", + scope="global.vtcm", + ) + for oh_0 in T.serial(T.int64(3)): + for ow_0 in T.serial( + T.int64(3), + annotations={ + "software_pipeline_async_stages": [0], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): + for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(1600)): + with T.block("p0_global.vtcm"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial( + T.int64(56), oh_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused // T.int64(80) + ) + v3 = T.axis.spatial( + T.int64(56), + ow_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(80) // T.int64(4), + ) + v4 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_fused % T.int64(4)) + T.reads(p0[v0, v1, v2, v3, v4]) + T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) + p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] + for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(1152)): + with T.block("fused_constant_global.vtcm"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(1), T.int64(0)) + v2 = T.axis.spatial( + T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(384) + ) + v3 = T.axis.spatial( + T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(384) // T.int64(128) + ) + v4 = T.axis.spatial(T.int64(1), T.int64(0)) + v5 = T.axis.spatial( + T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4) + ) + v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4)) + T.reads(fused_constant_1[v0, v1, v2, v3, v4, v5, v6]) + T.writes(fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) + fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = fused_constant_1[ + v0, v1, v2, v3, v4, v5, v6 + ] + for oh_1, ow_1 in T.grid(T.int64(3), T.int64(6)): + for oh_2_init, ow_2_init in T.grid(T.int64(6), T.int64(3)): + with T.block("conv2d_NCHWc_int8_o_init"): + v_n = T.axis.spatial(T.int64(1), T.int64(0)) + v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) + v_oh = T.axis.spatial( + T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2_init + ) + v_ow = T.axis.spatial( + T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2_init + ) + T.reads() + T.writes( + conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] + ) + for oc_block_1 in T.vectorized(T.int64(32)): + with T.block("conv2d_NCHWc_int8_init"): + v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1) + T.reads() + T.writes( + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init + ] + ) + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init + ] = 0 + for kh_1, kw_1, oh_2, ow_2 in T.grid( + T.int64(3), T.int64(3), T.int64(6), T.int64(3) + ): + with T.block("conv2d_NCHWc_int8_o_update"): + v_n = T.axis.spatial(T.int64(1), T.int64(0)) + v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) + v_oh = T.axis.spatial( + T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2 + ) + v_ow = T.axis.spatial( + T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2 + ) + v_kh, v_kw = T.axis.remap("RR", [kh_1, kw_1]) + v_ic_outer = T.axis.reduce(T.int64(1), T.int64(0)) + v_ic_f_inner = T.axis.reduce(T.int64(1), T.int64(0)) + T.reads( + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) + ], + p0_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4), + ], + fused_constant_global_vtcm[ + v_oc_chunk, + v_ic_outer, + v_kh, + v_kw, + v_ic_f_inner, + T.int64(0) : T.int64(32), + T.int64(0) : T.int64(4), + ], + ) + T.writes( + conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] + ) + A = T.match_buffer( + p0_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4), + ], + [T.int64(4)], + dtype="uint8", + scope="global.vtcm", + offset_factor=1, + ) + B = T.match_buffer( + fused_constant_global_vtcm[ + v_oc_chunk, + v_ic_outer, + v_kh, + v_kw, + v_ic_f_inner, + T.int64(0) : T.int64(32), + T.int64(0) : T.int64(4), + ], + [T.int64(32), T.int64(4)], + dtype="uint8", + scope="global.vtcm", + offset_factor=1, + ) + C = T.match_buffer( + conv2d_NCHWc_int8[ + v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) + ], + [T.int64(32)], + dtype="int32", + offset_factor=1, + ) + A_u8x4: T.uint8x4 = A[T.int64(0) : T.int64(4)] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x128 = B[T.int64(0), T.int64(0) : T.int64(128)] + B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + C[0:32] = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), + T.uint32(3), + C[0:32], + B_i32x32, + A_i32, + dtype="int32x32", + ) + def conv_approximation(size_a, size_w): """Conv approximation.""" @@ -695,5 +882,24 @@ def test_meta(hexagon_session): ) +def test_non_contiguous(): + """Test Non Contiguous memory lowering.""" + sch = tvm.tir.Schedule(conv2d_async_non_contig) + target_hexagon = tvm.target.hexagon("v68", link_params=True) + err_rgx = r"Unable to lower async dma for non contiguous memory access with load index: " + # Currently we do not support non contiguous memory access being lowered to + # async dma so we throw an error. + with pytest.raises(tvm.TVMError, match=err_rgx): + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": 1, + "tir.merge_async_commit_queue_scope": 0, + } + ): + tvm.build( + sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) + ) + + if __name__ == "__main__": tvm.testing.main()