Skip to content

Commit

Permalink
change partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 9, 2024
1 parent 0819607 commit b2907b4
Showing 1 changed file with 68 additions and 59 deletions.
127 changes: 68 additions & 59 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,25 +88,23 @@ def _dump_part_config(part_config, part_metadata):
json.dump(part_metadata, outfile, sort_keys=False, indent=4)


def _process_partitions(g_list, formats=None, sort_etypes=False):
def _process_partitions(g, formats=None, sort_etypes=False):
"""Preprocess partitions before saving:
1. format data types.
2. sort csc/csr by tag.
"""
for g in g_list:
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in g.ndata:
g.ndata[k] = F.astype(g.ndata[k], dtype)
if k in g.edata:
g.edata[k] = F.astype(g.edata[k], dtype)
for g in g_list:
if (not sort_etypes) or (formats is None):
continue
for k, dtype in RESERVED_FIELD_DTYPE.items():
if k in g.ndata:
g.ndata[k] = F.astype(g.ndata[k], dtype)
if k in g.edata:
g.edata[k] = F.astype(g.edata[k], dtype)

if (sort_etypes) and (formats is not None):
if "csr" in formats:
g = sort_csr_by_tag(g, tag=g.edata[ETYPE], tag_type="edge")
if "csc" in formats:
g = sort_csc_by_tag(g, tag=g.edata[ETYPE], tag_type="edge")
return g_list
return g


