Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dask chunking on frame level implemented #242

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions src/nd2/nd2file.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,12 @@ def write_tiff(
modify_ome=modify_ome,
)

def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Array:
def to_dask(
self,
wrapper: bool = True,
copy: bool = True,
frame_chunks: int | tuple | None = None,
) -> dask.array.core.Array:
"""Create dask array (delayed reader) representing image.

This generally works well, but it remains to be seen whether performance
Expand All @@ -913,6 +918,11 @@ def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Ar
If `True` (the default), the dask chunk-reading function will return
an array copy. This can avoid segfaults in certain cases, though it
may also add overhead.
frame_chunks : tuple | int | None
If `None` (the default), the file will not be chunked on the frame level.
Otherwise expects the dask compatible chunks to chunk the frames along
channel, y, and x axis. If a tuple, must have same length as
`self._frame_shape`.

Returns
-------
Expand All @@ -922,7 +932,46 @@ def to_dask(self, wrapper: bool = True, copy: bool = True) -> dask.array.core.Ar
from dask.array.core import map_blocks

chunks = [(1,) * x for x in self._coord_shape]
chunks += [(x,) for x in self._frame_shape]
if frame_chunks is None:
chunks += [(x,) for x in self._frame_shape]
elif isinstance(frame_chunks, int):
for frame_len in self._frame_shape:
div = frame_len // frame_chunks
if div == 0:
chunks.append((frame_len,))
else:
_chunks = (frame_chunks,) * div
if frame_len % frame_chunks != 0:
_chunks += (frame_len - div * frame_chunks,)
chunks.append(_chunks)
elif len(frame_chunks) != len(self._frame_shape):
raise ValueError(
f"frame_chunks must be of length {len(self._frame_shape)}."
)
elif isinstance(frame_chunks[0], int):
if not all(isinstance(frame_chunk, int) for frame_chunk in frame_chunks):
raise ValueError(
"frame_chunks must be a tuple of ints or tuple of tuple of ints."
)
for frame_len, frame_chunk in zip(self._frame_shape, frame_chunks):
div = frame_len // frame_chunk
if div == 0:
chunks.append((frame_len,))
else:
_chunks = (frame_chunk,) * div
if frame_len % frame_chunk != 0:
_chunks += (frame_len - div * frame_chunk,)
chunks.append(_chunks)
else:
if not all(
sum(frame_chunk) == frame_len
for frame_chunk, frame_len in zip(frame_chunks, self._frame_shape)
):
raise ValueError(
"Sum of frame_chunks does not align with frame shape of file."
)
chunks.extend(frame_chunks)

dask_arr = map_blocks(
self._dask_block,
copy=copy,
Expand All @@ -944,15 +993,16 @@ def _seq_index_from_coords(self, coords: Sequence) -> Sequence[int] | SupportsIn
return self._NO_IDX
return np.ravel_multi_index(coords, self._coord_shape) # type: ignore

def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray:
if isinstance(block_id, np.ndarray):
def _dask_block(self, copy: bool, block_info: dict) -> np.ndarray:
if isinstance(block_info, np.ndarray):
return None
with self._lock:
was_closed = self.closed
if self.closed:
self.open()
try:
ncoords = len(self._coord_shape)
block_id = block_info[None]["chunk-location"]
idx = self._seq_index_from_coords(block_id[:ncoords])

if idx == self._NO_IDX:
Expand All @@ -962,6 +1012,11 @@ def _dask_block(self, copy: bool, block_id: tuple[int]) -> np.ndarray:
)
idx = 0
data = self.read_frame(int(idx)) # type: ignore
slices = tuple(
slice(al[0], al[1])
for al in block_info[None]["array-location"][ncoords:]
)
data = data[slices]
data = data.copy() if copy else data
return data[(np.newaxis,) * ncoords]
finally:
Expand Down Expand Up @@ -1207,7 +1262,10 @@ def binary_data(self) -> BinaryLayers | None:
return self._rdr.binary_data()

def ome_metadata(
self, *, include_unstructured: bool = True, tiff_file_name: str | None = None
self,
*,
include_unstructured: bool = True,
tiff_file_name: str | None = None,
) -> OME:
"""Return `ome_types.OME` metadata object for this file.

Expand Down
43 changes: 43 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,49 @@ def test_dask_closed(single_nd2):
assert isinstance(dsk.compute(), np.ndarray)


@pytest.fixture(
params=[
(None, ((2,), (32,), (32,))),
(2, ((2,), (2,) * 16, (2,) * 16)),
(3, ((2,), (3,) * 10 + (2,), (3,) * 10 + (2,))),
((3, 17, 33), ((2,), (17, 15), (32,))),
((2, 16, 16), ((2,), (16, 16), (16, 16))),
(((1, 1), (8, 8, 8, 8), (16, 16)), ((1, 1), (8, 8, 8, 8), (16, 16))),
(((2,), (20, 12), (32,)), ((2,), (20, 12), (32,))),
],
ids=lambda x: str(x),
)
def passing_frame_chunks(request):
return request.param


def test_dask_chunking(single_nd2, passing_frame_chunks):
gatoniel marked this conversation as resolved.
Show resolved Hide resolved
with ND2File(single_nd2) as nd:
dsk = nd.to_dask(frame_chunks=passing_frame_chunks[0])
assert len(dsk.chunks) == 4
assert dsk.chunks[1:] == passing_frame_chunks[1]
unchunked = nd.to_dask()
assert (dsk.compute() == unchunked.compute()).all()


@pytest.fixture(
params=[
(2, 3),
(2, (16, 16), (16, 16)),
((1, 1, 1), (16, 16), (32,)),
],
ids=lambda x: str(x),
)
def failing_frame_chunks(request):
gatoniel marked this conversation as resolved.
Show resolved Hide resolved
return request.params


def test_value_error_dask_chunking(single_nd2, failing_frame_chunks):
with ND2File(single_nd2) as nd:
with pytest.raises(ValueError):
nd.to_dask(frame_chunks=passing_frame_chunks)


def test_full_read(new_nd2):
pytest.importorskip("xarray")
with ND2File(new_nd2) as nd:
Expand Down
Loading