From 70b38a7a4c1bcb8e07cc7780643b5b9c7a63d33c Mon Sep 17 00:00:00 2001 From: Kyle Edwards Date: Mon, 24 Jun 2024 14:52:06 -0400 Subject: [PATCH] Use Pydantic instead of asdict() We will use Pydantic later on for JSON decoding, so use it for encoding now. --- .pre-commit-config.yaml | 3 +++ pyproject.toml | 1 + src/rapids_metadata/json.py | 22 +++++----------------- tests/test_json.py | 3 ++- 4 files changed, 11 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e1ba1c..3c6de3e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,6 +43,7 @@ repos: pass_filenames: false additional_dependencies: - packaging + - pydantic - repo: local hooks: - id: generate-json @@ -50,6 +51,8 @@ repos: entry: ./ci/generate_json.py language: python pass_filenames: false + additional_dependencies: + - pydantic default_language_version: python: python3 diff --git a/pyproject.toml b/pyproject.toml index 0549b87..e2d1874 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "packaging", + "pydantic", ] [project.scripts] diff --git a/src/rapids_metadata/json.py b/src/rapids_metadata/json.py index e33c1bf..e83dc59 100644 --- a/src/rapids_metadata/json.py +++ b/src/rapids_metadata/json.py @@ -13,19 +13,15 @@ # limitations under the License. import argparse -import dataclasses import json import os import sys -from typing import Any, Union +from typing import Union + +from pydantic import TypeAdapter from . import all_metadata -from .metadata import ( - RAPIDSMetadata, - RAPIDSPackage, - RAPIDSRepository, - RAPIDSVersion, -) +from .metadata import RAPIDSMetadata from .rapids_version import get_rapids_version @@ -34,13 +30,6 @@ ] -class _RAPIDSMetadataEncoder(json.JSONEncoder): - def default( - self, o: Union[RAPIDSMetadata, RAPIDSPackage, RAPIDSRepository, RAPIDSVersion] - ) -> dict[str, Any]: - return dataclasses.asdict(o) - - def main(argv: Union[list[str], None] = None): if argv is None: argv = sys.argv[1:] @@ -77,9 +66,8 @@ def main(argv: Union[list[str], None] = None): def write_file(f): json.dump( - metadata, + TypeAdapter(RAPIDSMetadata).dump_python(metadata), f, - cls=_RAPIDSMetadataEncoder, sort_keys=True, separators=(",", ": ") if parsed.pretty else (",", ":"), indent=" " if parsed.pretty else None, diff --git a/tests/test_json.py b/tests/test_json.py index 5025c85..6b7d719 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -19,6 +19,7 @@ from unittest.mock import patch import pytest +from pydantic import TypeAdapter from rapids_metadata import json as rapids_json from rapids_metadata.metadata import ( RAPIDSMetadata, @@ -133,7 +134,7 @@ def set_cwd(cwd: os.PathLike) -> Generator: ], ) def test_metadata_encoder(unencoded, encoded): - assert rapids_json._RAPIDSMetadataEncoder().default(unencoded) == encoded + assert TypeAdapter(type(unencoded)).dump_python(unencoded) == encoded @pytest.mark.parametrize(