Skip to content

Commit

Permalink
Inherit directly from DataClassJsonMixin instead of using @dataclass_…
Browse files Browse the repository at this point in the history
…json for improved static type checking (#1801)

* Inherit directly from DataClassJsonMixin instead of @dataclass_json for improved static type checking

As it says in the dataclasses-json README: https://github.com/lidatong/dataclasses-json/blob/89578cb9ebed290e70dba8946bfdb68ff6746755/README.md?plain=1#L111-L129, we can use inheritance for improved static type checking; this one change eliminates something like 467 pyright errors from the flytekit module

Signed-off-by: Matthew Hoffman <matthew@protopia.ai>
  • Loading branch information
ringohoffman committed Aug 21, 2023
1 parent 72fcac9 commit 2164d4e
Show file tree
Hide file tree
Showing 29 changed files with 131 additions and 215 deletions.
19 changes: 7 additions & 12 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
from typing import Dict, List, Optional

import yaml
from dataclasses_json import dataclass_json
from dataclasses_json import DataClassJsonMixin

from flytekit.configuration import internal as _internal
from flytekit.configuration.default_images import DefaultImages
Expand All @@ -164,9 +164,8 @@
SERIALIZED_CONTEXT_ENV_VAR = "_F_SS_C"


@dataclass_json
@dataclass(init=True, repr=True, eq=True, frozen=True)
class Image(object):
class Image(DataClassJsonMixin):
"""
Image is a structured wrapper for task container images used in object serialization.
Expand Down Expand Up @@ -224,9 +223,8 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image
return Image(name=name, fqn=ref["name"], tag=ref["tag"])


@dataclass_json
@dataclass(init=True, repr=True, eq=True, frozen=True)
class ImageConfig(object):
class ImageConfig(DataClassJsonMixin):
"""
We recommend you to use ImageConfig.auto(img_name=None) to create an ImageConfig.
For example, ImageConfig.auto(img_name=""ghcr.io/flyteorg/flytecookbook:v1.0.0"") will create an ImageConfig.
Expand Down Expand Up @@ -671,9 +669,8 @@ def for_endpoint(
return c.with_params(platform=PlatformConfig.for_endpoint(endpoint, insecure), data_config=data_config)


@dataclass_json
@dataclass
class EntrypointSettings(object):
class EntrypointSettings(DataClassJsonMixin):
"""
This object carries information about the path of the entrypoint command that will be invoked at runtime.
This is where `pyflyte-execute` code can be found. This is useful for cases like pyspark execution.
Expand All @@ -682,9 +679,8 @@ class EntrypointSettings(object):
path: Optional[str] = None


@dataclass_json
@dataclass
class FastSerializationSettings(object):
class FastSerializationSettings(DataClassJsonMixin):
"""
This object hold information about settings necessary to serialize an object so that it can be fast-registered.
"""
Expand All @@ -698,9 +694,8 @@ class FastSerializationSettings(object):


# TODO: ImageConfig, python_interpreter, venv_root, fast_serialization_settings.destination_dir should be combined.
@dataclass_json
@dataclass()
class SerializationSettings(object):
@dataclass
class SerializationSettings(DataClassJsonMixin):
"""
These settings are provided while serializing a workflow and task, before registration. This is required to get
runtime information at serialization time, as well as some defaults.
Expand Down
39 changes: 19 additions & 20 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import typing
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Dict, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -220,8 +220,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

class DataclassTransformer(TypeTransformer[object]):
"""
The Dataclass Transformer, provides a type transformer for arbitrary Python dataclasses, that have
@dataclass and @dataclass_json decorators.
The Dataclass Transformer provides a type transformer for dataclasses_json dataclasses.
The Dataclass is converted to and from json and is transported between tasks using the proto.Structpb representation
Also the type declaration will try to extract the JSON Schema for the object if possible and pass it with the
Expand All @@ -233,9 +232,8 @@ class DataclassTransformer(TypeTransformer[object]):
.. code-block:: python
@dataclass_json
@dataclass
class Test():
class Test(DataClassJsonMixin):
a: int
b: str
Expand Down Expand Up @@ -270,9 +268,8 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
if type(v) == expected_type:
return

# @dataclass_json
# @dataclass
# class Foo(object):
# class Foo(DataClassJsonMixin):
# a: int = 0
#
# @task
Expand Down Expand Up @@ -318,7 +315,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

if not issubclass(t, DataClassJsonMixin):
raise AssertionError(
f"Dataclass {t} should be decorated with @dataclass_json to be " f"serialized correctly"
f"Dataclass {t} should be decorated with @dataclass_json or be a subclass of DataClassJsonMixin to be "
"serialized correctly"
)
schema = None
try:
Expand Down Expand Up @@ -349,7 +347,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
)
if not issubclass(type(python_val), DataClassJsonMixin):
raise TypeTransformerFailedError(
f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly"
f"Dataclass {python_type} should be decorated with @dataclass_json or be a subclass of "
"DataClassJsonMixin to be serialized correctly"
)
self._serialize_flyte_type(python_val, python_type)
return Literal(
Expand Down Expand Up @@ -429,10 +428,10 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
or issubclass(python_type, StructuredDataset)
):
lv = TypeEngine.to_literal(FlyteContext.current_context(), python_val, python_type, None)
# dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
# dataclasses_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
# JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the
# path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here,
# so that dataclass_json can always get a remote path.
# so that dataclasses_json can always get a remote path.
# In other words, the file transformer has special code that handles the fact that if remote_source is
# set, then the real uri in the literal should be the remote source, not the path (which may be an
# auto-generated random local path). To be sure we're writing the right path to the json, use the uri
Expand Down Expand Up @@ -596,12 +595,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
if not dataclasses.is_dataclass(expected_python_type):
raise TypeTransformerFailedError(
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
f"user defined datatypes in Flytekit"
"user defined datatypes in Flytekit"
)
if not issubclass(expected_python_type, DataClassJsonMixin):
raise TypeTransformerFailedError(
f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be "
f"serialized correctly"
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or be a subclass of "
"DataClassJsonMixin to be serialized correctly"
)
json_str = _json_format.MessageToJson(lv.scalar.generic)
dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str)
Expand Down Expand Up @@ -1520,32 +1519,32 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore


def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore
def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]:
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
:param schema_name: dataclass name of return type
"""
attribute_list = []
attribute_list: List[Tuple[str, type]] = []
for property_key, property_val in schema[schema_name]["properties"].items():
property_type = property_val["type"]
# Handle list
if property_val["type"] == "array":
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
attribute_list.append((property_key, List[_get_element_type(property_val["items"])])) # type: ignore[misc,index]
# Handle dataclass and dict
elif property_type == "object":
if property_val.get("$ref"):
name = property_val["$ref"].split("/")[-1]
attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name)))
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore
(property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index]
)
else:
attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) # type: ignore
attribute_list.append((property_key, Dict[str, _get_element_type(property_val)])) # type: ignore[misc,index]
# Handle int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
attribute_list.append((property_key, _get_element_type(property_val)))

