diff --git a/3rdparty/LICENSE.pba+ b/3rdparty/LICENSE.pba+ new file mode 100644 index 000000000..9d0b4030a --- /dev/null +++ b/3rdparty/LICENSE.pba+ @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 School of Computing, National University of Singapore + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LICENSE-3rdparty.md b/LICENSE-3rdparty.md index 0a3789af7..984d0e952 100644 --- a/LICENSE-3rdparty.md +++ b/LICENSE-3rdparty.md @@ -281,3 +281,9 @@ StainTools - https://github.com/Peter554/StainTools/blob/master/LICENSE.txt - Copyright: Peter Byfield - Usage: reference for stain color normalization algorithm + +PBA+ +- License: MIT License + - https://github.com/orzzzjq/Parallel-Banding-Algorithm-plus/blob/master/LICENSE +- Copyright: School of Computing, National University of Singapore +- Usage: PBA+ is used to implement the Euclidean distance transform. diff --git a/python/cucim/src/cucim/core/operations/morphology/__init__.py b/python/cucim/src/cucim/core/operations/morphology/__init__.py new file mode 100644 index 000000000..a9c676edb --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/__init__.py @@ -0,0 +1,3 @@ +from ._distance_transform import distance_transform_edt + +__all__ = ["distance_transform_edt"] diff --git a/python/cucim/src/cucim/core/operations/morphology/_distance_transform.py b/python/cucim/src/cucim/core/operations/morphology/_distance_transform.py new file mode 100644 index 000000000..e781938b2 --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/_distance_transform.py @@ -0,0 +1,182 @@ +import numpy as np + +from ._pba_2d import _pba_2d +from ._pba_3d import _pba_3d + +# TODO: support sampling distances +# support the distances and indices output arguments +# support chamfer, chessboard and l1/manhattan distances too? + + +def distance_transform_edt(image, sampling=None, return_distances=True, + return_indices=False, distances=None, indices=None, + *, block_params=None, float64_distances=False): + """Exact Euclidean distance transform. + + This function calculates the distance transform of the `input`, by + replacing each foreground (non-zero) element, with its shortest distance to + the background (any zero-valued element). + + In addition to the distance transform, the feature transform can be + calculated. In this case the index of the closest background element to + each foreground element is returned in a separate array. + + Parameters + ---------- + image : array_like + Input data to transform. Can be any type but will be converted into + binary: 1 wherever image equates to True, 0 elsewhere. + sampling : float, or sequence of float, optional + Spacing of elements along each dimension. If a sequence, must be of + length equal to the image rank; if a single number, this is used for + all axes. If not specified, a grid spacing of unity is implied. + return_distances : bool, optional + Whether to calculate the distance transform. + return_indices : bool, optional + Whether to calculate the feature transform. + distances : float32 cupy.ndarray, optional + An output array to store the calculated distance transform, instead of + returning it. `return_distances` must be True. It must be the same + shape as `image`. + indices : int32 cupy.ndarray, optional + An output array to store the calculated feature transform, instead of + returning it. `return_indicies` must be True. Its shape must be + `(image.ndim,) + image.shape`. + + Other Parameters + ---------------- + block_params : 3-tuple of int + The m1, m2, m3 algorithm parameters as described in [2]_. If None, + suitable defaults will be chosen. Note: This parameter is specific to + cuCIM and does not exist in SciPy. + float64_distances : bool, optional + If True, use double precision in the distance computation (to match + SciPy behavior). Otherwise, single precision will be used for + efficiency. Note: This parameter is specific to cuCIM and does not + exist in SciPy. + + Returns + ------- + distances : float64 ndarray, optional + The calculated distance transform. Returned only when + `return_distances` is True and `distances` is not supplied. It will + have the same shape as `image`. + indices : int32 ndarray, optional + The calculated feature transform. It has an image-shaped array for each + dimension of the image. See example below. Returned only when + `return_indices` is True and `indices` is not supplied. + + Notes + ----- + The Euclidean distance transform gives values of the Euclidean distance:: + + n + y_i = sqrt(sum (x[i]-b[i])**2) + i + + where b[i] is the background point (value 0) with the smallest Euclidean + distance to input points x[i], and n is the number of dimensions. + + Note that the `indices` output may differ from the one given by + `scipy.ndimage.distance_transform_edt` in the case of input pixels that are + equidistant from multiple background points. + + The parallel banding algorithm implemented here was originally described in + [1]_. The kernels used here correspond to the revised PBA+ implementation + that is described on the author's website [2]_. The source code of the + author's PBA+ implementation is available at [3]_. + + References + ---------- + ..[1] Thanh-Tung Cao, Ke Tang, Anis Mohamed, and Tiow-Seng Tan. 2010. + Parallel Banding Algorithm to compute exact distance transform with the + GPU. In Proceedings of the 2010 ACM SIGGRAPH symposium on Interactive + 3D Graphics and Games (I3D ’10). Association for Computing Machinery, + New York, NY, USA, 83–90. + DOI:https://doi.org/10.1145/1730804.1730818 + .. [2] https://www.comp.nus.edu.sg/~tants/pba.html + .. [3] https://github.com/orzzzjq/Parallel-Banding-Algorithm-plus + + Examples + -------- + >>> import cupy as cp + >>> from cucim.core.operations import morphology + >>> a = cp.array(([0,1,1,1,1], + ... [0,0,1,1,1], + ... [0,1,1,1,1], + ... [0,1,1,1,0], + ... [0,1,1,0,0])) + >>> morphology.distance_transform_edt(a) + array([[ 0. , 1. , 1.4142, 2.2361, 3. ], + [ 0. , 0. , 1. , 2. , 2. ], + [ 0. , 1. , 1.4142, 1.4142, 1. ], + [ 0. , 1. , 1.4142, 1. , 0. ], + [ 0. , 1. , 1. , 0. , 0. ]]) + + With a sampling of 2 units along x, 1 along y: + + >>> morphology.distance_transform_edt(a, sampling=[2,1]) + array([[ 0. , 1. , 2. , 2.8284, 3.6056], + [ 0. , 0. , 1. , 2. , 3. ], + [ 0. , 1. , 2. , 2.2361, 2. ], + [ 0. , 1. , 2. , 1. , 0. ], + [ 0. , 1. , 1. , 0. , 0. ]]) + + Asking for indices as well: + + >>> edt, inds = morphology.distance_transform_edt(a, return_indices=True) + >>> inds + array([[[0, 0, 1, 1, 3], + [1, 1, 1, 1, 3], + [2, 2, 1, 3, 3], + [3, 3, 4, 4, 3], + [4, 4, 4, 4, 4]], + [[0, 0, 1, 1, 4], + [0, 1, 1, 1, 4], + [0, 0, 1, 4, 4], + [0, 0, 3, 3, 4], + [0, 0, 3, 3, 4]]]) + + """ + if distances is not None: + raise NotImplementedError( + "preallocated distances image is not supported" + ) + if indices is not None: + raise NotImplementedError( + "preallocated indices image is not supported" + ) + scalar_sampling = None + if sampling is not None: + sampling = np.unique(np.atleast_1d(sampling)) + if len(sampling) == 1: + scalar_sampling = float(sampling) + sampling = None + else: + raise NotImplementedError( + "non-uniform values in sampling is not currently supported" + ) + + if image.ndim == 3: + pba_func = _pba_3d + elif image.ndim == 2: + pba_func = _pba_2d + else: + raise NotImplementedError( + "Only 2D and 3D distance transforms are supported.") + + vals = pba_func( + image, + sampling=sampling, + return_distances=return_distances, + return_indices=return_indices, + block_params=block_params + ) + + if return_distances and scalar_sampling is not None: + vals = (vals[0] * scalar_sampling,) + vals[1:] + + if len(vals) == 1: + vals = vals[0] + + return vals diff --git a/python/cucim/src/cucim/core/operations/morphology/_pba_2d.py b/python/cucim/src/cucim/core/operations/morphology/_pba_2d.py new file mode 100644 index 000000000..edb10a983 --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/_pba_2d.py @@ -0,0 +1,300 @@ +import math +import os + +import cupy + +pba2d_defines_template = """ + +// MARKER is used to mark blank pixels in the texture. +// Any uncolored pixels will have x = MARKER. +// Input texture should have x = MARKER for all pixels other than sites +#define MARKER {marker} +#define BLOCKSIZE {block_size_2d} +#define pixel_int2_t {pixel_int2_t} // typically short2 (int2 for images with > 32k pixels per side) +#define make_pixel(x, y) {make_pixel_func}(x, y) // typically make_short2 (make_int2 images with > 32k pixels per side + +""" # noqa + + +def _init_marker(int_dtype): + """use a minimum value that is appropriate to the integer dtype""" + if int_dtype == cupy.int16: + # marker = cupy.iinfo(int_dtype).min + marker = -32768 + elif int_dtype == cupy.int32: + # divide by two so we don't have to promote other intermediate int + # variables to 64-bit int + marker = -2147483648 // 2 + else: + raise ValueError( + "expected int_dtype to be either cupy.int16 or cupy.int32" + ) + return marker + + +@cupy.memoize(True) +def get_pba2d_src(block_size_2d=64, marker=-32768, pixel_int2_t='short2'): + make_pixel_func = 'make_' + pixel_int2_t + + pba2d_code = pba2d_defines_template.format( + block_size_2d=block_size_2d, + marker=marker, + pixel_int2_t=pixel_int2_t, + make_pixel_func=make_pixel_func + ) + kernel_directory = os.path.join(os.path.dirname(__file__), 'cuda') + with open(os.path.join(kernel_directory, 'pba_kernels_2d.h'), 'rt') as f: + pba2d_kernels = '\n'.join(f.readlines()) + + pba2d_code += pba2d_kernels + return pba2d_code + + +def _get_block_size(check_warp_size=False): + if check_warp_size: + dev = cupy.cuda.runtime.getDevice() + device_properties = cupy.cuda.runtime.getDeviceProperties(dev) + return int(device_properties['warpSize']) + else: + return 32 + + +def _pack_int2(arr, marker=-32768, int_dtype=cupy.int16): + if arr.ndim != 2: + raise ValueError("only 2d arr suppported") + input_x = cupy.zeros(arr.shape, dtype=int_dtype) + input_y = cupy.zeros(arr.shape, dtype=int_dtype) + # TODO: create custom kernel for setting values in input_x, input_y + cond = arr == 0 + y, x = cupy.where(cond) + input_x[cond] = x + mask = arr != 0 + input_x[mask] = marker # 1 << 32 + input_y[cond] = y + input_y[mask] = marker # 1 << 32 + int2_dtype = cupy.dtype({'names': ['x', 'y'], 'formats': [int_dtype] * 2}) + # in C++ code x is the contiguous axis and corresponds to width + # y is the non-contiguous axis and corresponds to height + # given that, store input_x as the last axis here + return cupy.squeeze( + cupy.stack((input_x, input_y), axis=-1).view(int2_dtype) + ) + + +def _unpack_int2(img, make_copy=False, int_dtype=cupy.int16): + temp = img.view(int_dtype).reshape(img.shape + (2,)) + if make_copy: + temp = temp.copy() + return temp + + +def _determine_padding(shape, padded_size, block_size): + # all kernels assume equal size along both axes, so pad up to equal size if + # shape is not isotropic + orig_sy, orig_sx = shape + if orig_sx != padded_size or orig_sy != padded_size: + padding_width = ( + (0, padded_size - orig_sy), (0, padded_size - orig_sx) + ) + else: + padding_width = None + return padding_width + + +def _pba_2d(arr, sampling=None, return_distances=True, return_indices=False, + block_params=None, check_warp_size=False, *, + float64_distances=False): + + # input_arr: a 2D image + # For each site at (x, y), the pixel at coordinate (x, y) should contain + # the pair (x, y). Pixels that are not sites should contain the pair + # (MARKER, MARKER) + + # Note: could query warp size here, but for now just assume 32 to avoid + # overhead of querying properties + block_size = _get_block_size(check_warp_size) + + if sampling is not None: + raise NotImplementedError("sampling not yet supported") + # if len(sampling) != 2: + # raise ValueError("sampling must be a sequence of two values.") + + padded_size = math.ceil(max(arr.shape) / block_size) * block_size + if block_params is None: + # should be <= size / 64. sy must be a multiple of m1 + m1 = max(1, min(padded_size // block_size, 32)) + + # size must be a multiple of m2 + m2 = max(1, min(padded_size // block_size, 32)) + # m2 must also be a power of two + m2 = 2**math.floor(math.log2(m2)) + if padded_size % m2 != 0: + raise RuntimeError("error in setting default m2") + + # should be <= 64. texture size must be a multiple of m3 + m3 = min(min(m1, m2), 2) + else: + m1, m2, m3 = block_params + + if m1 > padded_size // block_size: + raise ValueError("m1 too large. must be <= arr.shape[0] // 32") + if m2 > padded_size // block_size: + raise ValueError("m2 too large. must be <= arr.shape[1] // 32") + for m in (m1, m2, m3): + if padded_size % m != 0: + raise ValueError( + f"Largest dimension of image ({padded_size}) must be evenly " + f"disivible by each element of block_params: {(m1, m2, m3)}." + ) + + shape_max = max(arr.shape) + if shape_max <= 32768: + int_dtype = cupy.int16 + pixel_int2_type = 'short2' + else: + if shape_max > (1 << 24): + # limit to coordinate range to 2**24 due to use of __mul24 in + # coordinate TOID macro + raise ValueError( + f"maximum axis size of {1 << 24} exceeded, for image with " + f"shape {arr.shape}" + ) + int_dtype = cupy.int32 + pixel_int2_type = 'int2' + + marker = _init_marker(int_dtype) + + orig_sy, orig_sx = arr.shape + padding_width = _determine_padding(arr.shape, padded_size, block_size) + if padding_width is not None: + arr = cupy.pad(arr, padding_width, mode='constant', constant_values=1) + size = arr.shape[0] + + input_arr = _pack_int2(arr, marker=marker, int_dtype=int_dtype) + output = cupy.zeros_like(input_arr) + + int2_dtype = cupy.dtype({'names': ['x', 'y'], 'formats': [int_dtype] * 2}) + margin = cupy.empty((2 * m1 * size,), dtype=int2_dtype) + + # phase 1 of PBA. m1 must divide texture size and be <= 64 + pba2d = cupy.RawModule( + code=get_pba2d_src( + block_size_2d=block_size, + marker=marker, + pixel_int2_t=pixel_int2_type, + ) + ) + kernelFloodDown = pba2d.get_function('kernelFloodDown') + kernelFloodUp = pba2d.get_function('kernelFloodUp') + kernelPropagateInterband = pba2d.get_function('kernelPropagateInterband') + kernelUpdateVertical = pba2d.get_function('kernelUpdateVertical') + kernelProximatePoints = pba2d.get_function('kernelProximatePoints') + kernelCreateForwardPointers = pba2d.get_function( + 'kernelCreateForwardPointers' + ) + kernelMergeBands = pba2d.get_function('kernelMergeBands') + kernelDoubleToSingleList = pba2d.get_function('kernelDoubleToSingleList') + kernelColor = pba2d.get_function('kernelColor') + + block = (block_size, 1, 1) + grid = (math.ceil(size / block[0]), m1, 1) + bandSize1 = size // m1 + # kernelFloodDown modifies input_arr in-place + kernelFloodDown( + grid, + block, + (input_arr, input_arr, size, bandSize1), + ) + # kernelFloodUp modifies input_arr in-place + kernelFloodUp( + grid, + block, + (input_arr, input_arr, size, bandSize1), + ) + # kernelFloodUp fills values into margin + kernelPropagateInterband( + grid, + block, + (input_arr, margin, size, bandSize1), + ) + # kernelUpdateVertical stores output into an intermediate array of + # transposed shape + kernelUpdateVertical( + grid, + block, + (input_arr, margin, output, size, bandSize1), + ) + + # phase 2 + block = (block_size, 1, 1) + grid = (math.ceil(size / block[0]), m2, 1) + bandSize2 = size // m2 + kernelProximatePoints( + grid, + block, + (output, input_arr, size, bandSize2), + ) + kernelCreateForwardPointers( + grid, + block, + (input_arr, input_arr, size, bandSize2), + ) + # Repeatly merging two bands into one + noBand = m2 + while noBand > 1: + grid = (math.ceil(size / block[0]), noBand // 2) + kernelMergeBands( + grid, + block, + (output, input_arr, input_arr, size, size // noBand), + ) + noBand //= 2 + # Replace the forward link with the X coordinate of the seed to remove + # the need of looking at the other texture. We need it for coloring. + grid = (math.ceil(size / block[0]), size) + kernelDoubleToSingleList( + grid, + block, + (output, input_arr, input_arr, size), + ) + + # Phase 3 of PBA + block = (block_size, m3, 1) + grid = (math.ceil(size / block[0]), 1, 1) + kernelColor( + grid, + block, + (input_arr, output, size), + ) + + output = _unpack_int2(output, make_copy=False, int_dtype=int_dtype) + # make sure to crop any padding that was added here! + x = output[:orig_sy, :orig_sx, 0] + y = output[:orig_sy, :orig_sx, 1] + + # raise NotImplementedError("TODO") + vals = () + if return_distances: + # TODO: custom kernel for more efficient distance computation + y0, x0 = cupy.meshgrid( + *( + cupy.arange(s, dtype=cupy.int32) + for s in (orig_sy, orig_sx) + ), + indexing='ij', + sparse=True, + ) + tmp = (x - x0) + dist = tmp * tmp + tmp = (y - y0) + dist += tmp * tmp + if float64_distances: + dist = cupy.sqrt(dist) + else: + dist = dist.astype(cupy.float32) + cupy.sqrt(dist, out=dist) + vals = vals + (dist,) + if return_indices: + indices = cupy.stack((y, x), axis=0) + vals = vals + (indices,) + return vals diff --git a/python/cucim/src/cucim/core/operations/morphology/_pba_3d.py b/python/cucim/src/cucim/core/operations/morphology/_pba_3d.py new file mode 100644 index 000000000..dab484e90 --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/_pba_3d.py @@ -0,0 +1,311 @@ +import functools +import math +import numbers +import os + +import cupy +import numpy as np + +from ._pba_2d import _get_block_size + +try: + # math.lcm was introduced in Python 3.9 + from math import lcm +except ImportError: + + """Fallback implementation of least common multiple (lcm) + + TODO: remove once minimum Python requirement is >= 3.9 + """ + + def _lcm(a, b): + return abs(b * (a // math.gcd(a, b))) + + @functools.lru_cache() + def lcm(*args): + nargs = len(args) + if not all(isinstance(a, numbers.Integral) for a in args): + raise TypeError("all arguments must be integers") + if nargs == 0: + return 1 + res = int(args[0]) + if nargs == 1: + return abs(res) + for i in range(1, nargs): + x = int(args[i]) + res = _lcm(res, x) + return res + + +pba3d_defines_template = """ + +#define MARKER {marker} +#define MAX_INT {max_int} +#define BLOCKSIZE {block_size_3d} + +""" + +# For efficiency, the original PBA+ packs three 10-bit integers and two binary +# flags into a single 32-bit integer. The defines in +# `pba3d_defines_encode_32bit` handle this format. +pba3d_defines_encode_32bit = """ +// Sites : ENCODE(x, y, z, 0, 0) +// Not sites : ENCODE(0, 0, 0, 1, 0) or MARKER +#define ENCODED_INT_TYPE int +#define ZERO 0 +#define ONE 1 +#define ENCODE(x, y, z, a, b) (((x) << 20) | ((y) << 10) | (z) | ((a) << 31) | ((b) << 30)) +#define DECODE(value, x, y, z) \ + x = ((value) >> 20) & 0x3ff; \ + y = ((value) >> 10) & 0x3ff; \ + z = (value) & 0x3ff + +#define NOTSITE(value) (((value) >> 31) & 1) +#define HASNEXT(value) (((value) >> 30) & 1) + +#define GET_X(value) (((value) >> 20) & 0x3ff) +#define GET_Y(value) (((value) >> 10) & 0x3ff) +#define GET_Z(value) ((NOTSITE((value))) ? MAX_INT : ((value) & 0x3ff)) + +""" # noqa + + +# 64bit version of ENCODE/DECODE to allow a 20-bit integer per coordinate axis. +pba3d_defines_encode_64bit = """ +// Sites : ENCODE(x, y, z, 0, 0) +// Not sites : ENCODE(0, 0, 0, 1, 0) or MARKER +#define ENCODED_INT_TYPE long long +#define ZERO 0L +#define ONE 1L +#define ENCODE(x, y, z, a, b) (((x) << 40) | ((y) << 20) | (z) | ((a) << 61) | ((b) << 60)) +#define DECODE(value, x, y, z) \ + x = ((value) >> 40) & 0xfffff; \ + y = ((value) >> 20) & 0xfffff; \ + z = (value) & 0xfffff + +#define NOTSITE(value) (((value) >> 61) & 1) +#define HASNEXT(value) (((value) >> 60) & 1) + +#define GET_X(value) (((value) >> 40) & 0xfffff) +#define GET_Y(value) (((value) >> 20) & 0xfffff) +#define GET_Z(value) ((NOTSITE((value))) ? MAX_INT : ((value) & 0xfffff)) + +""" # noqa + + +@cupy.memoize(True) +def get_pba3d_src(block_size_3d=32, marker=-2147483648, max_int=2147483647, + size_max=1024): + pba3d_code = pba3d_defines_template.format( + block_size_3d=block_size_3d, marker=marker, max_int=max_int + ) + if size_max > 1024: + pba3d_code += pba3d_defines_encode_64bit + else: + pba3d_code += pba3d_defines_encode_32bit + kernel_directory = os.path.join(os.path.dirname(__file__), 'cuda') + with open(os.path.join(kernel_directory, 'pba_kernels_3d.h'), 'rt') as f: + pba3d_kernels = '\n'.join(f.readlines()) + pba3d_code += pba3d_kernels + return pba3d_code + + +# TODO: custom kernel for encode3d +def encode3d(arr, marker=-2147483648, bit_depth=32, size_max=1024): + if arr.ndim != 3: + raise ValueError("only 3d arr suppported") + if bit_depth not in [32, 64]: + raise ValueError("only bit_depth of 32 or 64 is supported") + if size_max > 1024: + dtype = np.int64 + else: + dtype = np.int32 + image = cupy.zeros(arr.shape, dtype=dtype, order='C') + cond = arr == 0 + z, y, x = cupy.where(cond) + # z, y, x so that x is the contiguous axis + # (must match TOID macro in the C++ code!) + if size_max > 1024: + image[cond] = (((x) << 40) | ((y) << 20) | (z)) + else: + image[cond] = (((x) << 20) | ((y) << 10) | (z)) + image[arr != 0] = marker # 1 << 32 + return image + + +# TODO: custom kernel for decode3d +def decode3d(output, size_max=1024): + if size_max > 1024: + x = (output >> 40) & 0xfffff + y = (output >> 20) & 0xfffff + z = output & 0xfffff + else: + x = (output >> 20) & 0x3ff + y = (output >> 10) & 0x3ff + z = output & 0x3ff + return (x, y, z) + + +def _determine_padding(shape, block_size, m1, m2, m3, blockx, blocky): + # TODO: can possibly revise to consider only particular factors for LCM on + # a given axis + LCM = lcm(block_size, m1, m2, m3, blockx, blocky) + orig_sz, orig_sy, orig_sx = shape + round_up = False + if orig_sx % LCM != 0: + # round up size to a multiple of the band size + round_up = True + sx = LCM * math.ceil(orig_sx / LCM) + else: + sx = orig_sx + if orig_sy % LCM != 0: + # round up size to a multiple of the band size + round_up = True + sy = LCM * math.ceil(orig_sy / LCM) + else: + sy = orig_sy + if orig_sz % LCM != 0: + # round up size to a multiple of the band size + round_up = True + sz = LCM * math.ceil(orig_sz / LCM) + else: + sz = orig_sz + + aniso = not (sx == sy == sz) + if aniso or round_up: + smax = max(sz, sy, sx) + padding_width = ( + (0, smax - orig_sz), (0, smax - orig_sy), (0, smax - orig_sx) + ) + else: + padding_width = None + return padding_width + + +def _pba_3d(arr, sampling=None, return_distances=True, return_indices=False, + block_params=None, check_warp_size=False, *, + float64_distances=False): + if arr.ndim != 3: + raise ValueError(f"expected a 3D array, got {arr.ndim}D") + + if sampling is not None: + raise NotImplementedError("sampling not yet supported") + # if len(sampling) != 3: + # raise ValueError("sampling must be a sequence of three values.") + + if block_params is None: + m1 = 1 + m2 = 1 + m3 = 2 + else: + m1, m2, m3 = block_params + + # reduce blockx for small inputs + s_min = min(arr.shape) + if s_min <= 4: + blockx = 4 + elif s_min <= 8: + blockx = 8 + elif s_min <= 16: + blockx = 16 + else: + blockx = 32 + blocky = 4 + + block_size = _get_block_size(check_warp_size) + + orig_sz, orig_sy, orig_sx = arr.shape + padding_width = _determine_padding( + arr.shape, block_size, m1, m2, m3, blockx, blocky + ) + if padding_width is not None: + arr = cupy.pad(arr, padding_width, mode='constant', constant_values=1) + size = arr.shape[0] + + # pba algorithm was implemented to use 32-bit integer to store compressed + # coordinates. input_arr will be C-contiguous, int32 + size_max = max(arr.shape) + input_arr = encode3d(arr, size_max=size_max) + buffer_idx = 0 + output = cupy.zeros_like(input_arr) + pba_images = [input_arr, output] + + block = (blockx, blocky, 1) + grid = (size // block[0], size // block[1], 1) + pba3d = cupy.RawModule( + code=get_pba3d_src(block_size_3d=block_size, size_max=size_max) + ) + + kernelFloodZ = pba3d.get_function('kernelFloodZ') + kernelMaurerAxis = pba3d.get_function('kernelMaurerAxis') + kernelColorAxis = pba3d.get_function('kernelColorAxis') + + kernelFloodZ( + grid, + block, + (pba_images[buffer_idx], pba_images[1 - buffer_idx], size) + ) + buffer_idx = 1 - buffer_idx + + block = (blockx, blocky, 1) + grid = (size // block[0], size // block[1], 1) + kernelMaurerAxis( + grid, + block, + (pba_images[buffer_idx], pba_images[1 - buffer_idx], size), + ) + + block = (block_size, m3, 1) + grid = (size // block[0], size, 1) + kernelColorAxis( + grid, + block, + (pba_images[1 - buffer_idx], pba_images[buffer_idx], size), + ) + + block = (blockx, blocky, 1) + grid = (size // block[0], size // block[1], 1) + kernelMaurerAxis( + grid, + block, + (pba_images[buffer_idx], pba_images[1 - buffer_idx], size), + ) + + block = (block_size, m3, 1) + grid = (size // block[0], size, 1) + kernelColorAxis( + grid, + block, + (pba_images[1 - buffer_idx], pba_images[buffer_idx], size), + ) + + output = pba_images[buffer_idx] + if return_distances or return_indices: + x, y, z = decode3d(output[:orig_sz, :orig_sy, :orig_sx], + size_max=size_max) + + vals = () + if return_distances: + # TODO: custom kernel for more efficient distance computation + orig_shape = (orig_sz, orig_sy, orig_sx) + z0, y0, x0 = cupy.meshgrid( + *(cupy.arange(s, dtype=cupy.int32) for s in orig_shape), + indexing='ij', + sparse=True + ) + tmp = (x - x0) + dist = tmp * tmp + tmp = (y - y0) + dist += tmp * tmp + tmp = (z - z0) + dist += tmp * tmp + if float64_distances: + dist = cupy.sqrt(dist) + else: + dist = dist.astype(cupy.float32) + cupy.sqrt(dist, out=dist) + vals = vals + (dist,) + if return_indices: + indices = cupy.stack((z, y, x), axis=0) + vals = vals + (indices,) + return vals diff --git a/python/cucim/src/cucim/core/operations/morphology/cuda/pba_kernels_2d.h b/python/cucim/src/cucim/core/operations/morphology/cuda/pba_kernels_2d.h new file mode 100644 index 000000000..61677c682 --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/cuda/pba_kernels_2d.h @@ -0,0 +1,451 @@ +// Euclidean Distance Transform +// +// Kernels for the 2D version of the Parallel Banding Algorithm (PBA+). +// +// MIT license: see 3rdparty/LICENSE.pba+ +// Copyright: (c) 2019 School of Computing, National University of Singapore +// +// Modifications by Gregory Lee (2022) (NVIDIA) +// - add user-defined pixel_int2_t to enable +// - replace __mul24 operations with standard multiplication operator + + +// START OF DEFINITIONS OVERRIDDEN BY THE PYTHON SCRIPT + +// The values included in this header file are those defined in the original +// PBA+ implementation + +// However, the Python code generation can potentially generate a different +// ENCODE/DECODE that use 20 bits per coordinates instead of 10 bits per +// coordinate with ENCODED_INT_TYPE as `long long`. + +#ifndef MARKER +#define MARKER -32768 +#endif + +#ifndef BLOCKSIZE +#define BLOCKSIZE 32 +#endif + +#ifndef pixel_int2_t +#define pixel_int2_t short2 +#define make_pixel(x, y) make_short2(x, y) +#endif + +// END OF DEFINITIONS OVERRIDDEN BY THE PYTHON SCRIPT + + +#define TOID(x, y, size) ((y) * (size) + (x)) + +#define LL long long +__device__ bool dominate(LL x1, LL y1, LL x2, LL y2, LL x3, LL y3, LL x0) +{ + LL k1 = y2 - y1, k2 = y3 - y2; + return (k1 * (y1 + y2) + (x2 - x1) * ((x1 + x2) - (x0 << 1))) * k2 > \ + (k2 * (y2 + y3) + (x3 - x2) * ((x2 + x3) - (x0 << 1))) * k1; +} +#undef LL + + +extern "C"{ + +__global__ void kernelFloodDown(pixel_int2_t *input, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + int id = TOID(tx, ty, size); + + pixel_int2_t pixel1, pixel2; + + pixel1 = make_pixel(MARKER, MARKER); + + for (int i = 0; i < bandSize; i++, id += size) { + pixel2 = input[id]; + + if (pixel2.x != MARKER) + pixel1 = pixel2; + + output[id] = pixel1; + } +} + +__global__ void kernelFloodUp(pixel_int2_t *input, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = (blockIdx.y+1) * bandSize - 1; + int id = TOID(tx, ty, size); + + pixel_int2_t pixel1, pixel2; + int dist1, dist2; + + pixel1 = make_pixel(MARKER, MARKER); + + for (int i = 0; i < bandSize; i++, id -= size) { + dist1 = abs(pixel1.y - ty + i); + + pixel2 = input[id]; + dist2 = abs(pixel2.y - ty + i); + + if (dist2 < dist1) + pixel1 = pixel2; + + output[id] = pixel1; + } +} + +__global__ void kernelPropagateInterband(pixel_int2_t *input, pixel_int2_t *margin_out, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int inc = bandSize * size; + int ny, nid, nDist; + pixel_int2_t pixel; + + // Top row, look backward + int ty = blockIdx.y * bandSize; + int topId = TOID(tx, ty, size); + int bottomId = TOID(tx, ty + bandSize - 1, size); + int tid = blockIdx.y * size + tx; + int bid = tid + (size * size / bandSize); + + pixel = input[topId]; + int myDist = abs(pixel.y - ty); + margin_out[tid] = pixel; + + for (nid = bottomId - inc; nid >= 0; nid -= inc) { + pixel = input[nid]; + + if (pixel.x != MARKER) { + nDist = abs(pixel.y - ty); + + if (nDist < myDist) + margin_out[tid] = pixel; + + break; + } + } + + // Last row, look downward + ty = ty + bandSize - 1; + pixel = input[bottomId]; + myDist = abs(pixel.y - ty); + margin_out[bid] = pixel; + + for (ny = ty + 1, nid = topId + inc; ny < size; ny += bandSize, nid += inc) { + pixel = input[nid]; + + if (pixel.x != MARKER) { + nDist = abs(pixel.y - ty); + + if (nDist < myDist) + margin_out[bid] = pixel; + + break; + } + } +} + +__global__ void kernelUpdateVertical(pixel_int2_t *color, pixel_int2_t *margin, pixel_int2_t *output, int size, int bandSize) +{ + __shared__ pixel_int2_t block[BLOCKSIZE][BLOCKSIZE]; + + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + + pixel_int2_t top = margin[blockIdx.y * size + tx]; + pixel_int2_t bottom = margin[(blockIdx.y + size / bandSize) * size + tx]; + pixel_int2_t pixel; + + int dist, myDist; + + int id = TOID(tx, ty, size); + + int n_step = bandSize / blockDim.x; + for(int step = 0; step < n_step; ++step) { + int y_start = blockIdx.y * bandSize + step * blockDim.x; + int y_end = y_start + blockDim.x; + + for (ty = y_start; ty < y_end; ++ty, id += size) { + pixel = color[id]; + myDist = abs(pixel.y - ty); + + dist = abs(top.y - ty); + if (dist < myDist) { myDist = dist; pixel = top; } + + dist = abs(bottom.y - ty); + if (dist < myDist) pixel = bottom; + + // temporary result is stored in block + block[threadIdx.x][ty - y_start] = make_pixel(pixel.y, pixel.x); + } + + __syncthreads(); + + // block is written to a transposed location in the output + + int tid = TOID(blockIdx.y * bandSize + step * blockDim.x + threadIdx.x, \ + blockIdx.x * blockDim.x, size); + + for(int i = 0; i < blockDim.x; ++i, tid += size) { + output[tid] = block[i][threadIdx.x]; + } + + __syncthreads(); + } +} + +__global__ void kernelProximatePoints(pixel_int2_t *input, pixel_int2_t *stack, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * bandSize; + int id = TOID(tx, ty, size); + int lasty = -1; + pixel_int2_t last1, last2, current; + + last1.y = -1; last2.y = -1; + + for (int i = 0; i < bandSize; i++, id += size) { + current = input[id]; + + if (current.x != MARKER) { + while (last2.y >= 0) { + if (!dominate(last1.x, last2.y, last2.x, \ + lasty, current.x, current.y, tx)) + break; + + lasty = last2.y; last2 = last1; + + if (last1.y >= 0) + last1 = stack[TOID(tx, last1.y, size)]; + } + + last1 = last2; last2 = make_pixel(current.x, lasty); lasty = current.y; + + stack[id] = last2; + } + } + + // Store the pointer to the tail at the last pixel of this band + if (lasty != ty + bandSize - 1) + stack[TOID(tx, ty + bandSize - 1, size)] = make_pixel(MARKER, lasty); +} + +__global__ void kernelCreateForwardPointers(pixel_int2_t *input, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = (blockIdx.y+1) * bandSize - 1; + int id = TOID(tx, ty, size); + int lasty = -1, nexty; + pixel_int2_t current; + + // Get the tail pointer + current = input[id]; + + if (current.x == MARKER) + nexty = current.y; + else + nexty = ty; + + for (int i = 0; i < bandSize; i++, id -= size) + if (ty - i == nexty) { + current = make_pixel(lasty, input[id].y); + output[id] = current; + + lasty = nexty; + nexty = current.y; + } + + // Store the pointer to the head at the first pixel of this band + if (lasty != ty - bandSize + 1) + output[id + size] = make_pixel(lasty, MARKER); +} + +__global__ void kernelMergeBands(pixel_int2_t *color, pixel_int2_t *link, pixel_int2_t *output, int size, int bandSize) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int band1 = blockIdx.y * 2; + int band2 = band1 + 1; + int firsty, lasty; + pixel_int2_t last1, last2, current; + // last1 and last2: x component store the x coordinate of the site, + // y component store the backward pointer + // current: y component store the x coordinate of the site, + // x component store the forward pointer + + // Get the two last items of the first list + lasty = band2 * bandSize - 1; + last2 = make_pixel(color[TOID(tx, lasty, size)].x, + link[TOID(tx, lasty, size)].y); + + if (last2.x == MARKER) { + lasty = last2.y; + + if (lasty >= 0) + last2 = make_pixel(color[TOID(tx, lasty, size)].x, + link[TOID(tx, lasty, size)].y); + else + last2 = make_pixel(MARKER, MARKER); + } + + if (last2.y >= 0) { + // Second item at the top of the stack + last1 = make_pixel(color[TOID(tx, last2.y, size)].x, + link[TOID(tx, last2.y, size)].y); + } + + // Get the first item of the second band + firsty = band2 * bandSize; + current = make_pixel(link[TOID(tx, firsty, size)].x, + color[TOID(tx, firsty, size)].x); + + if (current.y == MARKER) { + firsty = current.x; + + if (firsty >= 0) + current = make_pixel(link[TOID(tx, firsty, size)].x, + color[TOID(tx, firsty, size)].x); + else + current = make_pixel(MARKER, MARKER); + } + + // Count the number of item in the second band that survive so far. + // Once it reaches 2, we can stop. + int top = 0; + + while (top < 2 && current.y >= 0) { + // While there's still something on the left + while (last2.y >= 0) { + + if (!dominate(last1.x, last2.y, last2.x, \ + lasty, current.y, firsty, tx)) + break; + + lasty = last2.y; last2 = last1; + top--; + + if (last1.y >= 0) + last1 = make_pixel(color[TOID(tx, last1.y, size)].x, + link[TOID(tx, last1.y, size)].y); + } + + // Update the current pointer + output[TOID(tx, firsty, size)] = make_pixel(current.x, lasty); + + if (lasty >= 0) + output[TOID(tx, lasty, size)] = make_pixel(firsty, last2.y); + + last1 = last2; last2 = make_pixel(current.y, lasty); lasty = firsty; + firsty = current.x; + + top = max(1, top + 1); + + // Advance the current pointer to the next one + if (firsty >= 0) + current = make_pixel(link[TOID(tx, firsty, size)].x, + color[TOID(tx, firsty, size)].x); + else + current = make_pixel(MARKER, MARKER); + } + + // Update the head and tail pointer. + firsty = band1 * bandSize; + lasty = band2 * bandSize; + current = link[TOID(tx, firsty, size)]; + + if (current.y == MARKER && current.x < 0) { // No head? + last1 = link[TOID(tx, lasty, size)]; + + if (last1.y == MARKER) + current.x = last1.x; + else + current.x = lasty; + + output[TOID(tx, firsty, size)] = current; + } + + firsty = band1 * bandSize + bandSize - 1; + lasty = band2 * bandSize + bandSize - 1; + current = link[TOID(tx, lasty, size)]; + + if (current.x == MARKER && current.y < 0) { // No tail? + last1 = link[TOID(tx, firsty, size)]; + + if (last1.x == MARKER) + current.y = last1.y; + else + current.y = firsty; + + output[TOID(tx, lasty, size)] = current; + } +} + +__global__ void kernelDoubleToSingleList(pixel_int2_t *color, pixel_int2_t *link, pixel_int2_t *output, int size) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y; + int id = TOID(tx, ty, size); + + output[id] = make_pixel(color[id].x, link[id].y); +} + +__global__ void kernelColor(pixel_int2_t *input, pixel_int2_t *output, int size) +{ + __shared__ pixel_int2_t block[BLOCKSIZE][BLOCKSIZE]; + + int col = threadIdx.x; + int tid = threadIdx.y; + int tx = blockIdx.x * blockDim.x + col; + int dx, dy, lasty; + unsigned int best, dist; + pixel_int2_t last1, last2; + + lasty = size - 1; + + last2 = input[TOID(tx, lasty, size)]; + + if (last2.x == MARKER) { + lasty = last2.y; + last2 = input[TOID(tx, lasty, size)]; + } + + if (last2.y >= 0) + last1 = input[TOID(tx, last2.y, size)]; + + int y_start, y_end, n_step = size / blockDim.x; + for(int step = 0; step < n_step; ++step) { + y_start = size - step * blockDim.x - 1; + y_end = size - (step + 1) * blockDim.x; + + for (int ty = y_start - tid; ty >= y_end; ty -= blockDim.y) { + dx = last2.x - tx; dy = lasty - ty; + best = dist = dx * dx + dy * dy; + + while (last2.y >= 0) { + dx = last1.x - tx; dy = last2.y - ty; + dist = dx * dx + dy * dy; + + if (dist > best) + break; + + best = dist; lasty = last2.y; last2 = last1; + + if (last2.y >= 0) + last1 = input[TOID(tx, last2.y, size)]; + } + + block[threadIdx.x][ty - y_end] = make_pixel(lasty, last2.x); + } + + __syncthreads(); + + // note: transposes back to original shape here + if(!threadIdx.y) { + int id = TOID(y_end + threadIdx.x, blockIdx.x * blockDim.x, size); + for(int i = 0; i < blockDim.x; ++i, id+=size) { + output[id] = block[i][threadIdx.x]; + } + } + + __syncthreads(); + } +} +} // extern C diff --git a/python/cucim/src/cucim/core/operations/morphology/cuda/pba_kernels_3d.h b/python/cucim/src/cucim/core/operations/morphology/cuda/pba_kernels_3d.h new file mode 100644 index 000000000..c09f4b51f --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/cuda/pba_kernels_3d.h @@ -0,0 +1,237 @@ +// Euclidean Distance Transform +// +// Kernels for the 3D version of the Parallel Banding Algorithm (PBA+). +// +// MIT license: see 3rdparty/LICENSE.pba+ +// +// Modifications by Gregory Lee (2022) (NVIDIA) +// - allow user-defined ENCODED_INT_TYPE, ENCODE, DECODE + + +// START OF DEFINITIONS OVERRIDDEN BY THE PYTHON SCRIPT + +// The values included in this header file are those defined in the original +// PBA+ implementation + +// However, the Python code generation can potentially generate a different +// ENCODE/DECODE that use 20 bits per coordinates instead of 10 bits per +// coordinate with ENCODED_INT_TYPE as `long long`. + + +#ifndef MARKER +#define MARKER -2147483648 +#endif // MARKER + +#ifndef MAX_INT +#define MAX_INT 2147483647 +#endif + +#ifndef BLOCKSIZE +#define BLOCKSIZE 32 +#endif + +#ifndef ENCODE + +// Sites : ENCODE(x, y, z, 0, 0) +// Not sites : ENCODE(0, 0, 0, 1, 0) or MARKER +#define ENCODED_INT_TYPE int +#define ZERO 0 +#define ONE 1 +#define ENCODE(x, y, z, a, b) (((x) << 20) | ((y) << 10) | (z) | ((a) << 31) | ((b) << 30)) +#define DECODE(value, x, y, z) \ + x = ((value) >> 20) & 0x3ff; \ + y = ((value) >> 10) & 0x3ff; \ + z = (value) & 0x3ff + +#define NOTSITE(value) (((value) >> 31) & 1) +#define HASNEXT(value) (((value) >> 30) & 1) + +#define GET_X(value) (((value) >> 20) & 0x3ff) +#define GET_Y(value) (((value) >> 10) & 0x3ff) +#define GET_Z(value) ((NOTSITE((value))) ? MAX_INT : ((value) & 0x3ff)) + +#endif // ENCODE + +// END OF DEFINITIONS DEFINED IN THE PYTHON SCRIPT + + +#define LL long long +__device__ bool dominate(LL x_1, LL y_1, LL z_1, LL x_2, LL y_2, LL z_2, LL x_3, LL y_3, LL z_3, LL x_0, LL z_0) +{ + LL k_1 = y_2 - y_1, k_2 = y_3 - y_2; + + return (((y_1 + y_2) * k_1 + ((x_2 - x_1) * (x_1 + x_2 - (x_0 << 1)) + (z_2 - z_1) * (z_1 + z_2 - (z_0 << 1)))) * k_2 > \ + ((y_2 + y_3) * k_2 + ((x_3 - x_2) * (x_2 + x_3 - (x_0 << 1)) + (z_3 - z_2) * (z_2 + z_3 - (z_0 << 1)))) * k_1); +} +#undef LL + +#define TOID(x, y, z, size) ((((z) * (size)) + (y)) * (size) + (x)) + + +extern "C"{ + +__global__ void kernelFloodZ(ENCODED_INT_TYPE *input, ENCODED_INT_TYPE *output, int size) +{ + + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int ty = blockIdx.y * blockDim.y + threadIdx.y; + int tz = 0; + + int plane = size * size; + int id = TOID(tx, ty, tz, size); + ENCODED_INT_TYPE pixel1, pixel2; + + pixel1 = ENCODE(ZERO,ZERO,ZERO,ONE,ZERO); + + // Sweep down + for (int i = 0; i < size; i++, id += plane) { + pixel2 = input[id]; + + if (!NOTSITE(pixel2)) + pixel1 = pixel2; + + output[id] = pixel1; + } + + ENCODED_INT_TYPE dist1, dist2, nz; + + id -= plane + plane; + + // Sweep up + for (int i = size - 2; i >= 0; i--, id -= plane) { + nz = GET_Z(pixel1); + dist1 = abs(nz - (tz + i)); + + pixel2 = output[id]; + nz = GET_Z(pixel2); + dist2 = abs(nz - (tz + i)); + + if (dist2 < dist1) + pixel1 = pixel2; + + output[id] = pixel1; + } +} + + +__global__ void kernelMaurerAxis(ENCODED_INT_TYPE *input, ENCODED_INT_TYPE *stack, int size) +{ + int tx = blockIdx.x * blockDim.x + threadIdx.x; + int tz = blockIdx.y * blockDim.y + threadIdx.y; + int ty = 0; + + int id = TOID(tx, ty, tz, size); + + ENCODED_INT_TYPE lasty = 0; + ENCODED_INT_TYPE x1, y1, z1, x2, y2, z2, nx, ny, nz; + ENCODED_INT_TYPE p = ENCODE(ZERO,ZERO,ZERO,ONE,ZERO), s1 = ENCODE(ZERO,ZERO,ZERO,ONE,ZERO), s2 = ENCODE(ZERO,ZERO,ZERO,ONE,ZERO); + ENCODED_INT_TYPE flag = 0; + + for (ty = 0; ty < size; ++ty, id += size) { + p = input[id]; + + if (!NOTSITE(p)) { + + while (HASNEXT(s2)) { + DECODE(s1, x1, y1, z1); + DECODE(s2, x2, y2, z2); + DECODE(p, nx, ny, nz); + + if (!dominate(x1, y2, z1, x2, lasty, z2, nx, ty, nz, tx, tz)) + break; + + lasty = y2; s2 = s1; y2 = y1; + + if (HASNEXT(s2)) + s1 = stack[TOID(tx, y2, tz, size)]; + } + + DECODE(p, nx, ny, nz); + s1 = s2; + s2 = ENCODE(nx, lasty, nz, ZERO, flag); + y2 = lasty; + lasty = ty; + + stack[id] = s2; + + flag = ONE; + } + } + + if (NOTSITE(p)) + stack[TOID(tx, ty - 1, tz, size)] = ENCODE(ZERO, lasty, ZERO, ONE, flag); +} + +__global__ void kernelColorAxis(ENCODED_INT_TYPE *input, ENCODED_INT_TYPE *output, int size) +{ + __shared__ ENCODED_INT_TYPE block[BLOCKSIZE][BLOCKSIZE]; + + int col = threadIdx.x; + int tid = threadIdx.y; + int tx = blockIdx.x * blockDim.x + col; + int tz = blockIdx.y; + + ENCODED_INT_TYPE x1, y1, z1, x2, y2, z2; + ENCODED_INT_TYPE last1 = ENCODE(ZERO,ZERO,ZERO,ONE,ZERO), last2 = ENCODE(ZERO,ZERO,ZERO,ONE,ZERO), lasty; + long long dx, dy, dz, best, dist; + + lasty = size - 1; + + last2 = input[TOID(tx, lasty, tz, size)]; + DECODE(last2, x2, y2, z2); + + if (NOTSITE(last2)) { + lasty = y2; + if(HASNEXT(last2)) { + last2 = input[TOID(tx, lasty, tz, size)]; + DECODE(last2, x2, y2, z2); + } + } + + if (HASNEXT(last2)) { + last1 = input[TOID(tx, y2, tz, size)]; + DECODE(last1, x1, y1, z1); + } + + int y_start, y_end, n_step = size / blockDim.x; + for(int step = 0; step < n_step; ++step) { + y_start = size - step * blockDim.x - 1; + y_end = size - (step + 1) * blockDim.x; + + for (int ty = y_start - tid; ty >= y_end; ty -= blockDim.y) { + dx = x2 - tx; dy = lasty - ty; dz = z2 - tz; + best = dx * dx + dy * dy + dz * dz; + + while (HASNEXT(last2)) { + dx = x1 - tx; dy = y2 - ty; dz = z1 - tz; + dist = dx * dx + dy * dy + dz * dz; + + if(dist > best) break; + + best = dist; lasty = y2; last2 = last1; + DECODE(last2, x2, y2, z2); + + if (HASNEXT(last2)) { + last1 = input[TOID(tx, y2, tz, size)]; + DECODE(last1, x1, y1, z1); + } + } + + block[threadIdx.x][ty - y_end] = ENCODE(lasty, x2, z2, NOTSITE(last2), ZERO); + } + + __syncthreads(); + + if(!threadIdx.y) { + int id = TOID(y_end + threadIdx.x, blockIdx.x * blockDim.x, tz, size); + for(int i = 0; i < blockDim.x; i++, id+=size) { + output[id] = block[i][threadIdx.x]; + } + } + + __syncthreads(); + } +} + + +} // extern C diff --git a/python/cucim/src/cucim/core/operations/morphology/tests/test_distance_transform.py b/python/cucim/src/cucim/core/operations/morphology/tests/test_distance_transform.py new file mode 100644 index 000000000..ab46ae1fb --- /dev/null +++ b/python/cucim/src/cucim/core/operations/morphology/tests/test_distance_transform.py @@ -0,0 +1,145 @@ +from copy import copy + +import cupy as cp +import numpy as np +import pytest +import scipy.ndimage as ndi_cpu + +from cucim.core.operations.morphology import distance_transform_edt + + +def binary_image(shape, pct_true=50): + rng = cp.random.default_rng(123) + x = rng.integers(0, 100, size=shape, dtype=cp.uint8) + return x >= pct_true + + +def assert_percentile_equal(arr1, arr2, pct=95): + """Assert that at least pct% of the entries in arr1 and arr2 are equal.""" + pct_mismatch = (100 - pct) / 100 + arr1 = cp.asnumpy(arr1) + arr2 = cp.asnumpy(arr2) + mismatch = np.sum(arr1 != arr2) / arr1.size + assert mismatch < pct_mismatch + + +@pytest.mark.parametrize('return_indices', [False, True]) +@pytest.mark.parametrize('return_distances', [False, True]) +@pytest.mark.parametrize( + 'shape, sampling', + [ + ((256, 128), None), + ((384, 256), (1.5, 1.5)), + ((14, 32, 50), None), + ((50, 32, 24), (2, 2, 2)), + ] +) +@pytest.mark.parametrize('density', [5, 50, 95]) +@pytest.mark.parametrize('block_params', [None, (1, 1, 1)]) +def test_distance_transform_edt( + shape, sampling, return_distances, return_indices, density, block_params +): + + if not (return_indices or return_distances): + return + + kwargs_scipy = dict( + sampling=sampling, + return_distances=return_distances, + return_indices=return_indices, + ) + kwargs_cucim = copy(kwargs_scipy) + kwargs_cucim['block_params'] = block_params + img = binary_image(shape, pct_true=density) + out = distance_transform_edt(img, **kwargs_cucim) + expected = ndi_cpu.distance_transform_edt(cp.asnumpy(img), **kwargs_scipy) + if return_indices and return_distances: + assert len(out) == 2 + cp.testing.assert_allclose(out[0], expected[0]) + # May differ at a small % of coordinates where multiple points were + # equidistant. + assert_percentile_equal(out[1], expected[1], pct=95) + elif return_distances: + cp.testing.assert_allclose(out, expected) + elif return_indices: + assert_percentile_equal(out, expected, pct=95) + + +@pytest.mark.parametrize('return_indices', [False, True]) +@pytest.mark.parametrize('return_distances', [False, True]) +@pytest.mark.parametrize( + 'shape, sampling', + [ + ((384, 256), (1, 3)), + ((50, 32, 24), (1, 2, 4)), + ] +) +@pytest.mark.parametrize('density', [5, 50, 95]) +def test_distance_transform_edt_nonuniform_sampling( + shape, sampling, return_distances, return_indices, density +): + + if not (return_indices or return_distances): + return + + kwargs_scipy = dict( + sampling=sampling, + return_distances=return_distances, + return_indices=return_indices, + ) + kwargs_cucim = copy(kwargs_scipy) + img = binary_image(shape, pct_true=density) + if sampling is not None and len(np.unique(sampling)) != 1: + with pytest.raises(NotImplementedError): + distance_transform_edt(img, **kwargs_cucim) + return + + +@pytest.mark.parametrize('value', [0, 1, 3]) +@pytest.mark.parametrize('ndim', [2, 3]) +def test_distance_transform_edt_uniform_valued(value, ndim): + """ensure default block_params is robust to anisotropic shape.""" + img = cp.full((48, ) * ndim, value, dtype=cp.uint8) + # ensure there is at least 1 pixel at background intensity + img[(slice(24, 25),) * ndim] = 0 + out = distance_transform_edt(img) + expected = ndi_cpu.distance_transform_edt(cp.asnumpy(img)) + cp.testing.assert_allclose(out, expected) + + +@pytest.mark.parametrize('sx', list(range(16))) +@pytest.mark.parametrize('sy', list(range(16))) +def test_distance_transform_edt_2d_aniso(sx, sy): + """ensure default block_params is robust to anisotropic shape.""" + shape = (128 + sy, 128 + sx) + img = binary_image(shape, pct_true=80) + out = distance_transform_edt(img) + expected = ndi_cpu.distance_transform_edt(cp.asnumpy(img)) + cp.testing.assert_allclose(out, expected) + + +@pytest.mark.parametrize('sx', list(range(4))) +@pytest.mark.parametrize('sy', list(range(4))) +@pytest.mark.parametrize('sz', list(range(4))) +def test_distance_transform_edt_3d_aniso(sx, sy, sz): + """ensure default block_params is robust to anisotropic shape.""" + shape = (16 + sz, 32 + sy, 48 + sx) + img = binary_image(shape, pct_true=80) + out = distance_transform_edt(img) + expected = ndi_cpu.distance_transform_edt(cp.asnumpy(img)) + cp.testing.assert_allclose(out, expected) + + +@pytest.mark.parametrize('ndim', [1, 4, 5]) +def test_distance_transform_edt_unsupported_ndim(ndim): + with pytest.raises(NotImplementedError): + distance_transform_edt(cp.zeros((8,) * ndim)) + + +@pytest.mark.skip(reason="excessive memory requirement") +def test_distance_transform_edt_3d_int64(): + shape = (1280, 1280, 1280) + img = binary_image(shape, pct_true=80) + distance_transform_edt(img) + # Note: no validation vs. scipy.ndimage due to excessive run time + return