Skip to content

Commit

Permalink
Generate JSON schema (#16)
Browse files Browse the repository at this point in the history
This schema can be used to validate JSON data.
  • Loading branch information
KyleFromNVIDIA authored Jun 27, 2024
1 parent 9ca0136 commit 5407752
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 31 deletions.
13 changes: 11 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ repos:
- id: trailing-whitespace
exclude: |
(?x)
rapids-metadata[.]json$
rapids-metadata[.]json$|
schemas/rapids-metadata-v[0-9]+[.]json$
- id: end-of-file-fixer
exclude: |
(?x)
rapids-metadata[.]json$
rapids-metadata[.]json$|
schemas/rapids-metadata-v[0-9]+[.]json$
- repo: https://github.com/rapidsai/dependency-file-generator
rev: v1.13.11
hooks:
Expand Down Expand Up @@ -53,6 +55,13 @@ repos:
pass_filenames: false
additional_dependencies:
- pydantic
- id: generate-json-schema
name: generate-json-schema
entry: ./ci/generate_json_schema.py
language: python
pass_filenames: false
additional_dependencies:
- pydantic

default_language_version:
python: python3
2 changes: 1 addition & 1 deletion ci/generate_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
repo_root = os.path.join(os.path.dirname(__file__), "..")
sys.path.append(os.path.join(repo_root, "src"))

from rapids_metadata import json as rapids_json # noqa: E402
import rapids_metadata.json as rapids_json # noqa: E402

if __name__ == "__main__":
rapids_json.main(
Expand Down
20 changes: 20 additions & 0 deletions ci/generate_json_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env python3
# Copyright (c) 2024, NVIDIA CORPORATION.

import os.path
import sys

repo_root = os.path.join(os.path.dirname(__file__), "..")
sys.path.append(os.path.join(repo_root, "src"))

import rapids_metadata.json as rapids_json # noqa: E402

if __name__ == "__main__":
rapids_json.main(
[
"--output",
os.path.join(repo_root, "schemas/rapids-metadata-v1.json"),
"--pretty",
"--schema",
]
)
67 changes: 67 additions & 0 deletions schemas/rapids-metadata-v1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"$defs": {
"RAPIDSPackage": {
"description": "Package published by a RAPIDS repository. Includes both Python packages and Conda packages.",
"properties": {
"has_cuda_suffix": {
"default": true,
"description": "Whether or not the package has a CUDA suffix.",
"title": "Has Cuda Suffix",
"type": "boolean"
},
"publishes_prereleases": {
"default": true,
"description": "Whether or not the package publishes prereleases.",
"title": "Publishes Prereleases",
"type": "boolean"
}
},
"title": "RAPIDSPackage",
"type": "object"
},
"RAPIDSRepository": {
"description": "RAPIDS Git repository. Can publish more than one package.",
"properties": {
"packages": {
"additionalProperties": {
"$ref": "#/$defs/RAPIDSPackage"
},
"description": "Dictionary of packages in this repository by name.",
"title": "Packages",
"type": "object"
}
},
"title": "RAPIDSRepository",
"type": "object"
},
"RAPIDSVersion": {
"description": "Version of RAPIDS, which contains many Git repositories.",
"properties": {
"repositories": {
"additionalProperties": {
"$ref": "#/$defs/RAPIDSRepository"
},
"description": "Dictionary of repositories in this version by name.",
"title": "Repositories",
"type": "object"
}
},
"title": "RAPIDSVersion",
"type": "object"
}
},
"$id": "https://raw.githubusercontent.com/rapidsai/rapids-metadata/main/schemas/rapids-metadata-v1.json",
"description": "All RAPIDS metadata.",
"properties": {
"versions": {
"additionalProperties": {
"$ref": "#/$defs/RAPIDSVersion"
},
"description": "Dictionary of RAPIDS versions by <major>.<minor> version string.",
"title": "Versions",
"type": "object"
}
},
"title": "RAPIDSMetadata",
"type": "object"
}
43 changes: 27 additions & 16 deletions src/rapids_metadata/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import os
import sys
from typing import Union
from typing import Any, TextIO, Union

from pydantic import TypeAdapter

Expand All @@ -41,6 +41,11 @@ def main(argv: Union[list[str], None] = None):
action="store_true",
help="Output all versions, ignoring local VERSION file",
)
parser.add_argument(
"--schema",
action="store_true",
help="Output a JSON schema for the data instead of the data itself",
)
parser.add_argument(
"--pretty", action="store_true", help="Pretty-print JSON output"
)
Expand All @@ -52,21 +57,10 @@ def main(argv: Union[list[str], None] = None):
)

parsed = parser.parse_args(argv)
metadata = (
all_metadata
if parsed.all_versions
else RAPIDSMetadata(
versions={
get_rapids_version(os.getcwd()): all_metadata.get_current_version(
os.getcwd()
)
}
)
)

