Skip to content

Commit

Permalink
add test && fix lint
Browse files Browse the repository at this point in the history
Signed-off-by: hhcs9527 <hhcs9527@gmail.com>
  • Loading branch information
hhcs9527 committed Aug 6, 2023
1 parent fcc0a0d commit 892396c
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 74 deletions.
23 changes: 17 additions & 6 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,12 +1544,13 @@ def to_literal(
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore


def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list = []
for property_key, property_val in schema["properties"].items():
if property_val.get("anyOf"):
property_type = property_val["anyOf"][0]["type"]
elif property_val.get("enum") :
elif property_val.get("enum"):
property_type = "enum"
else:
property_type = property_val["type"]
Expand All @@ -1559,20 +1560,25 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
# Handle dataclass and dict
elif property_type == "object":
if property_val.get("anyOf"):
attribute_list.append((property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True)))
attribute_list.append(
(property_key, convert_json_schema_to_python_class(property_val["anyOf"][0], schema_name, True))
)
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore
)
else:
attribute_list.append((property_key, convert_json_schema_to_python_class(property_val, schema_name, True)))
attribute_list.append(
(property_key, convert_json_schema_to_python_class(property_val, schema_name, True))
)
elif property_type == "enum":
attribute_list.append([property_key, str]) # type: ignore
# Handle int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
return attribute_list


def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any):
attribute_list = []
for property_key, property_val in schema[schema_name]["properties"].items():
Expand All @@ -1596,7 +1602,8 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
return attribute_list

def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin:bool=False) -> Type[dataclasses.dataclass()]: # type: ignore

def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, is_dataclass_json_mixin: bool = False) -> Type[dataclasses.dataclass()]: # type: ignore
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
Expand All @@ -1610,8 +1617,12 @@ def convert_json_schema_to_python_class(schema: dict, schema_name: typing.Any, i
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"]
def _get_element_type(element_property: typing.Dict[str, typing.Any]) -> Type:
element_type = (
[e_property["type"] for e_property in element_property["anyOf"]]
if element_property.get("anyOf")
else element_property["type"]
)
element_format = element_property["format"] if "format" in element_property else None

if type(element_type) == list:
Expand Down
6 changes: 3 additions & 3 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from contextlib import contextmanager
from dataclasses import dataclass, field

from dataclasses_json import config, dataclass_json
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
Expand All @@ -25,9 +26,8 @@ def noop():
T = typing.TypeVar("T")


@dataclass_json
@dataclass
class FlyteFile(os.PathLike, typing.Generic[T]):
class FlyteFile(os.PathLike, typing.Generic[T], DataClassJSONMixin):
path: typing.Union[str, os.PathLike] = field(
default=None, metadata=config(mm_field=fields.String())
) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

import numpy as _np
import pandas
from dataclasses_json import config, dataclass_json
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.loggers import logger
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from mashumaro.mixins.json import DataClassJSONMixin

T = typing.TypeVar("T")

Expand Down
4 changes: 2 additions & 2 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from typing import Dict, Generator, Optional, Type, Union

import _datetime
from dataclasses_json import config, dataclass_json
from dataclasses_json import config
from fsspec.utils import get_protocol
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit import lazy_module
Expand All @@ -22,7 +23,6 @@
from flytekit.models import types as type_models
from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType
from mashumaro.mixins.json import DataClassJSONMixin

if typing.TYPE_CHECKING:
import pandas as pd
Expand Down
154 changes: 93 additions & 61 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from mashumaro.mixins.json import DataClassJSONMixin
import mashumaro
from pandas._testing import assert_frame_equal
from typing_extensions import Annotated, get_args, get_origin

Expand Down Expand Up @@ -486,7 +485,6 @@ class Foo(DataClassJSONMixin):
x: int
y: str

# schema = JSONSchema().dump(typing.cast(DataClassJSONMixin, Foo).schema())
from mashumaro.jsonschema import build_json_schema

schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict()
Expand Down Expand Up @@ -728,9 +726,15 @@ class TestStructD_transformer(DataClassJSONMixin):
m: typing.Dict[str, typing.List[int]]


@dataclass # to ask => not support => failed right away
@dataclass
class UnsupportedSchemaType_transformer:
_a:str="Hello"
_a: str = "Hello"


@dataclass
class UnsupportedNestedStruct_transformer(DataClassJSONMixin):
a: int
s: UnsupportedSchemaType_transformer


def test_dataclass_transformer_with_dataclassjsonmixin():
Expand All @@ -742,48 +746,17 @@ def test_dataclass_transformer_with_dataclassjsonmixin():
"type": "object",
"title": "InnerStruct_transformer",
"properties": {
"a": {
"type": "integer"
},
"b": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
]
},
"c": {
"type": "array",
"items": {
"type": "integer"
}
}
"a": {"type": "integer"},
"b": {"anyOf": [{"type": "string"}, {"type": "null"}]},
"c": {"type": "array", "items": {"type": "integer"}},
},
"additionalProperties": False,
"required": [
"a",
"b",
"c"
]
"required": ["a", "b", "c"],
},
"m": {
"type": "object",
"additionalProperties": {
"type": "string"
},
"propertyNames": {
"type": "string"
}
}
"m": {"type": "object", "additionalProperties": {"type": "string"}, "propertyNames": {"type": "string"}},
},
"additionalProperties": False,
"required": [
"s",
"m"
]
"required": ["s", "m"],
}

