Skip to content

Commit

Permalink
grr
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Aug 21, 2024
1 parent a11da62 commit 38e95c5
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,28 +339,30 @@ def rewrite_layer_chains(dsk: HighLevelGraph, keys: Sequence[Key]) -> HighLevelG
deps[outkey] = deps[chain[0]] # type: ignore
[deps.pop(ch) for ch in chain[:-1]] # type: ignore

subgraph = layer0.dsk.copy()
subgraph = layer0.dsk.copy() # mypy: ignore
indices = list(layer0.indices)
parent = chain[0]

outlayer.io_deps = layer0.io_deps
outlayer.io_deps = layer0.io_deps # mypy: ignore
for chain_member in chain[1:]:
layer = dsk.layers[chain_member]
for k in layer.io_deps:
for k in layer.io_deps: # mypy: ignore
outlayer.io_deps[k] = layer.io_deps[k] # type: ignore
func, *args = layer.dsk[chain_member]
func, *args = layer.dsk[chain_member] # mypy: ignore
args2 = _recursive_replace(args, layer, parent, indices)
subgraph[chain_member] = (func,) + tuple(args2)
parent = chain_member
outlayer.numblocks = {i[0]: (numblocks,) for i in indices if i[1] is not None}
outlayer.dsk = subgraph
outlayer.numblocks = {
i[0]: (numblocks,) for i in indices if i[1] is not None
} # mypy: ignore
outlayer.dsk = subgraph # mypy: ignore
if hasattr(outlayer, "_dims"):
del outlayer._dims
outlayer.indices = tuple(
outlayer.indices = tuple( # mypy: ignore
(i[0], (".0",) if i[1] is not None else None) for i in indices
)
outlayer.output_indices = (".0",)
outlayer.inputs = getattr(layer0, "inputs", set())
outlayer.output_indices = (".0",) # mypy: ignore
outlayer.inputs = getattr(layer0, "inputs", set()) # mypy: ignore
if hasattr(outlayer, "_cached_dict"):
del outlayer._cached_dict # reset, since original can be mutated
return HighLevelGraph(layers, deps)
Expand Down

0 comments on commit 38e95c5

Please sign in to comment.