Skip to content

Commit

Permalink
Use mashumaro to serialize/deserialize dataclass (#1735)
Browse files Browse the repository at this point in the history
Signed-off-by: HH <hhcs9527@gmail.com>
Signed-off-by: hhcs9527 <hhcs9527@gmail.com>
Signed-off-by: Matthew Hoffman <matthew@protopia.ai>
Co-authored-by: Matthew Hoffman <matthew@protopia.ai>
  • Loading branch information
hhcs9527 and ringohoffman committed Sep 12, 2023
1 parent 38ab18c commit e0346fb
Show file tree
Hide file tree
Showing 7 changed files with 743 additions and 50 deletions.
5 changes: 3 additions & 2 deletions doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ absl-py==1.4.0
# tensorflow
adlfs==2023.4.0
# via flytekit
aiobotocore==2.5.2
aiobotocore==2.5.3
# via s3fs
aiohttp==3.8.5
# via
Expand Down Expand Up @@ -98,7 +98,7 @@ bleach==6.0.0
# via nbconvert
blinker==1.6.2
# via flask
botocore==1.29.161
botocore==1.31.17
# via
# -r doc-requirements.in
# aiobotocore
Expand Down Expand Up @@ -1284,6 +1284,7 @@ typing-extensions==4.5.0
# flytekit
# great-expectations
# ipython
# mashumaro
# pydantic
# python-utils
# sqlalchemy
Expand Down
146 changes: 109 additions & 37 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import typing
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
Expand All @@ -22,6 +22,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -53,6 +54,7 @@

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


class BatchSize:
Expand Down Expand Up @@ -344,22 +346,28 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
f"Type {t} cannot be parsed."
)

if not issubclass(t, DataClassJsonMixin):
if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin):
raise AssertionError(
f"Dataclass {t} should be decorated with @dataclass_json or be a subclass of DataClassJsonMixin to be "
"serialized correctly"
f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
f"serialized correctly"
)
schema = None
try:
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
if issubclass(t, DataClassJsonMixin):
s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema()
for _, v in s.fields.items():
# marshmallow-jsonschema only supports enums loaded by name.
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
# check if DataClass mixin
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
else: # DataClassJSONMixin
from mashumaro.jsonschema import build_json_schema

schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict()
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
logger.warning(
Expand All @@ -376,15 +384,18 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
f"user defined datatypes in Flytekit"
)
if not issubclass(type(python_val), DataClassJsonMixin):
if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass(
type(python_val), DataClassJSONMixin
):
raise TypeTransformerFailedError(
f"Dataclass {python_type} should be decorated with @dataclass_json or be a subclass of "
"DataClassJsonMixin to be serialized correctly"
f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be "
f"serialized correctly"
)
self._serialize_flyte_type(python_val, python_type)
return Literal(
scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct()))
)

json_str = python_val.to_json() # type: ignore

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
Expand Down Expand Up @@ -628,13 +639,16 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
"user defined datatypes in Flytekit"
)
if not issubclass(expected_python_type, DataClassJsonMixin):
if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass(
expected_python_type, DataClassJSONMixin
):
raise TypeTransformerFailedError(
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or be a subclass of "
"DataClassJsonMixin to be serialized correctly"
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
f"serialized correctly"
)
json_str = _json_format.MessageToJson(lv.scalar.generic)
dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str)
dc = expected_python_type.from_json(json_str) # type: ignore

dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type))

Expand All @@ -645,10 +659,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
@lru_cache(typed=True)
def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore
if literal_type.simple == SimpleType.STRUCT:
if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata:
schema_name = literal_type.metadata["$ref"].split("/")[-1]
return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name)

