Skip to content

Commit

Permalink
fix pagerank dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Binh Vu committed Sep 7, 2023
1 parent 3f89215 commit 92fe071
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
23 changes: 13 additions & 10 deletions kgdata/wikidata/datasets/entity_pagerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,15 @@ def entity_pagerank() -> Dataset[EntityPageRank]:
.sortBy(lambda x: (x[0], int(x[1:]))) # type: ignore
.zipWithIndex()
.map(tab_ser)
.save_like_dataset(idmap_ds, checksum=False)
.save_like_dataset(
idmap_ds,
checksum=False,
auto_coalesce=True,
max_num_partitions=512,
)
)
# write the total number of entity
(cfg.entity_pagerank / "idmap.txt").write_text(str(idmap_ds.get_rdd().count()))

edges_dataset = Dataset(
cfg.entity_pagerank / "graph/*.gz",
Expand Down Expand Up @@ -137,12 +144,10 @@ def create_edges_npy(infiles: List[str], outfile: str):
)
if not pagerank_ds.has_complete_data():
assert does_result_dir_exist(
cfg.entity_pagerank / "graphtool_pagerank_en", allow_override=False
cfg.entity_pagerank / "graphtool_pagerank", allow_override=False
), "Must run graph-tool pagerank at `kgdata/scripts/pagerank_v2.py` first"

n_files = len(
glob(str(cfg.entity_pagerank / "graphtool_pagerank_en" / "*.npz"))
)
n_files = len(glob(str(cfg.entity_pagerank / "graphtool_pagerank" / "*.npz")))

def deserialize_np(dat: bytes) -> List[Tuple[int, float]]:
f = BytesIO(dat)
Expand All @@ -158,7 +163,7 @@ def process_join(

(
ExtendedRDD.binaryFiles(
cfg.entity_pagerank / "graphtool_pagerank_en" / "*.npz",
cfg.entity_pagerank / "graphtool_pagerank" / "*.npz",
)
.repartition(n_files)
.flatMap(lambda x: deserialize_np(x[1]))
Expand All @@ -170,9 +175,7 @@ def process_join(

pagerank_stat_outfile = cfg.entity_pagerank / f"pagerank.pkl"
if not pagerank_stat_outfile.exists():
n_files = len(
glob(str(cfg.entity_pagerank / "graphtool_pagerank_en" / "*.npz"))
)
n_files = len(glob(str(cfg.entity_pagerank / "graphtool_pagerank" / "*.npz")))

def deserialize_np2(dat: bytes) -> np.ndarray:
f = BytesIO(dat)
Expand All @@ -182,7 +185,7 @@ def deserialize_np2(dat: bytes) -> np.ndarray:
rdd = (
get_spark_context()
.binaryFiles(
str(cfg.entity_pagerank / "graphtool_pagerank_en" / "*.npz"),
str(cfg.entity_pagerank / "graphtool_pagerank" / "*.npz"),
)
.repartition(n_files)
.map(lambda x: deserialize_np2(x[1]))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kgdata"
version = "5.0.0a6"
version = "5.0.0a7"
description = "Library to process dumps of knowledge graphs (Wikipedia, DBpedia, Wikidata)"
readme = "README.md"
authors = [{ name = "Binh Vu", email = "binh@toan2.com" }]
Expand Down
6 changes: 3 additions & 3 deletions scripts/pagerank_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

wd_dir = Path(os.environ["WD_DIR"])
working_dir = wd_dir / "entity_pagerank"
pagerank_outdir = working_dir / "graphtool_pagerank_en"
edge_files = sorted(glob.glob(str((working_dir / "graphtool_en" / "part-*.npz"))))
pagerank_outdir = working_dir / "graphtool_pagerank"
edge_files = sorted(glob.glob(str((working_dir / "graphtool" / "part-*.npz"))))

logger.info("Creating graph from data...")
g = Graph(directed=True)
n_vertices = int((working_dir / "idmap_en.txt").read_text())
n_vertices = int((working_dir / "idmap.txt").read_text())
g.add_vertex(n=n_vertices)
logger.info("Creating graph... added vertices")

Expand Down

0 comments on commit 92fe071

Please sign in to comment.