Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[push_to_hub] Add data_files in yaml #1

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from collections import Counter
from collections.abc import Mapping
from copy import deepcopy
from fnmatch import fnmatch
from functools import partial, wraps
from io import BytesIO
from math import ceil, floor
Expand Down Expand Up @@ -64,6 +65,7 @@
from . import config
from .arrow_reader import ArrowReader
from .arrow_writer import ArrowWriter, OptimizedTypedSequence
from .data_files import sanitize_patterns
from .download.download_config import DownloadConfig
from .download.streaming_download_manager import xgetsize
from .features import Audio, ClassLabel, Features, Image, Sequence, Value
Expand Down Expand Up @@ -112,7 +114,7 @@
from .utils.hub import hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.metadata import DatasetMetadata, MetadataConfigs
from .utils.py_utils import asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.py_utils import asdict, convert_file_size_to_int, iflatmap_unordered, string_to_dict, unique_values
from .utils.stratify import stratified_shuffle_split_generate_indices
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
from .utils.typing import PathLike
Expand All @@ -133,6 +135,10 @@

logger = logging.get_logger(__name__)

PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED = (
"data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.parquet"
)


class DatasetInfoMixin:
"""This base class exposes some attributes of DatasetInfo
Expand Down Expand Up @@ -5316,7 +5322,8 @@ def push_to_hub(
raise ValueError(
"Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both."
)
data_dir = f"{config_name}/data" if config_name != "default" else "data" # for backward compatibility
data_dir = config_name if config_name != "default" else "data" # for backward compatibility

repo_id, split, uploaded_size, dataset_nbytes, repo_files, deleted_size = self._push_parquet_shards_to_hub(
repo_id=repo_id,
data_dir=data_dir,
Expand Down Expand Up @@ -5349,10 +5356,12 @@ def push_to_hub(
)
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path))
dataset_infos: DatasetInfosDict = DatasetInfosDict.from_metadata(dataset_metadata)
metadata_configs = MetadataConfigs.from_metadata(dataset_metadata)
repo_info = dataset_infos.get(config_name, None)
# get the deprecated dataset_infos.json to update them
elif config.DATASETDICT_INFOS_FILENAME in repo_files:
dataset_metadata = DatasetMetadata()
metadata_configs = MetadataConfigs()
download_config = DownloadConfig()
download_config.download_desc = "Downloading metadata"
download_config.use_auth_token = token
Expand All @@ -5366,6 +5375,7 @@ def push_to_hub(
repo_info = DatasetInfo.from_dict(dataset_info) if dataset_info else None
else:
dataset_metadata = DatasetMetadata()
metadata_configs = MetadataConfigs()
repo_info = None
# update the total info to dump from existing info
if repo_info is not None:
Expand All @@ -5388,6 +5398,46 @@ def push_to_hub(
split, num_bytes=dataset_nbytes, num_examples=len(self), dataset_name=dataset_name
)
info_to_dump = repo_info
# create the metadata configs if it was uploaded with push_to_hub before metadata configs existed
if not metadata_configs:
_matched_paths = [
p
for p in repo_files
if fnmatch(p, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*"))
]
if len(_matched_paths) > 0:
# it was uploaded with push_to_hub before metadata configs existed
_resolved_splits = {
string_to_dict(p, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED)["split"]
for p in _matched_paths
}
default_metadata_configs_to_dump = {
"data_files": [
{"split": _resolved_split, "pattern": f"data/{_resolved_split}-*"}
for _resolved_split in _resolved_splits
]
}
MetadataConfigs({"default": default_metadata_configs_to_dump}).to_metadata(dataset_metadata)
# update the metadata configs
if config_name in metadata_configs:
metadata_config = metadata_configs[config_name]
if "data_files" in metadata_config:
data_files_to_dump = sanitize_patterns(metadata_config["data_files"])
else:
data_files_to_dump = {}
# add the new split
data_files_to_dump[split] = f"{data_dir}/{split}-*"
metadata_config_to_dump = {
"data_files": [
{
"split": _split,
"pattern": _pattern[0] if isinstance(_pattern, list) and len(_pattern) == 1 else _pattern,
}
for _split, _pattern in data_files_to_dump.items()
]
}
else:
metadata_config_to_dump = {"data_files": [{"split": split, "pattern": f"{data_dir}/{split}-*"}]}
# push to the deprecated dataset_infos.json
if config.DATASETDICT_INFOS_FILENAME in repo_files:
download_config = DownloadConfig()
Expand All @@ -5412,9 +5462,7 @@ def push_to_hub(
)
# push to README
DatasetInfosDict({config_name: info_to_dump}).to_metadata(dataset_metadata)
MetadataConfigs({config_name: {"data_dir": config_name if config_name != "default" else "./"}}).to_metadata(
dataset_metadata
)
MetadataConfigs({config_name: metadata_config_to_dump}).to_metadata(dataset_metadata)
if "README.md" in repo_files:
with open(dataset_readme_path, encoding="utf-8") as readme_file:
readme_content = readme_file.read()
Expand Down
25 changes: 22 additions & 3 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,35 @@ def sanitize_patterns(patterns: Union[Dict, List, str]) -> Dict[str, Union[List[
The default split is "train".

Returns:
patterns: dictionary of split_name -> list_of _atterns
patterns: dictionary of split_name -> list_of patterns
"""
if isinstance(patterns, dict):
return {str(key): value if isinstance(value, list) else [value] for key, value in patterns.items()}
elif isinstance(patterns, str):
return {SANITIZED_DEFAULT_SPLIT: [patterns]}
elif isinstance(patterns, list):
return {SANITIZED_DEFAULT_SPLIT: patterns}
if any(isinstance(pattern, dict) for pattern in patterns):
for pattern in patterns:
if not isinstance(pattern, dict) or sorted(pattern) != ["pattern", "split"]:
raise ValueError(
f"Expected data_files in YAML to be a string or a list, but got {pattern}\nExamples:\n"
" data_files: data.csv\n data_files: data/*.png\n"
" data_files:\n - part0/*\n - part1/*\n"
" data_files:\n - split: train\n pattern: train/*\n - split: test\n pattern: test/*"
)
splits = [pattern["split"] for pattern in patterns]
polinaeterna marked this conversation as resolved.
Show resolved Hide resolved
if len(set(splits)) != len(splits):
raise ValueError(f"Some splits are duplicated in data_files: {splits}")
return {
str(pattern["split"]): pattern["pattern"]
if isinstance(pattern["pattern"], list)
else [pattern["pattern"]]
for pattern in patterns
}
else:
return {SANITIZED_DEFAULT_SPLIT: patterns}
else:
return {SANITIZED_DEFAULT_SPLIT: list(patterns)}
return sanitize_patterns(list(patterns))