return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))

Expand Down
5 changes: 2 additions & 3 deletions flytekit/extras/pytorch/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union

import torch
from dataclasses_json import dataclass_json
from dataclasses_json import DataClassJsonMixin
from typing_extensions import Protocol

from flytekit.core.context_manager import FlyteContext
Expand All @@ -20,9 +20,8 @@ class IsDataclass(Protocol):
__post_init__: Optional[Callable]


@dataclass_json
@dataclass
class PyTorchCheckpoint:
class PyTorchCheckpoint(DataClassJsonMixin):
"""
This class is helpful to save a checkpoint.
"""
Expand Down
5 changes: 2 additions & 3 deletions flytekit/extras/tensorflow/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Tuple, Type, Union

import tensorflow as tf
from dataclasses_json import dataclass_json
from dataclasses_json import DataClassJsonMixin
from tensorflow.python.data.ops.readers import TFRecordDatasetV2
from typing_extensions import Annotated, get_args, get_origin

Expand All @@ -16,9 +16,8 @@
from flytekit.types.file import TFRecordFile


@dataclass_json
@dataclass
class TFRecordDatasetConfig:
class TFRecordDatasetConfig(DataClassJsonMixin):
"""
TFRecordDatasetConfig can be used while creating tf.data.TFRecordDataset comprising
record of one or more TFRecord files.
Expand Down
5 changes: 2 additions & 3 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from uuid import UUID

import fsspec
from dataclasses_json import config, dataclass_json
from dataclasses_json import DataClassJsonMixin, config
from fsspec.utils import get_protocol
from marshmallow import fields

Expand All @@ -30,9 +30,8 @@ def noop():
...


@dataclass_json
@dataclass
class FlyteDirectory(os.PathLike, typing.Generic[T]):
class FlyteDirectory(DataClassJsonMixin, os.PathLike, typing.Generic[T]):
path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
"""
.. warning::
Expand Down
5 changes: 2 additions & 3 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dataclasses_json import config, dataclass_json
from dataclasses_json import DataClassJsonMixin, config
from marshmallow import fields

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
Expand All @@ -25,9 +25,8 @@ def noop():
T = typing.TypeVar("T")


