Skip to content

Commit

Permalink
Add check for non-contiguous memory access when lowering to async dma… (
Browse files Browse the repository at this point in the history
#13613)

* Add check for non-contiguous memory access when lowering to async dma copies.

* lint

* lint and nits

* lint
  • Loading branch information
nverke committed Dec 16, 2022
1 parent cdb4eea commit 7674ea8
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/

#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand All @@ -34,6 +35,12 @@ class AsyncDMALowerer : public StmtExprMutator {
public:
explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {}

// Create member statement to track a mapping from iter var to iter range
Stmt VisitStmt_(const ForNode* op) final {
input_iters.Set(op->loop_var, Range(op->min, op->extent));
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
// Convert this, for example:
// attr [0] "async_wait_queue_scope" = 0;
Expand Down Expand Up @@ -146,13 +153,33 @@ class AsyncDMALowerer : public StmtExprMutator {

// map loop variable to zero for the store index & simplify
Array<PrimExpr> store_index = bufferstorenode->indices;

// Use DetectIterMap to detect whether store index is non-contiguous.
arith::Analyzer analyzer;
auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck,
&analyzer, false);
if (!store_iter_map->errors.empty()) {
LOG(FATAL)
<< "Unable to lower async dma for non contiguous memory access with store index: "
<< store_index;
}

store_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

// map loop variable to zero for the load index & simplify
Array<PrimExpr> load_index = bufferloadnode->indices;

// Use DetectIterMap to detect whether load index is non-contiguous.
auto load_iter_map =
DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false);
if (!load_iter_map->errors.empty()) {
LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: "
<< load_index;
}

load_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
Expand All @@ -176,6 +203,7 @@ class AsyncDMALowerer : public StmtExprMutator {
private:
std::set<int> queue_ids_;
bool dma_bypass_cache_;
Map<Var, Range> input_iters = Map<Var, Range>();
};

namespace transform {
Expand Down
206 changes: 206 additions & 0 deletions tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,193 @@
VRMPY_SIZE_B = 128
VRMPY_SIZE_INT32 = 32

# pylint: disable=invalid-name
@T.prim_func
def conv2d_async_non_contig(
p0: T.Buffer[(T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)), "uint8"],
fused_constant_1: T.Buffer[
(T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)),
"uint8",
],
conv2d_NCHWc_int8: T.Buffer[
(T.int64(1), T.int64(1), T.int64(54), T.int64(54), T.int64(32)), "int32"
],
):
"""Non contiguous memory access is used in this conv2d taken from MS."""
# pylint: disable=no-self-argument
# function attr dict
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
# body
# with T.block("root")
p0_global_vtcm = T.alloc_buffer(
[T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)],
dtype="uint8",
scope="global.vtcm",
)
fused_constant_global_vtcm = T.alloc_buffer(
[T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)],
dtype="uint8",
scope="global.vtcm",
)
for oh_0 in T.serial(T.int64(3)):
for ow_0 in T.serial(
T.int64(3),
annotations={
"software_pipeline_async_stages": [0],
"software_pipeline_order": [0, 1, 2],
"software_pipeline_stage": [0, 0, 1],
},
):
for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(1600)):
with T.block("p0_global.vtcm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(1), T.int64(0))
v2 = T.axis.spatial(
T.int64(56), oh_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused // T.int64(80)
)
v3 = T.axis.spatial(
T.int64(56),
ow_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(80) // T.int64(4),
)
v4 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_fused % T.int64(4))
T.reads(p0[v0, v1, v2, v3, v4])
T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(1152)):
with T.block("fused_constant_global.vtcm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(1), T.int64(0))
v2 = T.axis.spatial(
T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(384)
)
v3 = T.axis.spatial(
T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(384) // T.int64(128)
)
v4 = T.axis.spatial(T.int64(1), T.int64(0))
v5 = T.axis.spatial(
T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4)
)
v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4))
T.reads(fused_constant_1[v0, v1, v2, v3, v4, v5, v6])
T.writes(fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = fused_constant_1[
v0, v1, v2, v3, v4, v5, v6
]
for oh_1, ow_1 in T.grid(T.int64(3), T.int64(6)):
for oh_2_init, ow_2_init in T.grid(T.int64(6), T.int64(3)):
with T.block("conv2d_NCHWc_int8_o_init"):
v_n = T.axis.spatial(T.int64(1), T.int64(0))
v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0))
v_oh = T.axis.spatial(
T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2_init
)
v_ow = T.axis.spatial(
T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2_init
)
T.reads()
T.writes(
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]
)
for oc_block_1 in T.vectorized(T.int64(32)):
with T.block("conv2d_NCHWc_int8_init"):
v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1)
T.reads()
T.writes(
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init
]
)
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init
] = 0
for kh_1, kw_1, oh_2, ow_2 in T.grid(
T.int64(3), T.int64(3), T.int64(6), T.int64(3)
):
with T.block("conv2d_NCHWc_int8_o_update"):
v_n = T.axis.spatial(T.int64(1), T.int64(0))
v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0))
v_oh = T.axis.spatial(
T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2
)
v_ow = T.axis.spatial(
T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2
)
v_kh, v_kw = T.axis.remap("RR", [kh_1, kw_1])
v_ic_outer = T.axis.reduce(T.int64(1), T.int64(0))
v_ic_f_inner = T.axis.reduce(T.int64(1), T.int64(0))
T.reads(
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)
],
p0_global_vtcm[
v_n,
v_ic_outer,
v_oh + v_kh,
v_ow + v_kw,
v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4),
],
fused_constant_global_vtcm[
v_oc_chunk,
v_ic_outer,
v_kh,
v_kw,
v_ic_f_inner,
T.int64(0) : T.int64(32),
T.int64(0) : T.int64(4),
],
)
T.writes(
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)]
)
A = T.match_buffer(
p0_global_vtcm[
v_n,
v_ic_outer,
v_oh + v_kh,
v_ow + v_kw,
v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4),
],
[T.int64(4)],
dtype="uint8",
scope="global.vtcm",
offset_factor=1,
)
B = T.match_buffer(
fused_constant_global_vtcm[
v_oc_chunk,
v_ic_outer,
v_kh,
v_kw,
v_ic_f_inner,
T.int64(0) : T.int64(32),
T.int64(0) : T.int64(4),
],
[T.int64(32), T.int64(4)],
dtype="uint8",
scope="global.vtcm",
offset_factor=1,
)
C = T.match_buffer(
conv2d_NCHWc_int8[
v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)
],
[T.int64(32)],
dtype="int32",
offset_factor=1,
)
A_u8x4: T.uint8x4 = A[T.int64(0) : T.int64(4)]
A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
B_i8x128 = B[T.int64(0), T.int64(0) : T.int64(128)]
B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32")
C[0:32] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
T.uint32(3),
C[0:32],
B_i32x32,
A_i32,
dtype="int32x32",
)


def conv_approximation(size_a, size_w):
"""Conv approximation."""
Expand Down Expand Up @@ -695,5 +882,24 @@ def test_meta(hexagon_session):
)


def test_non_contiguous():
"""Test Non Contiguous memory lowering."""
sch = tvm.tir.Schedule(conv2d_async_non_contig)
target_hexagon = tvm.target.hexagon("v68", link_params=True)
err_rgx = r"Unable to lower async dma for non contiguous memory access with load index: "
# Currently we do not support non contiguous memory access being lowered to
# async dma so we throw an error.
with pytest.raises(tvm.TVMError, match=err_rgx):
with tvm.transform.PassContext(
config={
"tir.use_async_copy": 1,
"tir.merge_async_commit_queue_scope": 0,
}
):
tvm.build(
sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon)
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 7674ea8

Please sign in to comment.