def _save_dgl_graphs(filename, g_list, formats=None):
Expand Down Expand Up @@ -475,6 +473,8 @@ def load_partition_book(part_config, part_id, part_metadata=None):
The path of the partition config file.
part_id : int
The partition ID.
part_metadata : dict
The meta data of partition.
Returns
-------
Expand Down Expand Up @@ -684,7 +684,7 @@ def _partition_to_graphbolt(
graph_formats=None,
):
gpb, _, ntypes, etypes = load_partition_book(
part_config, part_i, part_metadata
part_config=part_config, part_id=part_i, part_metadata=part_metadata
)
graph = parts[part_i]
csc_graph = gb_convert_single_dgl_partition(
Expand All @@ -698,7 +698,9 @@ def _partition_to_graphbolt(
store_inner_node=store_inner_node,
graph_formats=graph_formats,
)
rel_path_result = _save_graph_gb(part_config, part_i, csc_graph)
rel_path_result = _save_graph_gb(
part_config=part_config, part_id=part_i, csc_graph=csc_graph
)
part_metadata[f"part-{part_i}"]["part_graph_graphbolt"] = rel_path_result


Expand Down Expand Up @@ -1312,9 +1314,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue

Check warning on line 1316 in python/dgl/distributed/partition.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
else:
for ntype in g.ntypes:
if len(g.ntypes) > 1:
Expand Down Expand Up @@ -1349,9 +1351,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
# delete `orig_id` from ndata/edata
del part.ndata["orig_id"]
del part.edata["orig_id"]
Expand All @@ -1369,7 +1371,7 @@ def get_homogeneous(g, balance_ntypes):
"edge_feats": os.path.relpath(edge_feat_file, out_path),
}
sort_etypes = len(g.etypes) > 1
part = _process_partitions([part], graph_formats, sort_etypes)[0]
part = _process_partitions(part, graph_formats, sort_etypes)
if use_graphbolt:
# save FusedCSCSamplingGraph
kwargs["graph_formats"] = graph_formats
Expand All @@ -1383,9 +1385,9 @@ def get_homogeneous(g, balance_ntypes):
)
else:
part_graph_file = os.path.join(part_dir, "graph.dgl")
part_metadata["part-{}".format(part_id)][
"part_graph"
] = os.path.relpath(part_graph_file, out_path)
part_metadata["part-{}".format(part_id)]["part_graph"] = (
os.path.relpath(part_graph_file, out_path)
)
# save DGLGraph
_save_dgl_graphs(
part_graph_file,
Expand Down Expand Up @@ -1468,17 +1470,16 @@ def _load_part(part_config, part_id, parts=None):


def _save_graph_gb(part_config, part_id, csc_graph):
orig_feats_path = os.path.join(
csc_graph_save_dir = os.path.join(
os.path.dirname(part_config),
f"part{part_id}",
)
csc_graph_path = os.path.join(
orig_feats_path, "fused_csc_sampling_graph.pt"
csc_graph_save_dir, "fused_csc_sampling_graph.pt"
)
torch.save(csc_graph, csc_graph_path)

return os.path.relpath(csc_graph_path, os.path.dirname(part_config))
# Update graph path.


def cast_various_to_minimum_dtype_gb(
Expand Down Expand Up @@ -1686,8 +1687,35 @@ def gb_convert_single_dgl_partition(
return csc_graph


def convert_partition_to_graphbolt_multi_process(
part_config,
part_id,
graph_formats,
store_eids,
store_inner_node,
store_inner_edge,
):
gpb, _, ntypes, etypes = load_partition_book(
part_config=part_config, part_id=part_id
)
part = _load_part(part_config, part_id)
part_meta = copy.deepcopy(_load_part_config(part_config))
csc_graph = gb_convert_single_dgl_partition(
graph=part,
ntypes=ntypes,
etypes=etypes,
gpb=gpb,
part_meta=part_meta,
graph_formats=graph_formats,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
)
rel_path = _save_graph_gb(part_config, part_id, csc_graph)
return rel_path


def _convert_partition_to_graphbolt(
part_meta,
graph_formats,
part_config,
store_eids,
Expand All @@ -1706,11 +1734,9 @@ def _convert_partition_to_graphbolt(
# We can simply pass None to it.

# Iterate over partitions.
if part_meta is None:
part_meta = _load_part_config(part_config)
convert_with_format = partial(
gb_convert_single_dgl_partition,
part_meta=part_meta,
convert_partition_to_graphbolt_multi_process,
part_config=part_config,
graph_formats=graph_formats,
store_eids=store_eids,
store_inner_node=store_inner_node,
Expand All @@ -1727,31 +1753,16 @@ def _convert_partition_to_graphbolt(
mp_context=mp_ctx,
) as executor:
for part_id in range(num_parts):
gpb, _, ntypes, etypes = load_partition_book(
part_config, part_id
rel_path_results.append(
executor.submit(part_id=part_id).result()
)
part = _load_part(part_config, part_id)
csc_graph = executor.submit(
convert_with_format,
graph=part,
ntypes=ntypes,
etypes=etypes,
gpb=gpb,
).result()
rel_path = _save_graph_gb(part_config, part_id, csc_graph)
rel_path_results.append(rel_path)

else:
# If running single-threaded, avoid spawning new interpreter, which is slow
for part_id in range(num_parts):
gpb, _, ntypes, etypes = load_partition_book(part_config, part_id)
part = _load_part(part_config, part_id)
csc_graph = convert_with_format(
graph=part, ntypes=ntypes, etypes=etypes, gpb=gpb
)
rel_path = _save_graph_gb(part_config, part_id, csc_graph)
rel_path = convert_with_format(part_id=part_id)
rel_path_results.append(rel_path)

part_meta = _load_part_config(part_config)
for part_id in range(num_parts):
# Update graph path.
part_meta[f"part-{part_id}"]["part_graph_graphbolt"] = rel_path_results[
Expand Down Expand Up @@ -1813,16 +1824,14 @@ def dgl_partition_to_graphbolt(
" will be saved to the new format."
)
part_meta = _load_part_config(part_config)
new_part_meta = copy.deepcopy(part_meta)
num_parts = part_meta["num_parts"]
part_meta = _convert_partition_to_graphbolt(
new_part_meta,
graph_formats,
part_config,
store_eids,
store_inner_node,
store_inner_edge,
n_jobs,
num_parts,
graph_formats=graph_formats,
part_config=part_config,
store_eids=store_eids,
store_inner_node=store_inner_node,
store_inner_edge=store_inner_edge,
n_jobs=n_jobs,
num_parts=num_parts,
)
_dump_part_config(part_config, part_meta)

0 comments on commit b2907b4

Please sign in to comment.