def _is_inside_unrequested_special_dir(matched_rel_path: str, pattern: str) -> bool:
Expand Down
72 changes: 53 additions & 19 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import posixpath
import re
import warnings
from fnmatch import fnmatch
from io import BytesIO
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand All @@ -16,7 +17,7 @@
from datasets.utils.metadata import DatasetMetadata, MetadataConfigs

from . import config
from .arrow_dataset import Dataset
from .arrow_dataset import PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED, Dataset
from .download import DownloadConfig
from .features import Features
from .features.features import FeatureType
Expand All @@ -30,7 +31,7 @@
from .utils.doc_utils import is_documented_by
from .utils.file_utils import cached_path
from .utils.hub import hf_hub_url
from .utils.py_utils import asdict
from .utils.py_utils import asdict, string_to_dict
from .utils.typing import PathLike


Expand Down Expand Up @@ -1580,7 +1581,7 @@ def push_to_hub(
if not re.match(_split_re, split):
raise ValueError(f"Split name should match '{_split_re}' but got '{split}'.")

data_dir = f"{config_name}/data" if config_name != "default" else "data" # for backward compatibility
data_dir = config_name if config_name != "default" else "data" # for backward compatibility
for split in self.keys():
logger.warning(f"Pushing split {split} to the Hub.")
# The split=key needs to be removed before merging
Expand All @@ -1603,18 +1604,62 @@ def push_to_hub(
info_to_dump.dataset_size = total_dataset_nbytes
info_to_dump.size_in_bytes = total_uploaded_size + total_dataset_nbytes

metadata_config_to_dump = {"data_files": [{"split": split, "pattern": f"{data_dir}/{split}-*"}]}

api = HfApi(endpoint=config.HF_ENDPOINT)
repo_files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token)

# push to the deprecated dataset_infos.json
if config.DATASETDICT_INFOS_FILENAME in repo_files:
# get the info from the README to update them
if "README.md" in repo_files:
download_config = DownloadConfig()
download_config.download_desc = "Downloading metadata"
download_config.use_auth_token = token
dataset_readme_path = cached_path(
hf_hub_url(repo_id, "README.md"),
download_config=download_config,
)
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path))
dataset_infos: DatasetInfosDict = DatasetInfosDict.from_metadata(dataset_metadata)
metadata_configs = MetadataConfigs.from_metadata(dataset_metadata)
# get the deprecated dataset_infos.json to update them
elif config.DATASETDICT_INFOS_FILENAME in repo_files:
dataset_metadata = DatasetMetadata()
metadata_configs = MetadataConfigs()
download_config = DownloadConfig()
download_config.download_desc = "Updating deprecated dataset_infos.json"
download_config.download_desc = "Downloading metadata"
download_config.use_auth_token = token
dataset_infos_path = cached_path(
hf_hub_url(repo_id, config.DATASETDICT_INFOS_FILENAME),
download_config=download_config,
)
with open(dataset_infos_path, encoding="utf-8") as f:
dataset_infos: dict = json.load(f)
dataset_infos.get(config_name, None) if dataset_infos else None
else:
dataset_metadata = DatasetMetadata()
metadata_configs = MetadataConfigs()
# create the metadata configs if it was uploaded with push_to_hub before metadata configs existed
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it, thanks!

