From 3d2a23eafaef8a3d9186b096640659a53d035b7a Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Wed, 11 Sep 2024 18:17:11 +0200 Subject: [PATCH] ... --- src/databricks/labs/lsql/backends.py | 26 +++++++------------- src/databricks/labs/lsql/structs.py | 28 +++++++++++++++++++++- tests/integration/test_structs.py | 36 ++++++++++++++++++++++++++++ tests/unit/test_backends.py | 33 +++++++++++++++++++++---- tests/unit/test_structs.py | 7 +++--- 5 files changed, 104 insertions(+), 26 deletions(-) create mode 100644 tests/integration/test_structs.py diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index d856381a..f8b9230b 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -43,6 +43,8 @@ class SqlBackend(ABC): execute SQL statements, fetch results from SQL statements, and save data to tables.""" + # singleton shared across all SQL backends, used to infer schema from dataclasses. + # no state is stored in this class, so it can be shared across all instances. _STRUCTS = StructInference() @abstractmethod @@ -61,17 +63,6 @@ def create_table(self, full_name: str, klass: Dataclass): ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._STRUCTS.as_schema(klass)}) USING DELTA" self.execute(ddl) - @classmethod - def _field_type(cls, field: dataclasses.Field): - # workaround rare (Python?) issue where f.type is the type name instead of the type itself - # this seems to happen when the dataclass is first used from a file importing it - if isinstance(field.type, str): - try: - return __builtins__[field.type] - except TypeError as e: - logger.warning(f"Could not load type {field.type}", exc_info=e) - return field.type - @classmethod def _filter_none_rows(cls, rows, klass): if len(rows) == 0: @@ -161,21 +152,22 @@ def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any @classmethod def _value_to_sql(cls, value: Any) -> str: + """Converts a Python value to a SQL string representation.""" if value is None: return "NULL" + if isinstance(value, bool): + return "TRUE" if value else "FALSE" if isinstance(value, int): return f"{value}" if isinstance(value, float): - return f"{value:0.2f}" + return f"{value}" if isinstance(value, str): value = str(value).replace("'", "''") return f"'{value}'" - if isinstance(value, bool): - return "TRUE" if value else "FALSE" - if isinstance(value, datetime.date): - return f"'{value.year}-{value.month}-{value.day}'" if isinstance(value, datetime.datetime): - return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'" + return f"TIMESTAMP '{value.strftime('%Y-%m-%d %H:%M:%S%z')}'" + if isinstance(value, datetime.date): + return f"DATE '{value.year}-{value.month}-{value.day}'" if isinstance(value, list): values = ", ".join(cls._value_to_sql(v) for v in value) return f"ARRAY({values})" diff --git a/src/databricks/labs/lsql/structs.py b/src/databricks/labs/lsql/structs.py index ef1f9aa0..b7f129b1 100644 --- a/src/databricks/labs/lsql/structs.py +++ b/src/databricks/labs/lsql/structs.py @@ -11,11 +11,15 @@ class StructInferError(TypeError): class SqlType(Protocol): + """Represents a Spark SQL type.""" + def as_sql(self) -> str: ... @dataclass class NullableType(SqlType): + """Represents a nullable type.""" + inner_type: SqlType def as_sql(self) -> str: @@ -24,6 +28,8 @@ def as_sql(self) -> str: @dataclass class ArrayType(SqlType): + """Represents an array type.""" + element_type: SqlType def as_sql(self) -> str: @@ -32,6 +38,8 @@ def as_sql(self) -> str: @dataclass class MapType(SqlType): + """Represents a map type.""" + key_type: SqlType value_type: SqlType @@ -41,6 +49,8 @@ def as_sql(self) -> str: @dataclass class PrimitiveType(SqlType): + """Represents a primitive type.""" + name: str def as_sql(self) -> str: @@ -49,6 +59,8 @@ def as_sql(self) -> str: @dataclass class StructField: + """Represents a field in a struct type.""" + name: str type: SqlType @@ -62,13 +74,17 @@ def as_sql(self) -> str: @dataclass class StructType(SqlType): + """Represents a struct type.""" + fields: list[StructField] def as_sql(self) -> str: + """Returns a DDL representation of the struct type.""" fields = ",".join(f.as_sql() for f in self.fields) return f"STRUCT<{fields}>" def as_schema(self) -> str: + """Returns a schema representation of the struct type.""" fields = [] for field in self.fields: not_null = "" if field.nullable else " NOT NULL" @@ -77,6 +93,8 @@ def as_schema(self) -> str: class StructInference: + """Infers Spark SQL types from Python types.""" + _PRIMITIVES: ClassVar[dict[type, str]] = { str: "STRING", int: "LONG", @@ -87,16 +105,19 @@ class StructInference: } def as_ddl(self, type_ref: type) -> str: + """Returns a DDL representation of the type.""" v = self._infer(type_ref, []) return v.as_sql() def as_schema(self, type_ref: type) -> str: + """Returns a schema representation of the type.""" v = self._infer(type_ref, []) if hasattr(v, "as_schema"): return v.as_schema() raise StructInferError(f"Cannot generate schema for {type_ref}") def _infer(self, type_ref: type, path: list[str]) -> SqlType: + """Infers the SQL type from the Python type. Raises StructInferError if the type is not supported.""" if dataclasses.is_dataclass(type_ref): return self._infer_struct(type_ref, path) if isinstance(type_ref, enum.EnumMeta): @@ -112,11 +133,13 @@ def _infer(self, type_ref: type, path: list[str]) -> SqlType: return self._infer_generic(type_ref, path) def _infer_primitive(self, type_ref: type, path: list[str]) -> PrimitiveType: + """Infers the primitive SQL type from the Python type. Raises StructInferError if the type is not supported.""" if type_ref in self._PRIMITIVES: return PrimitiveType(self._PRIMITIVES[type_ref]) raise StructInferError(f'{".".join(path)}: unknown: {type_ref}') def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType: + """Infers the SQL type from the generic Python type. Uses internal APIs to handle generic types.""" # pylint: disable-next=import-outside-toplevel from typing import ( # type: ignore[attr-defined] _GenericAlias, @@ -125,7 +148,7 @@ def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType: if isinstance(type_ref, (types.UnionType, _UnionGenericAlias)): # type: ignore[attr-defined] return self._infer_nullable(type_ref, path) - if isinstance(type_ref, (_GenericAlias, types.GenericAlias)): # type: ignore[attr-defined] + if isinstance(type_ref, (types.GenericAlias, _GenericAlias)): # type: ignore[attr-defined] if type_ref.__origin__ in (dict, list) or isinstance(type_ref, types.GenericAlias): return self._infer_container(type_ref, path) prefix = ".".join(path) @@ -134,6 +157,7 @@ def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType: raise StructInferError(f"{prefix}unsupported type: {type_ref.__name__}") def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType: + """Infers nullability from Optional[x] or `x | None` types.""" type_args = get_args(type_ref) if len(type_args) > 2: raise StructInferError(f'{".".join(path)}: union: too many variants: {type_args}') @@ -144,6 +168,7 @@ def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType: return NullableType(first_type) def _infer_container(self, type_ref: type, path: list[str]) -> SqlType: + """Infers the SQL type from the generic container Python type.""" type_args = get_args(type_ref) if not type_args: raise StructInferError(f"Missing type arguments: {type_args} in {type_ref}") @@ -156,6 +181,7 @@ def _infer_container(self, type_ref: type, path: list[str]) -> SqlType: return ArrayType(element_type) def _infer_struct(self, type_ref: type, path: list[str]) -> StructType: + """Infers the struct type from the Python dataclass type.""" fields = [] for field, hint in get_type_hints(type_ref).items(): origin = getattr(hint, "__origin__", None) diff --git a/tests/integration/test_structs.py b/tests/integration/test_structs.py new file mode 100644 index 00000000..076e3225 --- /dev/null +++ b/tests/integration/test_structs.py @@ -0,0 +1,36 @@ +import datetime +from dataclasses import dataclass + +from databricks.labs.lsql.backends import StatementExecutionBackend + + +@dataclass +class Foo: + first: str + second: bool | None + + +@dataclass +class Nested: + foo: Foo + since: datetime.date + created: datetime.datetime + mapping: dict[str, int] + array: list[int] + + +def test_appends_complex_types(ws, env_or_skip, make_random) -> None: + sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) + today = datetime.date.today() + now = datetime.datetime.now() + full_name = f"hive_metastore.default.t{make_random(4)}" + sql_backend.save_table( + full_name, + [ + Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3]), + Nested(Foo("b", False), today, now, {"c": 3, "d": 4}, [4, 5, 6]), + ], + Nested, + ) + rows = list(sql_backend.fetch(f"SELECT * FROM {full_name}")) + assert len(rows) == 2 diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 7ff133d8..2d95160b 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -1,3 +1,4 @@ +import datetime import os import sys from dataclasses import dataclass @@ -219,7 +220,7 @@ def test_statement_execution_backend_save_table_two_records(): ) -def test_statement_execution_backend_save_table_in_batches_of_two(mocker): +def test_statement_execution_backend_save_table_in_batches_of_two(): ws = create_autospec(WorkspaceClient) ws.statement_execution.execute_statement.return_value = StatementResponse( @@ -435,12 +436,34 @@ def test_mock_backend_overwrite(): @dataclass class Nested: foo: Foo + since: datetime.date + created: datetime.datetime mapping: dict[str, int] array: list[int] + some: float | None = None def test_supports_complex_types(): - mock_backend = MockBackend() - mock_backend.create_table("nested", Nested) - expected = [...] - assert expected == mock_backend.queries + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = StatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2) + + today = datetime.date(2024, 9, 11) + now = datetime.datetime(2024, 9, 11, 12, 13, 14, tzinfo=datetime.timezone.utc) + seb.save_table( + "x", + [ + Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3], 0.342532), + ], + Nested, + ) + + queries = [_.kwargs["statement"] for _ in ws.statement_execution.method_calls] + assert [ + "CREATE TABLE IF NOT EXISTS x (foo STRUCT NOT NULL, since DATE NOT NULL, created TIMESTAMP NOT NULL, mapping MAP NOT NULL, array ARRAY NOT NULL, some FLOAT) USING DELTA", + "INSERT INTO x (foo, since, created, mapping, array, some) VALUES (STRUCT('a' AS first, TRUE AS second), DATE '2024-9-11', TIMESTAMP '2024-09-11 12:13:14+0000', MAP('a', 1, 'b', 2), ARRAY(1, 2, 3), 0.342532)", + ] == queries diff --git a/tests/unit/test_structs.py b/tests/unit/test_structs.py index 9dc1e956..10c90570 100644 --- a/tests/unit/test_structs.py +++ b/tests/unit/test_structs.py @@ -36,19 +36,20 @@ class NotDataclass: (datetime.date, "DATE"), (datetime.datetime, "TIMESTAMP"), (list[str], "ARRAY"), + (set[str], "ARRAY"), (dict[str, int], "MAP"), (dict[int, list[str]], "MAP>"), (Foo, "STRUCT"), (Nested, "STRUCT,mapping:MAP,array:ARRAY>"), ], ) -def test_struct_inference(type_ref, ddl): +def test_struct_inference(type_ref, ddl) -> None: inference = StructInference() assert inference.as_ddl(type_ref) == ddl @pytest.mark.parametrize("type_ref", [type(None), list, set, tuple, dict, object, NotDataclass]) -def test_struct_inference_raises_on_unknown_type(type_ref): +def test_struct_inference_raises_on_unknown_type(type_ref) -> None: inference = StructInference() with pytest.raises(StructInferError): inference.as_ddl(type_ref) @@ -65,6 +66,6 @@ def test_struct_inference_raises_on_unknown_type(type_ref): ), ], ) -def test_as_schema(type_ref, ddl): +def test_as_schema(type_ref, ddl) -> None: inference = StructInference() assert inference.as_schema(type_ref) == ddl