Skip to content

Commit

Permalink
conv1d_transpose speedup. (apache#6840)
Browse files Browse the repository at this point in the history
Improve performance of transposed convolution by avoiding
redundant multiplication by zero values from dilated data.

Co-authored-by: Ubuntu <ubuntu@ip-172-31-74-104.ec2.internal>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 4, 2020
1 parent 12ab455 commit d8fb974
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
75 changes: 36 additions & 39 deletions python/tvm/topi/cuda/conv1d_transpose_ncw.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,29 +65,46 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype, output_p
out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + output_padding
pad_left = kernel_size - 1 - pad_left
pad_right = kernel_size - 1 - pad_right + output_padding
dilated_width = stride * (inp_width - 1) + 1
data = te.compute(
(batch, inp_channels, pad_left + dilated_width + pad_right),
padded_width = pad_left + inp_width + pad_right

padded_data = te.compute(
(batch, inp_channels, padded_width),
lambda n, c, x: tvm.tir.if_then_else(
tvm.tir.all(
x >= pad_left,
x < pad_left + dilated_width,
tvm.tir.indexmod(x - pad_left, stride).equal(0),
),
data[n, c, tvm.tir.indexdiv(x - pad_left, stride)],
tvm.tir.all(x >= pad_left, x < pad_left + inp_width),
data[n, c, x - pad_left],
tvm.tir.const(0.0, "float32"),
),
name="data_pad",
)

dc = te.reduce_axis((0, inp_channels), name="dc")
dw = te.reduce_axis((0, kernel_size), name="dw")
padded_kernel = te.compute(
(inp_channels, out_channels, kernel_size + stride - 1),
lambda ci, co, k: tvm.tir.if_then_else(
tvm.tir.all(k < kernel_size),
kernel[ci, co, kernel_size - k - 1],
tvm.tir.const(0.0, "float32"),
),
name="kernel_pad",
)

ci = te.reduce_axis((0, inp_channels), name="ci")
k = te.reduce_axis((0, tvm.tir.indexdiv(kernel_size + stride - 1, stride)), name="k")
border = pad_left * (stride - 1)

# Skip multiplication by 0 values in the input data inserted when stride is greater then 1.
# During multiplication of kernel by padded data:
# Kernel indices are: 0, 1 * stride, 2 * stride, ..., ceil(kernel_size / stride) plus
# data offset mod stride
data_out = te.compute(
(batch, out_channels, out_width),
lambda b, c, w: te.sum(
data[b, dc, w + dw].astype(out_dtype)
* kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype),
axis=[dc, dw],
lambda b, co, w: te.sum(
padded_data[b, ci, tvm.tir.indexdiv(border + w + stride - 1, stride) + k].astype(
out_dtype
)
* padded_kernel[
ci, co, k * stride + tvm.tir.indexmod(stride - w - border, stride)
].astype(out_dtype),
axis=[ci, k],
),
tag="conv1d_transpose_ncw",
)
Expand Down Expand Up @@ -118,8 +135,8 @@ def schedule_conv1d_transpose_ncw(cfg, outs):

def _callback(op):
if op.tag == "conv1d_transpose_ncw":
pad_data = op.input_tensors[0]
kernel = op.input_tensors[1]
padded_data = op.input_tensors[0]
padded_kernel = op.input_tensors[1]
conv = op.output(0)

##### space definition begin #####
Expand All @@ -139,9 +156,6 @@ def _callback(op):

##### space definition end #####

if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, "local")
Expand All @@ -150,10 +164,8 @@ def _callback(op):
s[conv].set_scope("local")
OL = conv

# create cache stage
s[pad_data].set_scope("shared")
AA = pad_data
WW = s.cache_read(kernel, "shared", [OL])
s[padded_kernel].compute_inline()
s[padded_data].compute_inline()

# tile and bind spatial axes
n, f, x = s[output].op.axis
Expand All @@ -172,28 +184,13 @@ def _callback(op):

s[output].bind(tx, te.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tx)
# number of threads
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
n_tx = cfg["tile_x"].size[2]

# tile reduction axes
n, f, x = s[OL].op.axis
rc, rx = s[OL].op.reduce_axis
rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc)
s[OL].reorder(rco, rcm, rx, rci, n, f, x)

s[AA].compute_at(s[OL], rx)
s[WW].compute_at(s[OL], rx)

# cooperative fetching
for load in [AA, WW]:
n, f, x = s[load].op.axis
fused = s[load].fuse(f, x)
tz, fused = s[load].split(fused, nparts=n_tz)
tx, fused = s[load].split(fused, nparts=n_tx)
s[load].bind(tz, te.thread_axis("threadIdx.y"))
s[load].bind(tx, te.thread_axis("threadIdx.x"))

s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)

Expand Down
4 changes: 4 additions & 0 deletions tests/python/topi/python/test_topi_conv1d_transpose_ncw.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def test_conv1d_transpose_ncw():
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,))
verify_conv1d_transpose_ncw(1, 2, 1024, 1, 128, 128, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 128, 128, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 1024, 2, 2, 2, 0, (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,))
verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,))
verify_conv1d_transpose_ncw(1, 257, 128, 1, 512, 128, 256, (0,))


if __name__ == "__main__":
Expand Down

0 comments on commit d8fb974

Please sign in to comment.