if literal_type.metadata is not None:
if DEFINITIONS in literal_type.metadata:
schema_name = literal_type.metadata["$ref"].split("/")[-1]
return convert_marshmallow_json_schema_to_python_class(
literal_type.metadata[DEFINITIONS], schema_name
)
elif TITLE in literal_type.metadata:
schema_name = literal_type.metadata[TITLE]
return convert_mashumaro_json_schema_to_python_class(literal_type.metadata, schema_name)
raise ValueError(f"Dataclass transformer cannot reverse {literal_type}")


Expand Down Expand Up @@ -1563,13 +1582,45 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
raise ValueError(f"Enum transformer cannot reverse {literal_type}")


def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]:
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""
attribute_list: List[Tuple[str, type]] = []
def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list = []
for property_key, property_val in schema["properties"].items():
if property_val.get("anyOf"):
property_type = property_val["anyOf"][0]["type"]
elif property_val.get("enum"):
property_type = "enum"
else:
property_type = property_val["type"]
# Handle list
if property_type == "array":
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
# Handle dataclass and dict
elif property_type == "object":
if property_val.get("anyOf"):
sub_schemea = property_val["anyOf"][0]
sub_schemea_name = sub_schemea["title"]
attribute_list.append(
(property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name))
)
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore
)
else:
sub_schemea_name = property_val["title"]
attribute_list.append(
(property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name))
)
elif property_type == "enum":
attribute_list.append([property_key, str]) # type: ignore
# Handle int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
return attribute_list


def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any):
attribute_list = []
for property_key, property_val in schema[schema_name]["properties"].items():
property_type = property_val["type"]
# Handle list
Expand All @@ -1579,7 +1630,7 @@ def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str
elif property_type == "object":
if property_val.get("$ref"):
name = property_val["$ref"].split("/")[-1]
attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name)))
attribute_list.append((property_key, convert_marshmallow_json_schema_to_python_class(schema, name)))
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index]
Expand All @@ -1588,13 +1639,34 @@ def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str
attribute_list.append((property_key, Dict[str, _get_element_type(property_val)])) # type: ignore[misc,index]
# Handle int, float, bool or str
else:
attribute_list.append((property_key, _get_element_type(property_val)))
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
return attribute_list


def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""

attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name)
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""

attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name)
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
element_type = element_property["type"]
element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] # type: ignore
element_format = element_property["format"] if "format" in element_property else None

if type(element_type) == list:
Expand Down
5 changes: 3 additions & 2 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dataclasses_json import DataClassJsonMixin, config
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
Expand All @@ -26,7 +27,7 @@ def noop():


@dataclass
class FlyteFile(DataClassJsonMixin, os.PathLike, typing.Generic[T]):
class FlyteFile(os.PathLike, typing.Generic[T], DataClassJSONMixin):
path: typing.Union[str, os.PathLike] = field(
default=None, metadata=config(mm_field=fields.String())
) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

import numpy as _np
import pandas
from dataclasses_json import DataClassJsonMixin, config
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
Expand Down Expand Up @@ -180,7 +181,7 @@ def get_handler(cls, t: Type) -> SchemaHandler:


@dataclass
class FlyteSchema(DataClassJsonMixin):
class FlyteSchema(DataClassJSONMixin):
remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String()))
"""
This is the main schema class that users should use.
Expand Down
5 changes: 3 additions & 2 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from typing import Dict, Generator, Optional, Type, Union

import _datetime
from dataclasses_json import DataClassJsonMixin, config
from dataclasses_json import config
from fsspec.utils import get_protocol
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit import lazy_module
Expand Down Expand Up @@ -44,7 +45,7 @@


@dataclass
class StructuredDataset(DataClassJsonMixin):
class StructuredDataset(DataClassJSONMixin):
"""
This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset
class (that is just a model, a Python class representation of the protobuf).
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
# TODO: remove upper-bound after fixing change in contract
"dataclasses-json>=0.5.2,<0.5.12",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.9.1",
"marshmallow-enum",
"natsort>=7.0.1",
"docker-image-py>=0.1.10",
Expand Down
Loading

0 comments on commit e0346fb

Please sign in to comment.