Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Binh Vu committed Jan 22, 2024
1 parent d497560 commit 0575363
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 35 deletions.
4 changes: 2 additions & 2 deletions kgdata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from kgdata.config import init_dbdir_from_env
from kgdata.misc.query import PropQuery, every
from kgdata.spark import ExtendedRDD, SparkLikeInterface, get_spark_context
from kgdata.spark.common import does_result_dir_exist
from kgdata.spark.common import does_result_dir_exist, text_file
from kgdata.spark.extended_rdd import DatasetSignature

V = TypeVar("V")
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_files(
return files

def get_rdd(self) -> RDD[T_co]:
rdd = get_spark_context().textFile(str(self.file_pattern))
rdd = text_file(Path(self.file_pattern))
if self.prefilter is not None:
rdd = rdd.filter(self.prefilter)

Expand Down
1 change: 0 additions & 1 deletion kgdata/dbpedia/datasets/entity_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def entity_types(lang: str = "en") -> Dataset[tuple[str, list[str]]]:
.get_extended_rdd()
.map(get_instanceof)
.map(orjson.dumps)
.auto_coalesce(cache=True)
.save_like_dataset(ds, auto_coalesce=True, shuffle=True)
)

Expand Down
2 changes: 1 addition & 1 deletion kgdata/dbpedia/datasets/generic_extractor_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def generic_extractor_dump(lang: str = "en") -> Dataset[RDFResource]:
)

(
ExtendedRDD.textFile(str(split_dump_dir / "*/*.gz"))
ExtendedRDD.textFile(split_dump_dir / "*/*.gz")
.filter(ignore_comment)
.map(ntriple_loads)
.groupBy(lambda x: x[0])
Expand Down
2 changes: 1 addition & 1 deletion kgdata/dbpedia/datasets/mapping_extractor_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def mapping_extractor_dump(lang: str = "en") -> Dataset[RDFResource]:
)

(
ExtendedRDD.textFile(str(split_dump_dir / "*/*.gz"))
ExtendedRDD.textFile(split_dump_dir / "*/*.gz")
.filter(ignore_comment)
.map(ntriple_loads)
.groupBy(lambda x: x[0])
Expand Down
5 changes: 3 additions & 2 deletions kgdata/dbpedia/datasets/ontology_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from functools import lru_cache
from typing import Any, Callable, Iterable

from rdflib import OWL, RDF, RDFS, BNode, URIRef

from kgdata.dataset import Dataset
from kgdata.dbpedia.config import DBpediaDirCfg
from kgdata.misc.ntriples_parser import Triple, ignore_comment, ntriple_loads
from kgdata.misc.resource import RDFResource
from kgdata.spark import ExtendedRDD
from kgdata.splitter import split_a_file, split_a_list
from rdflib import OWL, RDF, RDFS, BNode, URIRef

rdf_type = str(RDF.type)
rdfs_label = str(RDFS.label)
Expand Down Expand Up @@ -65,7 +66,7 @@ def ontology_dump() -> Dataset[RDFResource]:
)

(
ExtendedRDD.textFile(str(step1_dir / "*.gz"))
ExtendedRDD.textFile(step1_dir / "*.gz")
.filter(ignore_comment)
.map(ntriple_loads)
.groupBy(lambda x: x[0])
Expand Down
2 changes: 1 addition & 1 deletion kgdata/dbpedia/datasets/redirection_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def redirection_dump(lang: str = "en"):
)

