Skip to content

Commit

Permalink
update vendored CUDA includes to match CuPy >= 13.3 (#781)
Browse files Browse the repository at this point in the history
CuPy 13.3 removed dependence on jitify (see: cupy/cupy#8473)

Make the same change in our vendored code (but only if CuPy >= 13.3).

Authors:
  - Gregory Lee (https://github.com/grlee77)

Approvers:
  - https://github.com/jakirkham

URL: #781
  • Loading branch information
grlee77 authored Sep 26, 2024
1 parent cc42019 commit 6e3fd31
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
22 changes: 20 additions & 2 deletions python/cucim/src/cucim/skimage/_vendored/_ndimage_filters_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

import cupy
import numpy
from packaging.version import parse

from cucim.skimage._vendored import (
_internal as internal,
_ndimage_util as _util,
)

CUPY_GTE_13_3_0 = parse(cupy.__version__) >= parse("13.3.0")


def _origins_to_offsets(origins, w_shape):
return tuple(x // 2 + o for x, o in zip(w_shape, origins))
Expand Down Expand Up @@ -182,13 +185,24 @@ def _call_kernel(
return output


_ndimage_includes = r"""
if CUPY_GTE_13_3_0:
_includes = r"""
#include <cupy/cuda_workaround.h> // provide std:: coverage
"""
else:
_includes = r"""
#include <type_traits> // let Jitify handle this
"""

_ndimage_includes = (
_includes
+ r"""
#include <cupy/math_constants.h>
template<> struct std::is_floating_point<float16> : std::true_type {};
template<> struct std::is_signed<float16> : std::true_type {};
"""
)


_ndimage_CAST_FUNCTION = """
Expand Down Expand Up @@ -352,7 +366,11 @@ def _generate_nd_kernel(
if has_mask:
name += "_with_mask"
preamble = _ndimage_includes + _ndimage_CAST_FUNCTION + preamble
options += ("--std=c++11", "-DCUPY_USE_JITIFY")

if CUPY_GTE_13_3_0:
options += ("--std=c++11",)
else:
options += ("--std=c++11", "-DCUPY_USE_JITIFY")
return cupy.ElementwiseKernel(
in_params,
out_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cucim.skimage._vendored import _ndimage_util as util
from cucim.skimage._vendored._internal import _normalize_axis_index
from cucim.skimage._vendored._ndimage_filters_core import (
CUPY_GTE_13_3_0,
_ndimage_CAST_FUNCTION,
_ndimage_includes,
)
Expand Down Expand Up @@ -917,7 +918,10 @@ def _get_separable_conv_kernel(
patch_per_block=patch_per_block,
flip_kernel=flip_kernel,
)
options = ("--std=c++11", "-DCUPY_USE_JITIFY")
if CUPY_GTE_13_3_0:
options = ("--std=c++11",)
else:
options = ("--std=c++11", "-DCUPY_USE_JITIFY")
m = cp.RawModule(code=code, options=options)
return m.get_function(func_name), block, patch_per_block

Expand Down

0 comments on commit 6e3fd31

Please sign in to comment.