From 4c6e8489f77438e3452fb464e64623f8be5188c8 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 10 Nov 2023 16:47:24 -0800 Subject: [PATCH] Add some more validation checks for torch.linalg.eigh and torch.compile --- check_binary.sh | 6 ++++++ test_example_code/torch_compile_smoke.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 test_example_code/torch_compile_smoke.py diff --git a/check_binary.sh b/check_binary.sh index 30b44b535..42ee0e997 100755 --- a/check_binary.sh +++ b/check_binary.sh @@ -404,6 +404,12 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE echo "Test that linalg works" python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))" + echo "Test that linalg.eigh works" + python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(x))" + + echo "Checking that basic torch.compile works" + python ${TEST_CODE_DIR}/torch_compile_smoke.py + popd fi # if libtorch fi # if cuda diff --git a/test_example_code/torch_compile_smoke.py b/test_example_code/torch_compile_smoke.py new file mode 100644 index 000000000..7a12a013e --- /dev/null +++ b/test_example_code/torch_compile_smoke.py @@ -0,0 +1,12 @@ +import torch + + +def foo(x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + torch.cos(x) + + +if __name__ == "__main__": + x = torch.rand(3, 3, device="cuda") + x_eager = foo(x) + x_pt2 = torch.compile(foo)(x) + print(torch.allclose(x_eager, x_pt2))