Skip to content

Commit

Permalink
fp16 also works
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 18, 2022
1 parent c2a34d4 commit 211a58b
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,17 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype)
w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)

dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0))
if data_dtype == "float16":
dx_np = ref_func(
dy_np.astype("float32"),
w_np.astype("float32"),
(stride_h, stride_w),
(pad_h, pad_w),
(0, 0),
)
dx_np = dx_np.astype("float16")
else:
dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0))

dy = te.placeholder(oshape, name="dy", dtype=data_dtype)
w = te.placeholder(wshape, name="dw", dtype=data_dtype)
Expand Down Expand Up @@ -292,6 +302,7 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-

f(dy, w, dx)
print(np.max(np.abs(dx.numpy() - dx_np)))
print(np.mean(np.abs(dx.numpy() - dx_np)))
tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol)


Expand All @@ -300,6 +311,9 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e-
def test_conv2d_backward_data():
verify_conv2d_backward_data("float32", "float32", tensor_format=0, tol=1e-5)
verify_conv2d_backward_data("float32", "float32", tensor_format=1, tol=1e-2)
# The scipy convolve function does not support fp16, so the reference will be computed with
# fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok).
verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1)


test_kwargs_default_2d = {
Expand Down

0 comments on commit 211a58b

Please sign in to comment.