if not metadata_configs:
_matched_paths = [
p
for p in repo_files
if fnmatch(p, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*"))
]
if len(_matched_paths) > 0:
# it was uploaded with push_to_hub before metadata configs existed
_resolved_splits = {
string_to_dict(p, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED)["split"]
for p in _matched_paths
}
default_metadata_configs_to_dump = {
"data_files": [
{"split": _resolved_split, "pattern": f"data/{_resolved_split}-*"}
for _resolved_split in _resolved_splits
]
}
MetadataConfigs({"default": default_metadata_configs_to_dump}).to_metadata(dataset_metadata)
# push to the deprecated dataset_infos.json
if config.DATASETDICT_INFOS_FILENAME in repo_files:
with open(dataset_infos_path, encoding="utf-8") as f:
dataset_infos: DatasetInfosDict = json.load(f)
dataset_infos[config_name] = asdict(info_to_dump)
Expand All @@ -1629,24 +1674,13 @@ def push_to_hub(
revision=branch,
)
# push to README
DatasetInfosDict({config_name: info_to_dump}).to_metadata(dataset_metadata)
MetadataConfigs({config_name: metadata_config_to_dump}).to_metadata(dataset_metadata)
if "README.md" in repo_files:
download_config = DownloadConfig()
download_config.download_desc = "Downloading metadata"
download_config.use_auth_token = token
dataset_readme_path = cached_path(
hf_hub_url(repo_id, "README.md"),
download_config=download_config,
)
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path))
with open(dataset_readme_path, encoding="utf-8") as readme_file:
readme_content = readme_file.read()
else:
dataset_metadata = DatasetMetadata()
readme_content = f'# Dataset Card for "{repo_id.split("/")[-1]}"\n\n[More Information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)'
DatasetInfosDict({config_name: info_to_dump}).to_metadata(dataset_metadata)
MetadataConfigs({config_name: {"data_dir": config_name if config_name != "default" else "./"}}).to_metadata(
dataset_metadata
)
HfApi(endpoint=config.HF_ENDPOINT).upload_file(
path_or_fileobj=dataset_metadata._to_readme(readme_content).encode(),
path_in_repo="README.md",
Expand Down
12 changes: 10 additions & 2 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def __call__(self, builder_cls, metadata_configs, name):


def configure_builder_class(
builder_cls: Type[DatasetBuilder], builder_configs: List["BulderConfig"], dataset_name: str
builder_cls: Type[DatasetBuilder],
builder_configs: List["BulderConfig"],
default_config_name: Optional[str],
dataset_name: str,
) -> Type[DatasetBuilder]:
"""
Dynamically create a builder class with custom builder configs parsed from README.md file,
Expand All @@ -162,6 +165,7 @@ def configure_builder_class(

class ConfiguredDatasetBuilder(builder_cls):
BUILDER_CONFIGS = builder_configs
DEFAULT_CONFIG_NAME = default_config_name

__module__ = builder_cls.__module__ # so that the actual packaged builder can be imported

Expand All @@ -187,14 +191,18 @@ def __reduce__(self): # to make dynamically created class pickable, see _Initia
return ConfiguredDatasetBuilder


def get_dataset_builder_class(dataset_module, dataset_name: Optional[str] = None) -> Type[DatasetBuilder]:
def get_dataset_builder_class(
dataset_module: "DatasetModule", dataset_name: Optional[str] = None
) -> Type[DatasetBuilder]:
builder_cls = import_main_class(dataset_module.module_path)
if dataset_module.metadata_configs:
config_cls = builder_cls.BUILDER_CONFIG_CLASS
builder_configs_list = dataset_module.metadata_configs.to_builder_configs(builder_config_cls=config_cls)
default_config_name = dataset_module.metadata_configs.get_default_config_name()
builder_cls = configure_builder_class(
builder_cls,
builder_configs=builder_configs_list,
default_config_name=default_config_name,
dataset_name=dataset_name,
)
return builder_cls
Expand Down
Loading