Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concatenate small input chunks before P2P rechunking #8832

Merged
merged 22 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 158 additions & 15 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@

from __future__ import annotations

import math
import mmap
import os
from collections import defaultdict
Expand All @@ -111,7 +112,7 @@
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from itertools import product
from itertools import chain, product
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast

Expand All @@ -124,6 +125,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.layers import Layer
from dask.typing import Key
from dask.utils import parse_bytes

from distributed.core import PooledRPCCall
from distributed.metrics import context_meter
Expand Down Expand Up @@ -220,7 +222,7 @@
return da.empty(x.shape, chunks=chunks, dtype=x.dtype)
from dask.array.core import new_da_object

prechunked = _calculate_prechunking(x.chunks, chunks)
prechunked = _calculate_prechunking(x.chunks, chunks, x.dtype, block_size_limit)
if prechunked != x.chunks:
x = cast(
"da.Array",
Expand Down Expand Up @@ -433,8 +435,140 @@


def _calculate_prechunking(
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
old_chunks: ChunkedAxes,
new_chunks: ChunkedAxes,
dtype: np.dtype,
block_size_limit: int | None,
) -> ChunkedAxes:
"""Calculate how to perform the pre-rechunking step

During the pre-rechunking step, we
1. Split input chunks along partial boundaries to make partials completely independent of one another
2. Merge small chunks within partials to reduce the number of transfer tasks and corresponding overhead
"""
split_axes = _split_chunks_along_partial_boundaries(old_chunks, new_chunks)

# We can only determine how to concatenate chunks if we can calculate block sizes.
has_nans = (any(math.isnan(y) for y in x) for x in old_chunks)

if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans):
return tuple(tuple(chain(*axis)) for axis in split_axes)

if dtype is None or dtype.hasobject or dtype.itemsize == 0:
return tuple(tuple(chain(*axis)) for axis in split_axes)

# We made sure that there are no NaNs in split_axes above
return _concatenate_small_chunks(
split_axes, old_chunks, new_chunks, dtype, block_size_limit # type: ignore[arg-type]
)


def _concatenate_small_chunks(
split_axes: list[list[list[int]]],
old_chunks: ChunkedAxes,
new_chunks: ChunkedAxes,
dtype: np.dtype,
block_size_limit: int | None,
) -> ChunkedAxes:
"""Concatenate small chunks within partials.

By concatenating chunks within partials, we reduce the number of P2P transfer tasks and their
corresponding overhead.

The algorithm used in this function is very similar to :func:`dask.array.rechunk.find_merge_rechunk`,
the main difference is that we have to make sure only to merge chunks within partials.
"""
import numpy as np

block_size_limit = block_size_limit or dask.config.get("array.chunk-size")

if isinstance(block_size_limit, str):
block_size_limit = parse_bytes(block_size_limit)

# Make it a number of elements
block_size_limit //= dtype.itemsize

# We verified earlier that we do not have any NaNs
largest_old_block = _largest_block_size(old_chunks) # type: ignore[arg-type]
largest_new_block = _largest_block_size(new_chunks) # type: ignore[arg-type]
block_size_limit = max([block_size_limit, largest_old_block, largest_new_block])

old_largest_width = [max(chain(*axis)) for axis in split_axes]
new_largest_width = [max(c) for c in new_chunks]

# This represents how much each dimension increases (>1) or reduces (<1)
# the graph size during rechunking
graph_size_effect = {
dim: len(new_axis) / sum(map(len, split_axis))
for dim, (split_axis, new_axis) in enumerate(zip(split_axes, new_chunks))
}

ndim = len(old_chunks)

# This represents how much each dimension increases (>1) or reduces (<1) the
# largest block size during rechunking
block_size_effect = {
dim: new_largest_width[dim] / (old_largest_width[dim] or 1)
for dim in range(ndim)
}

# Our goal is to reduce the number of nodes in the rechunk graph
# by concatenating some adjacent chunks, so consider dimensions where we can
# reduce the # of chunks
candidates = [dim for dim in range(ndim) if graph_size_effect[dim] <= 1.0]

