Skip to content

Commit

Permalink
Fix concat form comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Jul 31, 2024
1 parent d1922ab commit fc9589b
Showing 1 changed file with 6 additions and 17 deletions.
23 changes: 6 additions & 17 deletions src/dask_awkward/lib/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from awkward.operations.ak_concatenate import (
enforce_concatenated_form as enforce_layout_to_concatenated_form,
)
from awkward.typetracer import typetracer_from_form
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

Expand Down Expand Up @@ -65,29 +64,18 @@ def concatenate(
if len(metas) == 0:
raise ValueError("Need at least one array to concatenate")

# Are we performing a _logical_ concatenation?
if axis == 0:
# There are two possible cases here:
# 1. all arrays have identical metas — just grow the Dask collection
# 2. some arrays have different metas — coerce arrays to same form

# Drop reports from metas to avoid later touching any buffers
metas_no_report = [
typetracer_from_form(x.layout.form, behavior=x.behavior, attrs=x._attrs)
for x in metas
]
# Concatenate metas to determine result form
meta_no_report = ak.concatenate(
metas_no_report, axis=0, behavior=behavior, attrs=attrs
)
intended_form = meta_no_report.layout.form
intended_form = metas[0].layout.form

# If any forms aren't equal to this form, we must enforce each form to the same type
if any(
not m.layout.form.is_equal_to(
intended_form, all_parameters=True, form_key=False
)
for m in metas_no_report[1:]
for m in metas[1:]
):
arrays = [
map_partitions(
Expand Down Expand Up @@ -122,13 +110,14 @@ def concatenate(

aml = AwkwardMaterializedLayer(g, previous_layer_names=[arrays[0].name])

new_meta = ak.copy(metas[0])
new_meta._report = report
hlg = HighLevelGraph.from_collections(name, aml, dependencies=arrays)
meta_no_report._report = report
aml.meta = meta_no_report
aml.meta = new_meta
return new_array_object(
hlg,
name,
meta=meta_no_report,
meta=new_meta,
npartitions=sum(a.npartitions for a in arrays),
)

Expand Down

0 comments on commit fc9589b

Please sign in to comment.