Skip to content

Commit

Permalink
python3Packages.mmcv: fix build with CUDA support (NixOS#346358)
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorBaker authored Oct 4, 2024
2 parents 059b6fb + a639b84 commit 99db8c9
Showing 1 changed file with 20 additions and 30 deletions.
50 changes: 20 additions & 30 deletions pkgs/development/python-modules/mmcv/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,12 @@
tifffile,
lmdb,
mmengine,
symlinkJoin,
}:

let
inherit (torch) cudaCapabilities cudaPackages cudaSupport;
inherit (cudaPackages) backendStdenv cudaVersion;
inherit (cudaPackages) backendStdenv;

cuda-common-redist = with cudaPackages; [
cuda_cccl # <thrust/*>
libcublas # cublas_v2.h
libcusolver # cusolverDn.h
libcusparse # cusparse.h
];

cuda-native-redist = symlinkJoin {
name = "cuda-native-redist-${cudaVersion}";
paths =
with cudaPackages;
[
cuda_cudart # cuda_runtime.h
cuda_nvcc
]
++ cuda-common-redist;
};

cuda-redist = symlinkJoin {
name = "cuda-redist-${cudaVersion}";
paths = cuda-common-redist;
};
in
buildPythonPackage rec {
pname = "mmcv";
Expand All @@ -65,6 +42,8 @@ buildPythonPackage rec {
hash = "sha256-NNF9sLJWV1q6uBE73LUW4UWwYm4TBMTBJjJkFArBmsc=";
};

env.CUDA_HOME = lib.optionalString cudaSupport (lib.getDev cudaPackages.cuda_nvcc);

preConfigure =
''
export MMCV_WITH_OPS=1
Expand All @@ -77,7 +56,7 @@ buildPythonPackage rec {
'';

postPatch = ''
substituteInPlace setup.py --replace "cpu_use = 4" "cpu_use = $NIX_BUILD_CORES"
substituteInPlace setup.py --replace-fail "cpu_use = 4" "cpu_use = $NIX_BUILD_CORES"
'';

preCheck = ''
Expand All @@ -102,12 +81,23 @@ buildPythonPackage rec {
nativeBuildInputs = [
ninja
which
] ++ lib.optionals cudaSupport [ cuda-native-redist ];
];

buildInputs = [
pybind11
torch
] ++ lib.optionals cudaSupport [ cuda-redist ];
buildInputs =
[
pybind11
torch
]
++ lib.optionals cudaSupport (
with cudaPackages;
[
cuda_cudart # cuda_runtime.h
cuda_cccl # <thrust/*>
libcublas # cublas_v2.h
libcusolver # cusolverDn.h
libcusparse # cusparse.h
]
);

nativeCheckInputs = [
pytestCheckHook
Expand Down

0 comments on commit 99db8c9

Please sign in to comment.