diff --git a/.bazelrc b/.bazelrc index f9f5cfc1b8f6..0dd18d5071d4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -101,10 +101,18 @@ build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain -build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true -build:rocm --repo_env TF_NEED_ROCM=1 -build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" +build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true +build:rocm_base --repo_env TF_NEED_ROCM=1 +build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100" + +# Build with hipcc for ROCm and clang for the host. +build:rocm --config=rocm_base +build:rocm --action_env=TF_ROCM_CLANG="1" +build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" +build:rocm --copt=-Wno-gnu-offsetof-extensions +build:rocm --copt=-Qunused-arguments +build:rocm --action_env=TF_HIPCC_CLANG="1" build:nonccl --define=no_nccl_support=true diff --git a/build/build.py b/build/build.py index de0d5a9817fb..402e7a96f9a4 100755 --- a/build/build.py +++ b/build/build.py @@ -301,9 +301,12 @@ def write_bazelrc(*, remote_build, f.write( f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if enable_rocm: - f.write("build --config=rocm\n") + f.write("build --config=rocm_base\n") if not enable_nccl: f.write("build --config=nonccl\n") + if use_clang: + f.write("build --config=rocm\n") + f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n") if python_version: f.write( "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( @@ -495,7 +498,7 @@ def main(): help="A comma-separated list of CUDA compute capabilities to support.") parser.add_argument( "--rocm_amdgpu_targets", - default="gfx900,gfx906,gfx908,gfx90a,gfx1030", + default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100", help="A comma-separated list of ROCm amdgpu targets to support.") parser.add_argument( "--rocm_path", diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index d1867e6a5c1a..e20291cefd63 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -5,8 +5,14 @@ FROM ubuntu:20.04 AS rocm_base RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 +# Install bzip2 and sqlite3 packages +RUN apt-get update && apt-get install -y \ + sqlite3 libsqlite3-dev \ + libbz2-dev \ + && rm -rf /var/lib/apt/lists/* + # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCm @@ -70,6 +76,7 @@ FROM rocm_base AS rt_build ARG JAX_VERSION ARG JAX_COMMIT ARG XLA_COMMIT +ARG JAX_USE_CLANG LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.python_version="$PYTHON_VERSION" \ @@ -77,7 +84,15 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" + +# Create a directory to copy and retain the wheels in the image. +RUN mkdir -p /rocm_jax_wheels + RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ - pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + cp /wheelhouse/* /rocm_jax_wheels/ && \ + ls -lah /wheelhouse && \ + pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ + pip3 install wheelhouse/*rocm60*.whl + diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index caf303d45ff3..a67a7ecb2e22 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -7,3 +7,13 @@ ARG ROCM_BUILD_NUM RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS} + +# Install LLVM 18 and dependencies. +RUN --mount=type=cache,target=/var/cache/dnf \ + dnf install -y wget && dnf clean all +RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \ + mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ + make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 9fb0ebd77f87..e3a6609095ed 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -130,7 +130,7 @@ def _fetch_jax_metadata(xla_path): cmd = ["python3", "setup.py", "-V"] env = dict(os.environ) - env["JAX_RELEASE"] = "1" + #env["JAX_RELEASE"] = "1" jax_version = subprocess.check_output(cmd, env=env) diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 index ba64efbbc682..eb971482f708 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu22 +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -3,8 +3,14 @@ FROM ubuntu:22.04 RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 +# Install bzip2 and sqlite3 packages +RUN apt-get update && apt-get install -y \ + sqlite3 libsqlite3-dev \ + libbz2-dev \ + && rm -rf /var/lib/apt/lists/* + # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM @@ -61,4 +67,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ - pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + ls -lah /wheelhouse && \ + pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ + pip3 install wheelhouse/*rocm60*.whl diff --git a/build/rocm/docker/Dockerfile.jax-ubu24 b/build/rocm/docker/Dockerfile.jax-ubu24 index 44c59b1b7e6b..da714542e5f8 100644 --- a/build/rocm/docker/Dockerfile.jax-ubu24 +++ b/build/rocm/docker/Dockerfile.jax-ubu24 @@ -3,6 +3,12 @@ FROM ubuntu:24.04 RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 python3-pip +# Install bzip2 and sqlite3 packages +RUN apt-get update && apt-get install -y \ + sqlite3 libsqlite3-dev \ + libbz2-dev \ + && rm -rf /var/lib/apt/lists/* + # Add target file to help determine which device(s) to build for ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} @@ -60,4 +66,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ - pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt + ls -lah /wheelhouse && \ + pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ + pip3 install wheelhouse/*rocm60*.whl diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index f0631f099a35..deb6ab703391 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -56,6 +56,36 @@ def update_rocm_targets(rocm_path, targets): open(version_fp, "a").close() +def find_clang_path(): + llvm_base_path = "/usr/lib/" + # Search for llvm directories and pick the highest version. + llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")] + if llvm_dirs: + # Sort to get the highest llvm version. + llvm_dirs.sort(reverse=True) + clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin") + + # Prefer versioned clang binaries (e.g., clang-18). + versioned_clang = None + generic_clang = None + + for f in os.listdir(clang_bin_dir): + # Checks for versioned clang binaries. + if f.startswith("clang-") and f[6:].isdigit(): + versioned_clang = os.path.join(clang_bin_dir, f) + # Fallback to non-versioned clang. + elif f == "clang": + generic_clang = os.path.join(clang_bin_dir, f) + + # Return versioned clang if available, otherwise return generic clang. + if versioned_clang: + return versioned_clang + elif generic_clang: + return generic_clang + + return None + + def build_jaxlib_wheel( jax_path, rocm_path, python_version, xla_path=None, compiler="gcc" ): @@ -70,6 +100,14 @@ def build_jaxlib_wheel( "--use_clang=%s" % use_clang, ] + # Add clang path if clang is used. + if compiler == "clang": + clang_path = find_clang_path() + if clang_path: + cmd.append("--clang_path=%s" % clang_path) + else: + raise RuntimeError("Clang binary not found in /usr/lib/llvm-*") + if xla_path: cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path) @@ -168,18 +206,26 @@ def to_cpy_ver(python_version): def fix_wheel(path, jax_path): - # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 - # so use one of the CPythons in /opt to run - env = dict(os.environ) - py_bin = "/opt/python/cp310-cp310/bin" - env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) - - cmd = ["pip", "install", "auditwheel>=6"] - subprocess.run(cmd, check=True, env=env) - - fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") - cmd = ["python", fixwheel_path, path] - subprocess.run(cmd, check=True, env=env) + try: + # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 + # so use one of the CPythons in /opt to run + env = dict(os.environ) + py_bin = "/opt/python/cp310-cp310/bin" + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + cmd = ["pip", "install", "auditwheel>=6"] + subprocess.run(cmd, check=True, env=env) + + fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") + cmd = ["python", fixwheel_path, path] + subprocess.run(cmd, check=True, env=env) + LOG.info("Wheel fix completed successfully.") + except subprocess.CalledProcessError as cpe: + LOG.error(f"Subprocess failed with error: {cpe}") + raise + except Exception as e: + LOG.error(f"An unexpected error occurred: {e}") + raise def parse_args(): diff --git a/jaxlib/BUILD b/jaxlib/BUILD index ab60b3fadd37..23845369000f 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -240,6 +240,7 @@ pybind_extension( "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/status", "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:hip", "@nanobind", "@xla//third_party/python_runtime:headers", "@xla//xla:status",