Skip to content

Commit

Permalink
Add some more validation checks for torch.linalg.eigh and torch.compile
Browse files Browse the repository at this point in the history
  • Loading branch information
huydhn committed Nov 11, 2023
1 parent ca0040f commit 4c6e848
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions check_binary.sh
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE
echo "Test that linalg works"
python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))"

echo "Test that linalg.eigh works"
python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(x))"

echo "Checking that basic torch.compile works"
python ${TEST_CODE_DIR}/torch_compile_smoke.py

popd
fi # if libtorch
fi # if cuda
Expand Down
12 changes: 12 additions & 0 deletions test_example_code/torch_compile_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch


def foo(x: torch.Tensor) -> torch.Tensor:
return torch.sin(x) + torch.cos(x)


if __name__ == "__main__":
x = torch.rand(3, 3, device="cuda")
x_eager = foo(x)
x_pt2 = torch.compile(foo)(x)
print(torch.allclose(x_eager, x_pt2))

0 comments on commit 4c6e848

Please sign in to comment.