diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 5b1681c48..3eda797a3 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -177,6 +177,12 @@ def smoke_test_cuda(package: str, runtime_error_check: str, torch_compile_check: print(f"torch cudnn: {torch.backends.cudnn.version()}") print(f"cuDNN enabled? {torch.backends.cudnn.enabled}") + torch.cuda.init() + print(f"CUDA initialized successfully") + print(f"Number of CUDA devices: {torch.cuda.device_count()}") + for i in range(torch.cuda.device_count()): + print(f"Device {i}: {torch.cuda.get_device_name(i)}") + # nccl is availbale only on Linux if (sys.platform in ["linux", "linux2"]): print(f"torch nccl version: {torch.cuda.nccl.version()}")