diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index d01684f774710..bc7669d3d95c6 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -264,7 +264,17 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e- dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype) w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) - dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)) + if data_dtype == "float16": + dx_np = ref_func( + dy_np.astype("float32"), + w_np.astype("float32"), + (stride_h, stride_w), + (pad_h, pad_w), + (0, 0), + ) + dx_np = dx_np.astype("float16") + else: + dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), (0, 0)) dy = te.placeholder(oshape, name="dy", dtype=data_dtype) w = te.placeholder(wshape, name="dw", dtype=data_dtype) @@ -292,6 +302,7 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e- f(dy, w, dx) print(np.max(np.abs(dx.numpy() - dx_np))) + print(np.mean(np.abs(dx.numpy() - dx_np))) tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol) @@ -300,6 +311,9 @@ def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, tol=1e- def test_conv2d_backward_data(): verify_conv2d_backward_data("float32", "float32", tensor_format=0, tol=1e-5) verify_conv2d_backward_data("float32", "float32", tensor_format=1, tol=1e-2) + # The scipy convolve function does not support fp16, so the reference will be computed with + # fp32. Use larger tolerance to be on the safe side (1e-2 also seems mostly ok). + verify_conv2d_backward_data("float16", "float16", tensor_format=1, tol=1e-1) test_kwargs_default_2d = {