tf = DataclassTransformer()
Expand All @@ -801,17 +774,13 @@ def test_dataclass_transformer_with_dataclassjsonmixin():
assert t.metadata is not None
assert t.metadata == schema

t = tf.get_literal_type(UnsupportedNestedStruct)
assert t is not None
assert t.simple is not None
assert t.simple == SimpleType.STRUCT
assert t.metadata is None

@pytest.mark.xfail(raises=mashumaro.exceptions.UnserializableField)
def test_unsupported_schema_type():
# The code that is expected to raise the exception during class definition
@dataclass
class UnsupportedNestedStruct_transformer(DataClassJSONMixin):
a: int
s: UnsupportedSchemaType_transformer

tf = DataclassTransformer()
t = tf.get_literal_type(UnsupportedNestedStruct_transformer)
def test_dataclass_int_preserving():
ctx = FlyteContext.current_context()

Expand Down Expand Up @@ -974,20 +943,20 @@ def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir):
lt = tf.get_literal_type(TestFileStruct_optional_flytefile)
lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt)

assert lv.scalar.generic["a"] == remote_path
assert lv.scalar.generic["b"] == remote_path
assert lv.scalar.generic["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b_prime"] is None
assert lv.scalar.generic["c"] == remote_path
assert lv.scalar.generic["d"].values[0].string_value == remote_path
assert lv.scalar.generic["e"].values[0].string_value == remote_path
assert lv.scalar.generic["c"].fields["path"].string_value == remote_path
assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value"
assert lv.scalar.generic["f"]["a"] == remote_path
assert lv.scalar.generic["g"]["a"] == remote_path
assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g_prime"]["a"] is None
assert lv.scalar.generic["h"] == remote_path
assert lv.scalar.generic["h"].fields["path"].string_value == remote_path
assert lv.scalar.generic["h_prime"] is None
assert lv.scalar.generic["i"]["a"] == 42
assert lv.scalar.generic["i_prime"]["a"] == 99
assert lv.scalar.generic["i"].fields["a"].number_value == 42
assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99

ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile)

Expand Down Expand Up @@ -1233,6 +1202,41 @@ class DatasetStruct(object):
assert "parquet" == ot.b.c["hello"].file_format


@dataclass
class InnerDatasetStruct_dataclassjsonmixin(DataClassJSONMixin):
a: StructuredDataset
b: typing.List[Annotated[StructuredDataset, "parquet"]]
c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]]


def test_structured_dataset_in_dataclassjsonmixin():
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
People = Annotated[StructuredDataset, "parquet"]

@dataclass
class DatasetStruct_dataclassjsonmixin(DataClassJSONMixin):
a: People
b: InnerDatasetStruct_dataclassjsonmixin

sd = StructuredDataset(dataframe=df, file_format="parquet")
o = DatasetStruct_dataclassjsonmixin(a=sd, b=InnerDatasetStruct_dataclassjsonmixin(a=sd, b=[sd], c={"hello": sd}))

ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(DatasetStruct_dataclassjsonmixin)
lv = tf.to_literal(ctx, o, DatasetStruct_dataclassjsonmixin, lt)
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=DatasetStruct_dataclassjsonmixin)

assert_frame_equal(df, ot.a.open(pd.DataFrame).all())
assert_frame_equal(df, ot.b.a.open(pd.DataFrame).all())
assert_frame_equal(df, ot.b.b[0].open(pd.DataFrame).all())
assert_frame_equal(df, ot.b.c["hello"].open(pd.DataFrame).all())
assert "parquet" == ot.a.file_format
assert "parquet" == ot.b.a.file_format
assert "parquet" == ot.b.b[0].file_format
assert "parquet" == ot.b.c["hello"].file_format


# Enums should have string values
class Color(Enum):
RED = "red"
Expand Down Expand Up @@ -1391,6 +1395,7 @@ class ArgsAssert(DataClassJSONMixin):
x: int
y: typing.Optional[str]


@dataclass
class SchemaArgsAssert(DataClassJSONMixin):
x: typing.Optional[ArgsAssert]
Expand All @@ -1410,7 +1415,8 @@ class Bar(DataClassJSONMixin):

pv = Bar(x=3)
with pytest.raises(
TypeTransformerFailedError, match="Type of Val '<class 'int'>' is not an instance of <class 'types.SchemaArgsAssert'>"
TypeTransformerFailedError,
match="Type of Val '<class 'int'>' is not an instance of <class 'types.SchemaArgsAssert'>",
):
DataclassTransformer().assert_type(gt, pv)

Expand Down Expand Up @@ -2017,6 +2023,32 @@ def test_schema_in_dataclass():
assert o == ot


@dataclass
class InnerResult_dataclassjsonmixin(DataClassJSONMixin):
number: int
schema: TestSchema # type: ignore


@dataclass
class Result_dataclassjsonmixin(DataClassJSONMixin):
result: InnerResult_dataclassjsonmixin
schema: TestSchema # type: ignore


def test_schema_in_dataclassjsonmixin():
schema = TestSchema()
df = pd.DataFrame(data={"some_str": ["a", "b", "c"]})
schema.open().write(df)
o = Result(result=InnerResult(number=1, schema=schema), schema=schema)
ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(Result)
lv = tf.to_literal(ctx, o, Result, lt)
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result)

assert o == ot


def test_guess_of_dataclass():
@dataclass_json
@dataclass()
Expand Down

0 comments on commit 892396c

Please sign in to comment.