(
ExtendedRDD.textFile(str(cfg.redirection_dump / f"raw-{lang}/*.gz"))
ExtendedRDD.textFile(cfg.redirection_dump / f"raw-{lang}/*.gz")
.filter(ignore_comment)
.map(ntriple_loads)
.map(norm_redirection) # extracted redirection (source -> target)
Expand Down
6 changes: 4 additions & 2 deletions kgdata/misc/funcs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import importlib
from io import BytesIO
from io import BufferedReader, BytesIO, TextIOWrapper
from typing import Type

import zstandard as zstd
Expand Down Expand Up @@ -46,4 +46,6 @@ def import_attr(attr_ident: str):
def deser_zstd_records(dat: bytes):
cctx = zstd.ZstdDecompressor()
datobj = BytesIO(dat)
return [x.decode() for x in cctx.stream_reader(datobj).readall().splitlines()]
# readlines will result in an extra \n at the end
# we do not want this because it's different from spark implementation
return cctx.stream_reader(datobj).readall().splitlines()
96 changes: 84 additions & 12 deletions kgdata/spark/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions for Apache Spark."""
from __future__ import annotations

import hashlib
import math
import os
import random
Expand All @@ -20,6 +21,7 @@
Tuple,
TypeVar,
Union,
cast,
)

import orjson
Expand All @@ -32,7 +34,7 @@

# SparkContext singleton
_sc = None

StrPath = Union[Path, str]

R1 = TypeVar("R1")
R2 = TypeVar("R2")
Expand Down Expand Up @@ -506,26 +508,47 @@ def save_as_text_file(

def save_partition(partition: Iterable[str] | Iterable[bytes]):
partition_id = assert_not_null(TaskContext.get()).partitionId()
lst = []
it = iter(partition)
first_val = next(it, None)
if first_val is None:
# empty partition
return

lst = []

if isinstance(first_val, str):
lst.append(first_val.encode())
first_val = first_val.encode()
if not first_val.endswith(b"\n"):
lst.append(first_val + b"\n")
else:
lst.append(first_val)

for x in it:
lst.append(x.encode()) # type: ignore
x = x.encode() # type: ignore
if not x.endswith(b"\n"):
x = x + b"\n"
lst.append(x)
else:
lst.append(first_val)
if not first_val.endswith(b"\n"):
lst.append(first_val + b"\n")
else:
lst.append(first_val)
for x in it:
if not x.endswith(b"\n"):
x = x + b"\n"
lst.append(x)
datasize = sum(len(x) + 1 for x in lst) # 1 for newline

lst[-1] = lst[-1][:-1] # exclude last \n

datasize = sum(len(x) for x in lst)
cctx = zstd.ZstdCompressor(level=compression_level, write_content_size=True)

with open(outdir / f"part-{partition_id:05d}.zst", "wb") as fh:
with cctx.stream_writer(fh, size=datasize) as f:
for record in lst:
f.write(record)
f.write(b"\n")
for x in lst:
f.write(x)

outdir.mkdir(parents=True, exist_ok=True)
rdd.foreachPartition(save_partition)
(outdir / "_SUCCESS").touch()
return
Expand All @@ -534,7 +557,7 @@ def save_partition(partition: Iterable[str] | Iterable[bytes]):


def text_file(
filepattern: Path, min_partitions: Optional[int] = None, use_unicode: bool = True
filepattern: StrPath, min_partitions: Optional[int] = None, use_unicode: bool = True
):
"""Drop-in replacement for SparkContext.textFile that supports zstd files."""
filepattern = Path(filepattern)
Expand All @@ -546,15 +569,64 @@ def text_file(
for file in filepattern.iterdir()
)
) or filepattern.name.endswith(".zst"):
if filepattern.is_dir():
n_parts = sum(
1 for file in filepattern.iterdir() if file.name.startswith("part-")
)
else:
n_parts = sum(1 for _ in filepattern.parent.glob("*.zst"))

return (
get_spark_context()
.binaryFiles(str(filepattern), min_partitions)
.flatMap(lambda x: deser_zstd_records(x[1]))
.binaryFiles(str(filepattern))
.repartition(n_parts)
.flatMap(lambda x: deser_zstd_records(x[1]), preservesPartitioning=True)
)

return get_spark_context().textFile(str(filepattern), min_partitions, use_unicode)


def diff_rdd(rdd1: RDD[str], rdd2: RDD[str], key: Callable[[str], str]):
"""Compare content of two RDDs
Parameters
----------
rdd1 : RDD[str]
first RDD
rdd2 : RDD[str]
second RDD
key : Callable[[str], str]
function that extract key from a record
Returns
-------
RDD[str]
records that are in rdd1 but not in rdd2
"""

def convert(x):
k = key(x)
if not isinstance(x, bytes):
x = x.encode()
return k, hashlib.sha256(x).digest().hex()

max_size = 100
records = (
rdd1.map(convert)
.fullOuterJoin(rdd2.map(convert))
.filter(lambda x: x[1][0] != x[1][1])
.take(max_size)
)
if len(records) == 0:
print("No difference")
return
print(
f"Found {'at least' if len(records) >= max_size else ''} {len(records)} difference:"
)
for r in records:
print(r[0], r[1][0], r[1][1])


@dataclass
class EmptyBroadcast(Generic[V]):
value: V
Expand Down
13 changes: 8 additions & 5 deletions kgdata/spark/extended_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from kgdata.misc.funcs import deser_zstd_records
from kgdata.spark.common import (
StrPath,
are_records_unique,
estimate_num_partitions,
get_spark_context,
Expand Down Expand Up @@ -57,7 +58,6 @@ def __lt__(self, other: SupportsOrdering) -> bool:


S = TypeVar("S", bound=SupportsOrdering)
StrPath = Union[Path, str]
NEW_DATASET_NAME = "__new__"
NO_CHECKSUM = (b"\x00" * 32).hex()

Expand Down Expand Up @@ -335,9 +335,11 @@ def save_as_dataset(
save_as_text_file(self.rdd, Path(outdir), compression, compression_level)
else:
tmp_dir = str(outdir) + "_tmp"
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
save_as_text_file(self.rdd, Path(tmp_dir), compression, compression_level)

rdd = text_file(Path(tmp_dir))
rdd = text_file(tmp_dir)
num_partitions = math.ceil(
sum((os.path.getsize(file) for file in glob.glob(tmp_dir + "/part-*")))
/ partition_size
Expand All @@ -356,9 +358,7 @@ def save_as_dataset(
name = name or os.path.basename(outdir)
if checksum:
# compute checksum and save it to a file -- reload from the file so we do not have to process the data again.
ds_checksum = ExtendedRDD(
get_spark_context().textFile(outdir), self.sig
).hash()
ds_checksum = ExtendedRDD(text_file(outdir), self.sig).hash()
else:
ds_checksum = b"\x00" * 32

Expand Down Expand Up @@ -543,6 +543,9 @@ def parallelize(
def take(self: ExtendedRDD[T], num: int) -> list[T]:
return self.rdd.take(num)

def count(self) -> int:
return self.rdd.count()

def union(self: ExtendedRDD[T], other: ExtendedRDD[U]) -> ExtendedRDD[T | U]:
return ExtendedRDD(self.rdd.union(other.rdd), self.sig.use(other.sig))

Expand Down
14 changes: 12 additions & 2 deletions kgdata/wikidata/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def entities(lang: str = "en") -> Dataset[WDEntity]:
dependencies=[entity_dump()],
)
fixed_ds = Dataset(
cfg.entities / lang / "*.gz",
cfg.entities / lang / "*.zst",
deserialize=deser_entity,
name=f"entities/{lang}/fixed",
dependencies=[entity_dump(), entity_ids(), entity_redirections()],
Expand Down Expand Up @@ -130,6 +130,7 @@ def entities(lang: str = "en") -> Dataset[WDEntity]:
auto_coalesce=True,
shuffle=True,
trust_dataset_dependencies=True,
compression_level=9,
)
)
need_verification = True
Expand Down Expand Up @@ -300,5 +301,14 @@ def extract_invalid_qualifier(ent: WDEntity) -> Optional[WDEntity]:


if __name__ == "__main__":
WikidataDirCfg.init("~/kgdata/wikidata/20211213")
WikidataDirCfg.init("/var/tmp/kgdata/wikidata/20230619")
# from sm.misc.ray_helper import ray_map

# def deser(infile):
# (infile, "rb")

# ray_map(
# (WikidataDirCfg.get_instance().entities / "en").glob("*.zst"),
# )

print("Total:", entities().get_rdd().count())
2 changes: 1 addition & 1 deletion kgdata/wikidata/datasets/entity_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def _record_postprocess(record: str):

if __name__ == "__main__":
WikidataDirCfg.init("/var/tmp/kgdata/wikidata/20230619")
entity_dump()
print(entity_dump().get_extended_rdd().count())
1 change: 1 addition & 0 deletions kgdata/wikidata/datasets/entity_redirections.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def entity_redirections() -> Dataset[tuple[str, str]]:
dataset=unk_target_ds,
checksum=False,
auto_coalesce=True,
trust_dataset_dependencies=True,
)
)

Expand Down
7 changes: 4 additions & 3 deletions kgdata/wikidata/datasets/meta_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@
from typing import Dict, Iterable, List, Optional, Tuple, TypeAlias, Union

import orjson
from sm.misc.funcs import filter_duplication

from kgdata.dataset import Dataset
from kgdata.wikidata.config import WikidataDirCfg
from kgdata.wikidata.datasets.entities import entities
from kgdata.wikidata.datasets.entity_outlinks import entity_outlinks
from kgdata.wikidata.datasets.entity_types import entity_types
from kgdata.wikidata.models.wdentity import WDEntity
from kgdata.wikidata.models.wdvalue import WDValue, WDValueKind
from sm.misc.funcs import filter_duplication


def meta_graph():
cfg = WikidataDirCfg.get_instance()

ds = Dataset(
cfg.entity_types / "*.gz",
cfg.meta_graph / "*.gz",
deserialize=orjson.loads,
name="entity-types",
name="meta-graph",
dependencies=[entities(), entity_outlinks(), entity_types()],
)

Expand Down
Empty file.
2 changes: 1 addition & 1 deletion scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ wikidata_dataset entity_outlinks
wikidata_dataset entity_pagerank
wikidata_dataset entity_wiki_aliases

wikidata_dataset main_property_connections
wikidata_dataset meta_graph

# ======================================================================
# WIKIPEDIA Datasets
Expand Down
Loading

0 comments on commit 0575363

Please sign in to comment.