# Concatenating along each dimension reduces the graph size by a certain factor
# and increases memory largest block size by a certain factor.
# We want to optimize the graph size while staying below the given
# block_size_limit. This is in effect a knapsack problem, except with
# multiplicative values and weights. Just use a greedy algorithm
# by trying dimensions in decreasing value / weight order.
def key(k: int) -> float:
gse = graph_size_effect[k]
bse = block_size_effect[k]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a sentences what these 2 variables represent when you define them above? Took me a bit to figure this out

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if bse == 1:
bse = 1 + 1e-9
return (np.log(gse) / np.log(bse)) if bse > 0 else 0

sorted_candidates = sorted(candidates, key=key)

concatenated_axes: list[list[int]] = [[] for i in range(ndim)]

# Sim all the axes that are no candidates
for i in range(ndim):
if i in candidates:
continue
concatenated_axes[i] = list(chain(*split_axes[i]))

# We want to concatenate chunks
for axis_index in sorted_candidates:
concatenated_axis = concatenated_axes[axis_index]
multiplier = math.prod(
old_largest_width[:axis_index] + old_largest_width[axis_index + 1 :]
)
axis_limit = block_size_limit // multiplier

for partial in split_axes[axis_index]:
current = partial[0]
for chunk in partial[1:]:
if (current + chunk) > axis_limit:
concatenated_axis.append(current)
current = chunk
else:
current += chunk
concatenated_axis.append(current)
old_largest_width[axis_index] = max(concatenated_axis)
return tuple(tuple(axis) for axis in concatenated_axes)


def _split_chunks_along_partial_boundaries(
old_chunks: ChunkedAxes, new_chunks: ChunkedAxes
) -> list[list[list[float]]]:
"""Split the old chunks along the boundaries of partials, i.e., groups of new chunks that share the same inputs.

By splitting along the boundaries before rechunkin their input tasks become disjunct and each partial conceptually
operates on an independent sub-array.
"""
from dask.array.rechunk import old_to_new

_old_to_new = old_to_new(old_chunks, new_chunks)
Expand All @@ -443,10 +577,13 @@

split_axes = []

# Along each axis, we want to figure out how we have to split input chunks in order to make
# partials disjunct. We then group the resulting input chunks per partial before returning.
for axis_index, slices in enumerate(partials):
old_to_new_axis = _old_to_new[axis_index]
old_axis = old_chunks[axis_index]
split_axis = []
partial_chunks = []
for slice_ in slices:
first_new_chunk = slice_.start
first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0]
Expand All @@ -465,22 +602,28 @@
chunk_size = last_old_slice.stop
if first_old_slice.start != 0:
chunk_size -= first_old_slice.start
split_axis.append(chunk_size)
continue

split_axis.append(first_chunk_size - first_old_slice.start)

split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk])

if last_old_slice.stop is not None:
chunk_size = last_old_slice.stop
partial_chunks.append(chunk_size)
else:
chunk_size = last_chunk_size
partial_chunks.append(first_chunk_size - first_old_slice.start)

split_axis.append(chunk_size)
partial_chunks.extend(old_axis[first_old_chunk + 1 : last_old_chunk])

if last_old_slice.stop is not None:
chunk_size = last_old_slice.stop
else:
chunk_size = last_chunk_size

Check warning on line 614 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L614

Added line #L614 was not covered by tests

partial_chunks.append(chunk_size)
split_axis.append(partial_chunks)
partial_chunks = []
if partial_chunks:
split_axis.append(partial_chunks)

Check warning on line 620 in distributed/shuffle/_rechunk.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_rechunk.py#L620

Added line #L620 was not covered by tests
split_axes.append(split_axis)
return tuple(tuple(axis) for axis in split_axes)
return split_axes


def _largest_block_size(chunks: tuple[tuple[int, ...], ...]) -> int:
return math.prod(map(max, chunks))


def _split_partials(
Expand Down
97 changes: 90 additions & 7 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ async def test_rechunk_avoid_needless_chunking(c, s, *ws):
x = da.ones(16, chunks=2)
y = x.rechunk(8, method="p2p")
dsk = y.__dask_graph__()
assert len(dsk) <= 8 + 2
# 8 inputs, 2 concatenations of small inputs, 2 outputs
assert len(dsk) <= 8 + 2 + 2


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1340,7 +1341,7 @@ async def test_partial_rechunk_taskgroups(c, s):
),
timeout=5,
)
assert len(s.task_groups) < 6
assert len(s.task_groups) < 7