def write_file(f):
def write_file(data: dict[str, Any], f: TextIO):
json.dump(
TypeAdapter(RAPIDSMetadata).dump_python(metadata),
data,
f,
sort_keys=True,
separators=(",", ": ") if parsed.pretty else (",", ":"),
Expand All @@ -75,11 +69,28 @@ def write_file(f):
if parsed.pretty:
f.write("\n")

type_adapter = TypeAdapter(RAPIDSMetadata)
if parsed.schema:
data = type_adapter.json_schema()
else:
metadata = (
all_metadata
if parsed.all_versions
else RAPIDSMetadata(
versions={
get_rapids_version(os.getcwd()): all_metadata.get_current_version(
os.getcwd()
)
}
)
)
data = type_adapter.dump_python(metadata)

if parsed.output:
with open(parsed.output, "w") as f:
write_file(f)
write_file(data, f)
else:
write_file(sys.stdout)
write_file(data, sys.stdout)


if __name__ == "__main__":
Expand Down
52 changes: 45 additions & 7 deletions src/rapids_metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from os import PathLike
from typing import Union

from pydantic import ConfigDict, Field
from pydantic.dataclasses import dataclass

from .rapids_version import get_rapids_version

__all__ = [
Expand All @@ -28,18 +30,40 @@

@dataclass
class RAPIDSPackage:
publishes_prereleases: bool = field(default=True)
has_cuda_suffix: bool = field(default=True)
(
"""Package published by a RAPIDS repository. Includes both Python packages """
"""and Conda packages."""
)

publishes_prereleases: bool = Field(
default=True,
description="""Whether or not the package publishes prereleases.""",
)

has_cuda_suffix: bool = Field(
default=True,
description="""Whether or not the package has a CUDA suffix.""",
)


@dataclass
class RAPIDSRepository:
packages: dict[str, RAPIDSPackage] = field(default_factory=dict)
"""RAPIDS Git repository. Can publish more than one package."""

packages: dict[str, RAPIDSPackage] = Field(
default_factory=dict,
description="""Dictionary of packages in this repository by name.""",
)


@dataclass
class RAPIDSVersion:
repositories: dict[str, RAPIDSRepository] = field(default_factory=dict)
"""Version of RAPIDS, which contains many Git repositories."""

repositories: dict[str, RAPIDSRepository] = Field(
default_factory=dict,
description="""Dictionary of repositories in this version by name.""",
)

@property
def all_packages(self) -> set[str]:
Expand Down Expand Up @@ -68,9 +92,23 @@ def cuda_suffixed_packages(self) -> set[str]:
}


@dataclass
@dataclass(
config=ConfigDict(
json_schema_extra={
"$id": "https://raw.githubusercontent.com/rapidsai/rapids-metadata/main/schemas/rapids-metadata-v1.json",
},
)
)
class RAPIDSMetadata:
versions: dict[str, RAPIDSVersion] = field(default_factory=dict)
"""All RAPIDS metadata."""

versions: dict[str, RAPIDSVersion] = Field(
default_factory=dict,
description=(
"""Dictionary of RAPIDS versions by <major>.<minor> """
"""version string."""
),
)

def get_current_version(
self, directory: Union[str, PathLike[str]]
Expand Down
30 changes: 25 additions & 5 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import contextlib
import os.path
import re
from textwrap import dedent
from typing import Generator
from typing import Generator, Union
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -256,9 +257,22 @@ def test_metadata_encoder(unencoded, encoded):
"""
),
),
(
None,
["--schema"],
re.compile(
r'"\$id":"https://raw.githubusercontent.com/rapidsai/rapids-metadata/main/schemas/rapids-metadata-v1.json"'
),
),
],
)
def test_main(capsys, tmp_path, version, args, expected_json):
def test_main(
capsys: pytest.CaptureFixture[str],
tmp_path: str,
version: Union[str, None],
args: list[str],
expected_json: Union[str, re.Pattern],
):
mock_metadata = RAPIDSMetadata(
versions={
"24.08": RAPIDSVersion(
Expand All @@ -285,17 +299,23 @@ def test_main(capsys, tmp_path, version, args, expected_json):
with open(os.path.join(tmp_path, "VERSION"), "w") as f:
f.write(f"{version}\n")

def check_output(output: str):
if isinstance(expected_json, re.Pattern):
assert expected_json.search(output)
else:
assert output == expected_json

with set_cwd(tmp_path), patch("sys.argv", ["rapids-metadata-json", *args]), patch(
"rapids_metadata.json.all_metadata", mock_metadata
):
rapids_json.main()
captured = capsys.readouterr()
assert captured.out == expected_json
check_output(captured.out)

with set_cwd(tmp_path), patch("rapids_metadata.json.all_metadata", mock_metadata):
rapids_json.main(args)
captured = capsys.readouterr()
assert captured.out == expected_json
check_output(captured.out)

with set_cwd(tmp_path), patch(
"sys.argv", ["rapids-metadata-json", *args, "-o", "rapids-metadata.json"]
Expand All @@ -305,4 +325,4 @@ def test_main(capsys, tmp_path, version, args, expected_json):
written_json = f.read()
captured = capsys.readouterr()
assert captured.out == ""
assert written_json == expected_json
check_output(written_json)

0 comments on commit 5407752

Please sign in to comment.