diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py index 5cb3cd3f57d5..28f650bca95f 100644 --- a/python/tvm/topi/x86/concat.py +++ b/python/tvm/topi/x86/concat.py @@ -19,11 +19,12 @@ import tvm from tvm import te import numpy as np -from ..utils import get_const_int, const_vector +from ..utils import get_const_int def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): - """Join a sequence of arrays along an existing axis. Optimized for CPU exeution. + """Join a sequence of arrays along an existing axis. + Optimized for CPU execution. Parameters ---------- @@ -38,48 +39,45 @@ def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0): ret : tvm.te.Tensor """ - def gen_ir_1d(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf): - """Custom conactenation execution.""" + in_outers = [int(np.prod(i.shape[axis:])) for i in data] + in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]] + + def gen_ir_1d(data_bufs, out_buf): + """Custom concatenation execution.""" i_b = tvm.tir.ir_builder.create() data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] out_buf = i_b.buffer_ptr(out_buf) - outers = i_b.buffer_ptr(in_outers_tensor) - cumsum = i_b.buffer_ptr(in_cumsum_tensor) + for i in range(len(data)): - with i_b.for_range(0, outers[i], name="j") as j: - out_buf[cumsum[i] + j] = data_bufs1[i][j] + with i_b.for_range(0, in_outers[i], name="j") as j: + out_buf[in_outers_cumsum[i] + j] = data_bufs1[i][j] return i_b.get() - def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer): - """Common case of conactenation execution.""" + def gen_ir(data_bufs, out_buf, inner, outer): + """Common case of concatenation execution.""" i_b = tvm.tir.ir_builder.create() data_bufs1 = [i_b.buffer_ptr(data_buf) for data_buf in data_bufs] out_buf = i_b.buffer_ptr(out_buf) - outers = i_b.buffer_ptr(in_outers_tensor) - cumsum = i_b.buffer_ptr(in_cumsum_tensor) if inner > 1: with i_b.for_range(0, inner, name="inn", kind="parallel") as inn: pos = inn * outer for i in range(len(data)): - offset = inn * outers[i] - with i_b.for_range(0, outers[i], name="j") as j: - out_buf[pos + cumsum[i] + j] = data_bufs1[i][offset + j] + offset = inn * in_outers[i] + with i_b.for_range(0, in_outers[i], name="j") as j: + out_buf[pos + in_outers_cumsum[i] + j] = data_bufs1[i][offset + j] else: for i in range(len(data)): - with i_b.for_range(0, outers[i], name="j", kind="parallel") as j: - out_buf[cumsum[i] + j] = data_bufs1[i][j] + with i_b.for_range(0, in_outers[i], name="j", kind="parallel") as j: + out_buf[in_outers_cumsum[i] + j] = data_bufs1[i][j] return i_b.get() if axis < 0: axis += len(data[0].shape) concat_axis_sizes = [int(t.shape[axis]) for t in data] join_size = int(np.sum(concat_axis_sizes)) - in_outers = [int(np.prod(i.shape[axis:])) for i in data] - in_outers_cumsum = [0, *np.cumsum(in_outers, dtype="int64")[0:-1]] + dtype = data[0].dtype out_shape = data[0].shape[:axis] + [join_size] + data[0].shape[axis + 1 :] - in_outers_tensor = const_vector(in_outers) - in_cumsum_tensor = const_vector(in_outers_cumsum, name="cumsum") right_val = np.prod(out_shape[axis:]) left_val = np.prod(out_shape[:axis]) @@ -92,8 +90,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) # badly parallelized case return te.extern( [out_shape], - list(data) + [in_outers_tensor, in_cumsum_tensor], - lambda ins, outs: gen_ir_1d(ins, ins[-2], ins[-1], outs[0]), + list(data), + lambda ins, outs: gen_ir_1d(ins, outs[0]), dtype=dtype, name="concatenate_ext", ) @@ -102,8 +100,8 @@ def gen_ir(data_bufs, in_outers_tensor, in_cumsum_tensor, out_buf, inner, outer) outer = get_const_int(int(right_val)) return te.extern( [out_shape], - list(data) + [in_outers_tensor, in_cumsum_tensor], - lambda ins, outs: gen_ir(ins, ins[-2], ins[-1], outs[0], inner, outer), + list(data), + lambda ins, outs: gen_ir(ins, outs[0], inner, outer), dtype=dtype, name="concatenate_ext", ) diff --git a/tests/python/unittest/test_micro_model_library_format.py b/tests/python/unittest/test_micro_model_library_format.py index 0caae1cdd9d4..7be5037478b1 100644 --- a/tests/python/unittest/test_micro_model_library_format.py +++ b/tests/python/unittest/test_micro_model_library_format.py @@ -460,7 +460,7 @@ def test_export_byoc_c_module(): "constants_size_bytes": 0, "device": 1, "io_size_bytes": 4800, - "workspace_size_bytes": 1264, + "workspace_size_bytes": 1200, } ] else: @@ -469,7 +469,7 @@ def test_export_byoc_c_module(): "constants_size_bytes": 0, "device": 1, "io_size_bytes": 4800, - "workspace_size_bytes": 1248, + "workspace_size_bytes": 1200, } ]