Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt gpu ci fix #125

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
24 changes: 20 additions & 4 deletions .buildkite/gpu_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ steps:
queue: "juliagpu"
cuda: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 180
commands: |
pwd
env
echo "--- Setup :python: Dependencies"
mkdir -p .local/bin
export PATH="`pwd`/.local/bin:`pwd`/conda/bin:\$PATH"
Expand All @@ -16,15 +18,29 @@ steps:

mv bazel* .local/bin/bazel
chmod +x .local/bin/bazel
export PATH="`pwd`/.local/bin:\$PATH"

mkdir -p .baztmp

echo "--- :python: Test"

HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --test_output=errors //test/...
HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --cache_test_results=no //test:bench_vs_xla
HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --cache_test_results=no //test:llama
export CUDA_DIR=`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvcc_cu12/site-packages/nvidia/cuda_nvcc
export XLA_FLAGS=--xla_gpu_cuda_data_dir=\$CUDA_DIR
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cusolver_cu12/site-packages/nvidia/cusolver:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cudnn_cu12/site-packages/nvidia/cudnn/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/test.runfiles/pypi_nvidia_cublas_cu12/site-packages/nvidia/cublas/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_cupti_cu12/site-packages/nvidia/cuda_cupti/lib:\$LD_LIBRARY_PATH"
export LD_LIBRARY_PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_runtime_cu12/site-packages/nvidia/cuda_runtime/lib:\$LD_LIBRARY_PATH"
export PATH="`pwd`/bazel-bin/test/llama.runfiles/pypi_nvidia_cuda_nvcc_cu12/site-packages/nvidia/cuda_nvcc/bin:\$PATH"
export TF_CPP_MIN_LOG_LEVEL=0
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp run --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL //builddeps:requirements.update
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --test_output=errors //test/... || echo "fail1"
find `pwd`/bazel-bin/test/llama.runfiles > finds.txt
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:bench_vs_xla || echo "fail2"
HERMETIC_PYTHON_VERSION="3.12" .local/bin/bazel --output_user_root=`pwd`/.baztmp test --repo_env CUDA_DIR --repo_env XLA_FLAGS --action_env XLA_FLAGS --repo_env TF_CPP_MIN_LOG_LEVEL --action_env TF_CPP_MIN_LOG_LEVEL --cache_test_results=no -s //test:llama || echo "fail3"
HERMETIC_PYTHON_VERSION="3.12" bazel-bin/test/llama
cat bazel-out/*/testlogs/test/llama/test.log
artifact_paths:
- "finds.txt"
- "bazel-out/*/testlogs/test/llama/test.log"
- "bazel-out/*/testlogs/test/llama/bench_vs_xla.log"
63 changes: 60 additions & 3 deletions builddeps/requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ idna==3.8 \
--hash=sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac \
--hash=sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603
# via requests
jax==0.4.31 \
jax[cuda12-pip]==0.4.31 ; sys_platform == "linux" \
--hash=sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7 \
--hash=sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287
# via
Expand All @@ -122,14 +122,14 @@ jax-cuda12-pjrt==0.4.31 \
--hash=sha256:3e77d1cfebeca06517254eb568f082037e1a2aa3ed8f63c543492ad8ab5a1585 \
--hash=sha256:8961abb381d893a3c2392ad76ab2067a81f8f2514f3b47d2da3ac24283293fe0
# via jax-cuda12-plugin
jax-cuda12-plugin==0.4.31 ; sys_platform == "linux" \
jax-cuda12-plugin[with-cuda]==0.4.31 \
--hash=sha256:146f26928ca719a0daa14bb9a9f5a5cbfa20211e76ea05b7bb534277a658a995 \
--hash=sha256:5048acdf29755303b2887d948137fd82891b48cdbc086e67640ebb976457cfcc \
--hash=sha256:5cfa46d4106f70c31944c9fafcfa7b04c6f9886a1b35df5477cde7d9c43eb8bd \
--hash=sha256:7a179e5e80dd9890972d777d597f3c902c6876e42bcf2edcfe4f3ec5a610472e \
--hash=sha256:a3727a332fbeac625ab6d5ae63d0ed9e62d1e0b5011f72130c9bc96e797395d5 \
--hash=sha256:cdb6d0c4009438a6a6bd7997ab8a1194beda9ae7322b8d265eea9e551b6c2b4e
# via -r builddeps/test-requirements.txt
# via jax
jaxlib==0.4.31 \
--hash=sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d \
--hash=sha256:1b8e9e6970ecc08bd8b4d80c03d882f4dcd4ac119cb2959811ebc58fce1c263d \
Expand Down Expand Up @@ -234,6 +234,63 @@ numpy==2.1.0 \
# ml-dtypes
# opt-einsum
# scipy
nvidia-cublas-cu12==12.6.1.4 \
--hash=sha256:5dd125ece5469dbdceebe2e9536ad8fc4abd38aa394a7ace42fc8a930a1e81e3 \
--hash=sha256:5e5d384583d72ac364064ced3dd92a5caa59a8a57568595c9f82e83d255b2481 \
--hash=sha256:c25ab29026a265d46c1063b5fb3cb9440f5f2eb88041c6b7c6711bcb3361789f
# via
# jax-cuda12-plugin
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
nvidia-cuda-cupti-cu12==12.6.68 \
--hash=sha256:13408a021727de6473d138a0c5e8080b23437f761508e2b11d2530fed24f4ea0 \
--hash=sha256:5ad6a1fcfcb42c8628f7e547079575116d428d0cb3b4fab98362e08a9ea0b842 \
--hash=sha256:7487f59d73a090bf661fa8da84bad649f019a249dbac3a6cc58b039e15c28d91
# via jax-cuda12-plugin
nvidia-cuda-nvcc-cu12==12.6.68 \
--hash=sha256:3999aa4a42ac8723c09a8aafd06bc4a6ec1a0b05c53bc96c8d6cf195e84f6935 \
--hash=sha256:9c0a18d76f0d1de99ba1d5fd70cffb32c0249e4abc42de9c0504e34d90ff421c \
--hash=sha256:d2faca18a3d5dd48865ad259262f7da43358d0940d53554026102d70c14ea2f9
# via jax-cuda12-plugin
nvidia-cuda-runtime-cu12==12.6.68 \
--hash=sha256:3d421aa4ff608b2d8c650e0208a0fb28b4b6792a35b42bd2769d802149f85238 \
--hash=sha256:806b51a1dd266aac41ae09ca6142faee1686d119ced006cb9b76dfd331c75ab8 \
--hash=sha256:846987485889786d257f6d7bdcf7544a36452936514e20dd710527b896c0fe12
# via jax-cuda12-plugin
nvidia-cudnn-cu12==9.3.0.75 \
--hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \
--hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \
--hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6
# via jax-cuda12-plugin
nvidia-cufft-cu12==11.2.6.59 \
--hash=sha256:251df5b20b11bb2af6d3964ac01b85a94094222d081c90f27e8df3bf533d3257 \
--hash=sha256:2ea19d2101d309228daeb1045397d8e28eb3ec1ec45f226bdc12ac6e9c1c59d4 \
--hash=sha256:998bbd77799dc427f9c48e5d57a316a7370d231fd96121fb018b370f67fc4909
# via jax-cuda12-plugin
nvidia-cusolver-cu12==11.6.4.69 \
--hash=sha256:07d9a1fc00049cba615ec3475eca5320943df3175b05d358d2559286bb7f1fa6 \
--hash=sha256:1c799e473bbd369a34490322ebf6bbf8862831e199f5d6da6868d5f6f7332fff \
--hash=sha256:ec0419e653587d25f399736eaf1d26a6562d8bcaeb44b1e3daef87e13b669963
# via jax-cuda12-plugin
nvidia-cusparse-cu12==12.5.3.3 \
--hash=sha256:76030755020d3a969b40273f43b8c496c4e122ee2a01fd724cf1398421bcadd8 \
--hash=sha256:bfa07cb86edfd6112dbead189c182a924fd9cb3e48ae117b1ac4cd3084078bc0 \
--hash=sha256:c9d0ff7870672b1e0c7ffc1e47e9b87b51e38ad32ae39e05f08fc68933983a80
# via
# jax-cuda12-plugin
# nvidia-cusolver-cu12
nvidia-nccl-cu12==2.22.3 \
--hash=sha256:f9f5e03c00269dee2cd1aa57019f9a024478a74ae6e9b32d5341c849fe6f6302
# via jax-cuda12-plugin
nvidia-nvjitlink-cu12==12.6.68 \
--hash=sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab \
--hash=sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d \
--hash=sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b
# via
# jax-cuda12-plugin
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
opt-einsum==3.3.0 \
--hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \
--hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549
Expand Down
2 changes: 1 addition & 1 deletion builddeps/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ absl-py
jax
numpy
jaxlib
jax-cuda12-plugin; sys_platform == 'linux'
jax[cuda12_pip]; sys_platform == 'linux'
requests; sys_platform == 'linux'
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
libtpu-nightly == 0.1.dev20240729; sys_platform == 'linux'
4 changes: 2 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")

devices = []
CurBackends = [jax.default_backend()]
CurBackends = ["cuda"] #jax.default_backend()]

if jax.default_backend() != "cpu":
if "cuda" != "cpu":
devices = CurBackends

AllBackends = ["cpu"] + devices
Expand Down
Loading