Skip to content

Commit

Permalink
nchw test worked
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2022
1 parent 2bf68c7 commit c0609ab
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 21 deletions.
15 changes: 4 additions & 11 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,17 +604,10 @@ def conv_backward_data(x, w, pad, stride, dilation, conv_mode, tensor_format, co
x_shape = list(x.shape)

assert isinstance(x.shape[0], tvm.tir.expr.IntImm), "Dynamic batch is not supported for cudnn conv2d backwad data yet."
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
# TODO: fix oshape
oshape = x_shape
oshape[1] = w.shape[1]

algo = conv_backward_data_find_algo(
tensor_format,
pad,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/testing/conv2d_transpose_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
dilated_a_np.shape[2] + bpad_top + bpad_bottom,
dilated_a_np.shape[3] + bpad_left + bpad_right,
)
)
).astype(a_np.dtype)
padded_a_np[
:,
:,
Expand All @@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
# convolution stage
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
b_np = np.zeros((batch, out_c, out_h, out_w))
b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype)
for n in range(batch):
for f in range(out_c):
for c in range(in_c):
Expand Down
17 changes: 9 additions & 8 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,47 +248,48 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0):
# schedule
if tensor_format == 0:
xshape = [batch, in_channel, height, width]
wshape = [in_channel, out_channel, filter_h, filter_w]
wshape = [out_channel, in_channel, filter_h, filter_w]
else:
xshape = [batch, height, width, in_channel]
wshape = [out_channel, filter_h, filter_w, in_channel]

oshape = xshape
oshape[1] = out_channel

dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)

dx_np = tvm.topi.testing.conv2d_transpose_nchw_python(
dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)
)

dy = te.placeholder(oshape, name="dy", dtype=data_dtype)
w = te.placeholder(wshape, name="dw", dtype=data_dtype)
dx = cudnn.conv_forward_backward_data(
dx = cudnn.conv_backward_data(
dy,
w,
[pad_h, pad_w],
[stride_h, stride_w],
[dilation_h, dilation_w],
[1, 1],
conv_mode=1,
tensor_format=tensor_format,
conv_dtype=conv_dtype,
groups=1,
)

s = te.create_schedule(Y.op)
s = te.create_schedule(dx.op)

# validation
dev = tvm.cuda(0)
f = tvm.build(s, [dy, w, x], "cuda --host=llvm", name="conv2d_backward_data")

dx_np = np.zeros(oshape).astype(data_dtype)
f = tvm.build(s, [dy, w, dx], "cuda --host=llvm", name="conv2d_backward_data")

dy = tvm.nd.array(dy_np, dev)
w = tvm.nd.array(w_np, dev)
dx = tvm.nd.array(dx_np, dev)

f(dy, w, dx)
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=1e-2, rtol=1e-2)
print(np.max(np.abs(dx.numpy() - dx_np)))
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=1e-5, rtol=1e-5)


@tvm.testing.requires_gpu
Expand Down

0 comments on commit c0609ab

Please sign in to comment.