Skip to content

Commit

Permalink
[ROCm] Bring up clang support for JAX+XLA
Browse files Browse the repository at this point in the history
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
  • Loading branch information
Ruturaj4 committed Oct 10, 2024
1 parent ddf8524 commit ec4a8d6
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 24 deletions.
16 changes: 12 additions & 4 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,18 @@ build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc
build:cuda_nvcc --action_env=TF_NVCC_CLANG="1"
build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm_base --repo_env TF_NEED_ROCM=1
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100"

# Build with hipcc for ROCm and clang for the host.
build:rocm --config=rocm_base
build:rocm --action_env=TF_ROCM_CLANG="1"
build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
build:rocm --copt=-Wno-gnu-offsetof-extensions
build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"

build:nonccl --define=no_nccl_support=true

Expand Down
7 changes: 5 additions & 2 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,12 @@ def write_bazelrc(*, remote_build,
f.write(
f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n')
if enable_rocm:
f.write("build --config=rocm\n")
f.write("build --config=rocm_base\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if use_clang:
f.write("build --config=rocm\n")
f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n")
if python_version:
f.write(
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
Expand Down Expand Up @@ -495,7 +498,7 @@ def main():
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx900,gfx906,gfx908,gfx90a,gfx1030",
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--rocm_path",
Expand Down
19 changes: 17 additions & 2 deletions build/rocm/Dockerfile.ms
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ FROM ubuntu:20.04 AS rocm_base
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3

# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCm
Expand Down Expand Up @@ -70,14 +76,23 @@ FROM rocm_base AS rt_build
ARG JAX_VERSION
ARG JAX_COMMIT
ARG XLA_COMMIT
ARG JAX_USE_CLANG
LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
com.amdgpu.python_version="$PYTHON_VERSION" \
com.amdgpu.jax_version="$JAX_VERSION" \
com.amdgpu.jax_commit="$JAX_COMMIT" \
com.amdgpu.xla_commit="$XLA_COMMIT"
# Create a directory to copy and retain the wheels in the image.
RUN mkdir -p /rocm_jax_wheels
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
cp /wheelhouse/* /rocm_jax_wheels/ && \
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl
10 changes: 10 additions & 0 deletions build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,13 @@ ARG ROCM_BUILD_NUM
RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM

ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS}

# Install LLVM 18 and dependencies.
RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y wget && dnf clean all
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project
2 changes: 1 addition & 1 deletion build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _fetch_jax_metadata(xla_path):

cmd = ["python3", "setup.py", "-V"]
env = dict(os.environ)
env["JAX_RELEASE"] = "1"
#env["JAX_RELEASE"] = "1"

jax_version = subprocess.check_output(cmd, env=env)

Expand Down
12 changes: 10 additions & 2 deletions build/rocm/docker/Dockerfile.jax-ubu22
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ FROM ubuntu:22.04
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3

# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*

# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}

# Install ROCM
Expand Down Expand Up @@ -61,4 +67,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \

RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl
10 changes: 9 additions & 1 deletion build/rocm/docker/Dockerfile.jax-ubu24
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ FROM ubuntu:24.04
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3 python3-pip

# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*

# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
Expand Down Expand Up @@ -60,4 +66,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \

RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl
70 changes: 58 additions & 12 deletions build/rocm/tools/build_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,36 @@ def update_rocm_targets(rocm_path, targets):
open(version_fp, "a").close()


def find_clang_path():
llvm_base_path = "/usr/lib/"
# Search for llvm directories and pick the highest version.
llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")]
if llvm_dirs:
# Sort to get the highest llvm version.
llvm_dirs.sort(reverse=True)
clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin")

# Prefer versioned clang binaries (e.g., clang-18).
versioned_clang = None
generic_clang = None

for f in os.listdir(clang_bin_dir):
# Checks for versioned clang binaries.
if f.startswith("clang-") and f[6:].isdigit():
versioned_clang = os.path.join(clang_bin_dir, f)
# Fallback to non-versioned clang.
elif f == "clang":
generic_clang = os.path.join(clang_bin_dir, f)

# Return versioned clang if available, otherwise return generic clang.
if versioned_clang:
return versioned_clang
elif generic_clang:
return generic_clang

return None


def build_jaxlib_wheel(
jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"
):
Expand All @@ -70,6 +100,14 @@ def build_jaxlib_wheel(
"--use_clang=%s" % use_clang,
]

# Add clang path if clang is used.
if compiler == "clang":
clang_path = find_clang_path()
if clang_path:
cmd.append("--clang_path=%s" % clang_path)
else:
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")

if xla_path:
cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path)

Expand Down Expand Up @@ -168,18 +206,26 @@ def to_cpy_ver(python_version):


def fix_wheel(path, jax_path):
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
env = dict(os.environ)
py_bin = "/opt/python/cp310-cp310/bin"
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])

cmd = ["pip", "install", "auditwheel>=6"]
subprocess.run(cmd, check=True, env=env)

fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
try:
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
env = dict(os.environ)
py_bin = "/opt/python/cp310-cp310/bin"
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])

cmd = ["pip", "install", "auditwheel>=6"]
subprocess.run(cmd, check=True, env=env)

fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
LOG.info("Wheel fix completed successfully.")
except subprocess.CalledProcessError as cpe:
LOG.error(f"Subprocess failed with error: {cpe}")
raise
except Exception as e:
LOG.error(f"An unexpected error occurred: {e}")
raise


def parse_args():
Expand Down
1 change: 1 addition & 0 deletions jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ pybind_extension(
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@local_config_rocm//rocm:hip",
"@nanobind",
"@xla//third_party/python_runtime:headers",
"@xla//xla:status",
Expand Down

0 comments on commit ec4a8d6

Please sign in to comment.