Skip to content

Commit

Permalink
Improve tests for rechunking (#1532)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Aug 19, 2024
1 parent daee0e0 commit 9de9f3b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 29 deletions.
34 changes: 34 additions & 0 deletions alembic/versions/59c5cc87c066_drop_outdated_rechunking_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Drop outdated rechunking data
Revision ID: 59c5cc87c066
Revises: e11cd1aaed38
Create Date: 2024-08-16 15:16:27.114045
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '59c5cc87c066'
down_revision = 'e11cd1aaed38'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.execute(
"""
delete from test_run
where originalname in (
'test_adjacent_groups',
'test_heal_oversplit',
'test_swap_axes',
'test_tiles_to_rows'
)
"""
)


def downgrade() -> None:
pass
84 changes: 55 additions & 29 deletions tests/benchmarks/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import dask
import dask.array as da
import pytest
from dask.utils import parse_bytes

from ..conftest import requires_p2p_memory, requires_p2p_rechunk
from ..utils_test import cluster_memory, scaled_array_shape, wait


@pytest.fixture(params=["8.5 MiB", "auto"])
def input_chunk_size(request):
return request.param


@pytest.fixture(
params=[
pytest.param("tasks", marks=pytest.mark.shuffle_tasks),
Expand All @@ -18,14 +22,12 @@
),
]
)
def configure_rechunking(request, memory_multiplier):
def configure_rechunking_in_memory(request):
if request.param == "tasks":
with dask.config.set({"array.rechunk.method": "tasks"}):
yield
else:
disk = "disk" in request.param
if not disk and memory_multiplier > 0.4:
pytest.skip("Out of memory")
with dask.config.set(
{
"array.rechunk.method": "p2p",
Expand All @@ -35,47 +37,73 @@ def configure_rechunking(request, memory_multiplier):
yield


@pytest.fixture(params=["8 MiB", "128 MiB"])
def configure_chunksize(request, memory_multiplier):
if memory_multiplier > 0.4 and parse_bytes(request.param) < parse_bytes("64 MiB"):
pytest.skip("too slow")

with dask.config.set({"array.chunk-size": request.param}):
yield
@pytest.fixture(
params=[
pytest.param("tasks", marks=pytest.mark.shuffle_tasks),
pytest.param("p2p", marks=[pytest.mark.shuffle_p2p, requires_p2p_rechunk]),
]
)
def configure_rechunking_out_of_core(request):
if request.param == "tasks":
with dask.config.set({"array.rechunk.method": "tasks"}):
yield
else:
with dask.config.set(
{
"array.rechunk.method": "p2p",
"distributed.p2p.disk": True,
}
):
yield


def test_tiles_to_rows(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_chunksize,
configure_rechunking,
input_chunk_size,
configure_rechunking_in_memory,
small_client,
):
"""2D array sliced into square tiles becomes sliced by columns.
This use case can be broken down into N independent problems.
In task rechunk, this generates O(N) intermediate tasks and graph edges.
"""
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", "x"))
shape = scaled_array_shape(memory * 1.5, ("x", "x"))

a = da.random.random(shape, chunks="auto")
a = da.random.random(shape, chunks=input_chunk_size)
a = a.rechunk((-1, "auto")).sum()
wait(a, small_client, timeout=600)


def test_swap_axes(
def test_swap_axes_in_memory(
# Order matters: don't initialize client when skipping test
input_chunk_size,
configure_rechunking_in_memory,
small_client,
):
"""2D array sliced by columns becomes sliced by rows.
This is an N-to-N problem, so grouping into sub-problems is impossible.
In task rechunk, this generates O(N^2) intermediate tasks and graph edges.
"""
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * 0.5, ("x", "x"))

a = da.random.random(shape, chunks=(-1, input_chunk_size))
a = a.rechunk(("auto", -1)).sum()
wait(a, small_client, timeout=600)


def test_swap_axes_out_of_core(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_chunksize,
configure_rechunking,
configure_rechunking_out_of_core,
small_client,
):
"""2D array sliced by columns becomes sliced by rows.
This is an N-to-N problem, so grouping into sub-problems is impossible.
In task rechunk, this generates O(N^2) intermediate tasks and graph edges.
"""
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", "x"))
shape = scaled_array_shape(memory * 1.5, ("x", "x"))

a = da.random.random(shape, chunks=(-1, "auto"))
a = a.rechunk(("auto", -1)).sum()
Expand All @@ -84,33 +112,31 @@ def test_swap_axes(

def test_adjacent_groups(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_chunksize,
configure_rechunking,
input_chunk_size,
configure_rechunking_in_memory,
small_client,
):
"""M-to-N use case, where each input task feeds into a localized but substantial
subset of the output tasks, with partial interaction between adjacent zones.
"""
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", 10, 10_000))
shape = scaled_array_shape(memory * 1.5, ("x", 10, 10_000))

a = da.random.random(shape, chunks=("auto", 2, 5_000))
a = da.random.random(shape, chunks=(input_chunk_size, 2, 5_000))
a = a.rechunk(("auto", 5, 10_000)).sum()
wait(a, small_client, timeout=600)


def test_heal_oversplit(
# Order matters: don't initialize client when skipping test
memory_multiplier,
configure_rechunking,
configure_rechunking_in_memory,
small_client,
):
"""rechunk() is used to heal a situation where chunks are too small.
This is a trivial N-to-1 reduction step that gets no benefit from p2p rechunking.
"""
memory = cluster_memory(small_client)
shape = scaled_array_shape(memory * memory_multiplier, ("x", "x"))
shape = scaled_array_shape(memory * 1.5, ("x", "x"))
# Avoid exact n:1 rechunking, which would be a simpler special case.
# Dask should be smart enough to avoid splitting input chunks out to multiple output
# chunks.
Expand Down

0 comments on commit 9de9f3b

Please sign in to comment.