@dataclass_json
@dataclass
class FlyteFile(os.PathLike, typing.Generic[T]):
class FlyteFile(DataClassJsonMixin, os.PathLike, typing.Generic[T]):
path: typing.Union[str, os.PathLike] = field(
default=None, metadata=config(mm_field=fields.String())
) # type: ignore
Expand Down
5 changes: 2 additions & 3 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as _np
import pandas
from dataclasses_json import config, dataclass_json
from dataclasses_json import DataClassJsonMixin, config
from marshmallow import fields

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
Expand Down Expand Up @@ -179,9 +179,8 @@ def get_handler(cls, t: Type) -> SchemaHandler:
return cls._SCHEMA_HANDLERS[t]


@dataclass_json
@dataclass
class FlyteSchema(object):
class FlyteSchema(DataClassJsonMixin):
remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String()))
"""
This is the main schema class that users should use.
Expand Down
5 changes: 2 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, Generator, Optional, Type, Union

import _datetime
from dataclasses_json import config, dataclass_json
from dataclasses_json import DataClassJsonMixin, config
from fsspec.utils import get_protocol
from marshmallow import fields
from typing_extensions import Annotated, TypeAlias, get_args, get_origin
Expand Down Expand Up @@ -43,9 +43,8 @@
GENERIC_PROTOCOL: str = "generic protocol"


@dataclass_json
@dataclass
class StructuredDataset(object):
class StructuredDataset(DataClassJsonMixin):
"""
This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset
class (that is just a model, a Python class representation of the protobuf).
Expand Down
7 changes: 3 additions & 4 deletions plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

from dataclasses_json import dataclass_json
from dataclasses_json import DataClassJsonMixin
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

Expand All @@ -10,9 +10,8 @@
from flytekit.extend import TaskPlugins


@dataclass_json
@dataclass
class AWSBatchConfig(object):
class AWSBatchConfig(DataClassJsonMixin):
"""
Use this to configure SubmitJobInput for a AWS batch job. Task's marked with this will automatically execute
natively onto AWS batch service.
Expand All @@ -27,7 +26,7 @@ class AWSBatchConfig(object):

def to_dict(self):
s = Struct()
s.update(self.to_dict())
s.update(super().to_dict())
return json_format.MessageToDict(s)


Expand Down
Loading

0 comments on commit 2164d4e

Please sign in to comment.