Skip to content

Commit

Permalink
fix gpu build
Browse files Browse the repository at this point in the history
  • Loading branch information
smjleo committed Sep 8, 2024
1 parent 57e5754 commit ff5cf8b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
14 changes: 14 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ build --define=allow_oversize_protos=true

build -c opt

build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NVCC_CLANG=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.4.0"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
build:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain"
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
build:cuda --@local_config_cuda//:cuda_compiler=nvcc
2 changes: 1 addition & 1 deletion .buildkite/gpu_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ steps:
# CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --test_output=errors //test/...
# CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no //test:bench_vs_xla
CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no //test:llama
CARGO_BAZEL_REPIN=true HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --check_visibility=false --cache_test_results=no --config=cuda //test:llama
cat bazel-out/*/testlogs/test/llama/test.log
artifact_paths:
- "bazel-out/*/testlogs/test/llama/test.log"
Expand Down
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ http_archive(
)

load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "aeb4b1c7dd12860c3e022df5c4bf4db9d06007ba"
XLA_COMMIT = "7d4f8d1e8a91e67a713ac69796a22f343d292327"
http_archive(
name = "xla",
#sha256 = XLA_SHA256,
Expand Down
21 changes: 8 additions & 13 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,21 @@ BAZEL_BUILD_FLAGS+=(--define=no_ignite_support=true)
BAZEL_BUILD_FLAGS+=(--define=grpc_no_ares=true)

BAZEL_BUILD_FLAGS+=(--define=llvm_enable_zlib=false)

BAZEL_BUILD_FLAGS+=(--verbose_failures)
BAZEL_BUILD_FLAGS+=(--cxxopt=-std=c++17 --host_cxxopt=-std=c++17)
BAZEL_BUILD_FLAGS+=(--cxxopt=-DTCP_USER_TIMEOUT=0)
BAZEL_BUILD_FLAGS+=(--check_visibility=false)
BAZEL_BUILD_FLAGS+=(--experimental_cc_shared_library)

export TMPDIR=$HOME/.eqsat-tmp
export TMP=$TMPDIR
export TEMP=$TMPDIR
BAZEL_BUILD_FLAGS+=(--action_env=TMP=$TMPDIR --action_env=TEMP=$TMPDIR --action_env=TMPDIR=$TMPDIR --sandbox_tmpfs_path=$TMPDIR)

export CUDA_HOME=/usr/local/cuda
export CUDA_HOME=$HOME/miniconda3/
export PATH=$PATH:$CUDA_HOME/bin
export CUDACXX=$CUDA_HOME/bin/nvcc
BAZEL_BUILD_FLAGS+=(--repo_env TF_NEED_CUDA=1)
BAZEL_BUILD_FLAGS+=(--repo_env TF_CUDA_VERSION=12.3)
BAZEL_BUILD_FLAGS+=(--repo_env TF_CUDA_PATHS="$CUDA_HOME,/usr/lib/x86_64-linux-gnu,/usr/include")
BAZEL_BUILD_FLAGS+=(--repo_env TF_NCCL_USE_STUB=1)
BAZEL_BUILD_FLAGS+=(--action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90")
BAZEL_BUILD_FLAGS+=(--crosstool_top=@local_config_cuda//crosstool:toolchain)
BAZEL_BUILD_FLAGS+=(--@local_config_cuda//:enable_cuda)
BAZEL_BUILD_FLAGS+=(--@xla//xla/python:enable_gpu=true)
BAZEL_BUILD_FLAGS+=(--@xla//xla/python:jax_cuda_pip_rpaths=true)
BAZEL_BUILD_FLAGS+=(--define=xla_python_enable_gpu=true)
bazel build ${BAZEL_BUILD_FLAGS[@]} :enzyme_ad
pip install bazel-bin/enzyme_ad-0.0.6-py3-none-manylinux2014_x86_64.whl --force-reinstall --no-deps
BAZEL_BUILD_FLAGS+=(--config=cuda)
HERMETIC_PYTHON_VERSION=3.12 bazel build ${BAZEL_BUILD_FLAGS[@]} :wheel

0 comments on commit ff5cf8b

Please sign in to comment.