@pytest.mark.parametrize(
Expand All @@ -1354,25 +1355,107 @@ async def test_partial_rechunk_taskgroups(c, s):
],
)
def test_calculate_prechunking_1d(old, new, expected):
actual = _calculate_prechunking(old, new)
actual = _calculate_prechunking(old, new, np.dtype, None)
assert actual == expected


@pytest.mark.parametrize(
["old", "new", "expected"],
[
[((2, 2), (3, 3)), ((2, 2), (3, 3)), ((2, 2), (3, 3))],
[((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))],
[((2, 2), (3, 3)), ((4,), (3, 3)), ((4,), (3, 3))],
[((2, 2), (3, 3)), ((1, 1, 1, 1), (3, 3)), ((2, 2), (3, 3))],
[
((2, 2, 2), (3, 3, 3)),
((1, 2, 2, 1), (2, 3, 4)),
((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)),
((1, 2, 2, 1), (2, 3, 4)),
],
[((1, np.nan), (3, 3)), ((1, np.nan), (2, 2, 2)), ((1, np.nan), (2, 1, 1, 2))],
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))],
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (3,))],
],
)
def test_calculate_prechunking_2d(old, new, expected):
actual = _calculate_prechunking(old, new)
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == expected


@pytest.mark.parametrize(
["old", "new", "expected"],
[
(
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
((1, 1, 1, 1), (4,), (2, 2)),
((2, 2), (4,), (1, 1, 1, 1)),
),
(
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
((1, 1, 1, 1), (2, 2), (2, 2)),
((2, 2), (2, 2), (2, 2)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one worries me a little bit. the max input chunk is 2, max output chunk is 4 but the algorithm concatenates in a way that we end with 8, which is not great

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the block size limit the upper bound here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

),
(
((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)),
((1, 1, 1, 1), (2, 2), (4,)),
((2, 2), (2, 2), (2, 2)),
),
(
((1, 1, 1, 1), (1, 1, 1, 1), (2, 2)),
((2, 2), (4,), (1, 1, 1, 1)),
((2, 2), (2, 2), (2, 2)),
),
],
)
def test_calculate_prechunking_3d(old, new, expected):
with dask.config.set({"array.chunk-size": "16 B"}):
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == expected


@pytest.mark.parametrize(
["chunk_size", "expected"],
[
("1 B", ((10,), (1,) * 10)),
("20 B", ((10,), (1,) * 10)),
("40 B", ((10,), (2, 2, 1, 2, 2, 1))),
("100 B", ((10,), (5, 5))),
],
)
def test_calculate_prechunking_concatenation(chunk_size, expected):
old = ((10,), (1,) * 10)
new = ((2,) * 5, (5, 5))
with dask.config.set({"array.chunk-size": chunk_size}):
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == expected


def test_calculate_prechunking_does_not_concatenate_object_type():
old = ((10,), (1,) * 10)
new = ((2,) * 5, (5, 5))

# Ensure that int dtypes get concatenated
new = ((2,) * 5, (5, 5))
with dask.config.set({"array.chunk-size": "100 B"}):
actual = _calculate_prechunking(old, new, np.dtype(np.int16), None)
assert actual == ((10,), (5, 5))

# Ensure object dtype chunks do not get concatenated
with dask.config.set({"array.chunk-size": "100 B"}):
actual = _calculate_prechunking(old, new, np.dtype(object), None)
assert actual == old


@pytest.mark.parametrize(
["old", "new", "expected"],
[
[((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))],
[
((2, 2, 2), (3, 3, 3)),
((1, 2, 2, 1), (2, 3, 4)),
((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)),
],
[((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))],
],
)
def test_calculate_prechunking_splitting(old, new, expected):
# _calculate_prechunking does not concatenate on object
actual = _calculate_prechunking(old, new, np.dtype(object), None)
assert actual == expected
Loading