Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Sep 11, 2024
1 parent 1ef5fd4 commit 3d2a23e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 26 deletions.
26 changes: 9 additions & 17 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class SqlBackend(ABC):
execute SQL statements, fetch results from SQL statements, and save data
to tables."""

# singleton shared across all SQL backends, used to infer schema from dataclasses.
# no state is stored in this class, so it can be shared across all instances.
_STRUCTS = StructInference()

@abstractmethod
Expand All @@ -61,17 +63,6 @@ def create_table(self, full_name: str, klass: Dataclass):
ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._STRUCTS.as_schema(klass)}) USING DELTA"
self.execute(ddl)

@classmethod
def _field_type(cls, field: dataclasses.Field):
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
# this seems to happen when the dataclass is first used from a file importing it
if isinstance(field.type, str):
try:
return __builtins__[field.type]
except TypeError as e:
logger.warning(f"Could not load type {field.type}", exc_info=e)
return field.type

@classmethod
def _filter_none_rows(cls, rows, klass):
if len(rows) == 0:
Expand Down Expand Up @@ -161,21 +152,22 @@ def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any

@classmethod
def _value_to_sql(cls, value: Any) -> str:
"""Converts a Python value to a SQL string representation."""
if value is None:
return "NULL"
if isinstance(value, bool):
return "TRUE" if value else "FALSE"
if isinstance(value, int):
return f"{value}"
if isinstance(value, float):
return f"{value:0.2f}"
return f"{value}"
if isinstance(value, str):
value = str(value).replace("'", "''")
return f"'{value}'"
if isinstance(value, bool):
return "TRUE" if value else "FALSE"
if isinstance(value, datetime.date):
return f"'{value.year}-{value.month}-{value.day}'"
if isinstance(value, datetime.datetime):
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
return f"TIMESTAMP '{value.strftime('%Y-%m-%d %H:%M:%S%z')}'"
if isinstance(value, datetime.date):
return f"DATE '{value.year}-{value.month}-{value.day}'"
if isinstance(value, list):
values = ", ".join(cls._value_to_sql(v) for v in value)
return f"ARRAY({values})"
Expand Down
28 changes: 27 additions & 1 deletion src/databricks/labs/lsql/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ class StructInferError(TypeError):


class SqlType(Protocol):
"""Represents a Spark SQL type."""

def as_sql(self) -> str: ...


@dataclass
class NullableType(SqlType):
"""Represents a nullable type."""

inner_type: SqlType

def as_sql(self) -> str:
Expand All @@ -24,6 +28,8 @@ def as_sql(self) -> str:

@dataclass
class ArrayType(SqlType):
"""Represents an array type."""

element_type: SqlType

def as_sql(self) -> str:
Expand All @@ -32,6 +38,8 @@ def as_sql(self) -> str:

@dataclass
class MapType(SqlType):
"""Represents a map type."""

key_type: SqlType
value_type: SqlType

Expand All @@ -41,6 +49,8 @@ def as_sql(self) -> str:

@dataclass
class PrimitiveType(SqlType):
"""Represents a primitive type."""

name: str

def as_sql(self) -> str:
Expand All @@ -49,6 +59,8 @@ def as_sql(self) -> str:

@dataclass
class StructField:
"""Represents a field in a struct type."""

name: str
type: SqlType

Expand All @@ -62,13 +74,17 @@ def as_sql(self) -> str:

@dataclass
class StructType(SqlType):
"""Represents a struct type."""

fields: list[StructField]

def as_sql(self) -> str:
"""Returns a DDL representation of the struct type."""
fields = ",".join(f.as_sql() for f in self.fields)
return f"STRUCT<{fields}>"

def as_schema(self) -> str:
"""Returns a schema representation of the struct type."""
fields = []
for field in self.fields:
not_null = "" if field.nullable else " NOT NULL"
Expand All @@ -77,6 +93,8 @@ def as_schema(self) -> str:


class StructInference:
"""Infers Spark SQL types from Python types."""

_PRIMITIVES: ClassVar[dict[type, str]] = {
str: "STRING",
int: "LONG",
Expand All @@ -87,16 +105,19 @@ class StructInference:
}

def as_ddl(self, type_ref: type) -> str:
"""Returns a DDL representation of the type."""
v = self._infer(type_ref, [])
return v.as_sql()

def as_schema(self, type_ref: type) -> str:
"""Returns a schema representation of the type."""
v = self._infer(type_ref, [])
if hasattr(v, "as_schema"):
return v.as_schema()
raise StructInferError(f"Cannot generate schema for {type_ref}")

def _infer(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers the SQL type from the Python type. Raises StructInferError if the type is not supported."""
if dataclasses.is_dataclass(type_ref):
return self._infer_struct(type_ref, path)
if isinstance(type_ref, enum.EnumMeta):
Expand All @@ -112,11 +133,13 @@ def _infer(self, type_ref: type, path: list[str]) -> SqlType:
return self._infer_generic(type_ref, path)

def _infer_primitive(self, type_ref: type, path: list[str]) -> PrimitiveType:
"""Infers the primitive SQL type from the Python type. Raises StructInferError if the type is not supported."""
if type_ref in self._PRIMITIVES:
return PrimitiveType(self._PRIMITIVES[type_ref])
raise StructInferError(f'{".".join(path)}: unknown: {type_ref}')

def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers the SQL type from the generic Python type. Uses internal APIs to handle generic types."""
# pylint: disable-next=import-outside-toplevel
from typing import ( # type: ignore[attr-defined]
_GenericAlias,
Expand All @@ -125,7 +148,7 @@ def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType:

if isinstance(type_ref, (types.UnionType, _UnionGenericAlias)): # type: ignore[attr-defined]
return self._infer_nullable(type_ref, path)
if isinstance(type_ref, (_GenericAlias, types.GenericAlias)): # type: ignore[attr-defined]
if isinstance(type_ref, (types.GenericAlias, _GenericAlias)): # type: ignore[attr-defined]
if type_ref.__origin__ in (dict, list) or isinstance(type_ref, types.GenericAlias):
return self._infer_container(type_ref, path)
prefix = ".".join(path)
Expand All @@ -134,6 +157,7 @@ def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType:
raise StructInferError(f"{prefix}unsupported type: {type_ref.__name__}")

def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers nullability from Optional[x] or `x | None` types."""
type_args = get_args(type_ref)
if len(type_args) > 2:
raise StructInferError(f'{".".join(path)}: union: too many variants: {type_args}')
Expand All @@ -144,6 +168,7 @@ def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType:
return NullableType(first_type)

def _infer_container(self, type_ref: type, path: list[str]) -> SqlType:
"""Infers the SQL type from the generic container Python type."""
type_args = get_args(type_ref)
if not type_args:
raise StructInferError(f"Missing type arguments: {type_args} in {type_ref}")
Expand All @@ -156,6 +181,7 @@ def _infer_container(self, type_ref: type, path: list[str]) -> SqlType:
return ArrayType(element_type)

def _infer_struct(self, type_ref: type, path: list[str]) -> StructType:
"""Infers the struct type from the Python dataclass type."""
fields = []
for field, hint in get_type_hints(type_ref).items():
origin = getattr(hint, "__origin__", None)
Expand Down
36 changes: 36 additions & 0 deletions tests/integration/test_structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import datetime
from dataclasses import dataclass

from databricks.labs.lsql.backends import StatementExecutionBackend


@dataclass
class Foo:
first: str
second: bool | None


@dataclass
class Nested:
foo: Foo
since: datetime.date
created: datetime.datetime
mapping: dict[str, int]
array: list[int]


def test_appends_complex_types(ws, env_or_skip, make_random) -> None:
sql_backend = StatementExecutionBackend(ws, env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"))
today = datetime.date.today()
now = datetime.datetime.now()
full_name = f"hive_metastore.default.t{make_random(4)}"
sql_backend.save_table(
full_name,
[
Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3]),
Nested(Foo("b", False), today, now, {"c": 3, "d": 4}, [4, 5, 6]),
],
Nested,
)
rows = list(sql_backend.fetch(f"SELECT * FROM {full_name}"))
assert len(rows) == 2
33 changes: 28 additions & 5 deletions tests/unit/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import os
import sys
from dataclasses import dataclass
Expand Down Expand Up @@ -219,7 +220,7 @@ def test_statement_execution_backend_save_table_two_records():
)


def test_statement_execution_backend_save_table_in_batches_of_two(mocker):
def test_statement_execution_backend_save_table_in_batches_of_two():
ws = create_autospec(WorkspaceClient)

ws.statement_execution.execute_statement.return_value = StatementResponse(
Expand Down Expand Up @@ -435,12 +436,34 @@ def test_mock_backend_overwrite():
@dataclass
class Nested:
foo: Foo
since: datetime.date
created: datetime.datetime
mapping: dict[str, int]
array: list[int]
some: float | None = None


def test_supports_complex_types():
mock_backend = MockBackend()
mock_backend.create_table("nested", Nested)
expected = [...]
assert expected == mock_backend.queries
ws = create_autospec(WorkspaceClient)

ws.statement_execution.execute_statement.return_value = StatementResponse(
status=StatementStatus(state=StatementState.SUCCEEDED)
)

seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2)

today = datetime.date(2024, 9, 11)
now = datetime.datetime(2024, 9, 11, 12, 13, 14, tzinfo=datetime.timezone.utc)
seb.save_table(
"x",
[
Nested(Foo("a", True), today, now, {"a": 1, "b": 2}, [1, 2, 3], 0.342532),
],
Nested,
)

queries = [_.kwargs["statement"] for _ in ws.statement_execution.method_calls]
assert [
"CREATE TABLE IF NOT EXISTS x (foo STRUCT<first:STRING,second:BOOLEAN> NOT NULL, since DATE NOT NULL, created TIMESTAMP NOT NULL, mapping MAP<STRING,LONG> NOT NULL, array ARRAY<LONG> NOT NULL, some FLOAT) USING DELTA",
"INSERT INTO x (foo, since, created, mapping, array, some) VALUES (STRUCT('a' AS first, TRUE AS second), DATE '2024-9-11', TIMESTAMP '2024-09-11 12:13:14+0000', MAP('a', 1, 'b', 2), ARRAY(1, 2, 3), 0.342532)",
] == queries
7 changes: 4 additions & 3 deletions tests/unit/test_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,20 @@ class NotDataclass:
(datetime.date, "DATE"),
(datetime.datetime, "TIMESTAMP"),
(list[str], "ARRAY<STRING>"),
(set[str], "ARRAY<STRING>"),
(dict[str, int], "MAP<STRING,LONG>"),
(dict[int, list[str]], "MAP<LONG,ARRAY<STRING>>"),
(Foo, "STRUCT<first:STRING,second:BOOLEAN>"),
(Nested, "STRUCT<foo:STRUCT<first:STRING,second:BOOLEAN>,mapping:MAP<STRING,LONG>,array:ARRAY<LONG>>"),
],
)
def test_struct_inference(type_ref, ddl):
def test_struct_inference(type_ref, ddl) -> None:
inference = StructInference()
assert inference.as_ddl(type_ref) == ddl


@pytest.mark.parametrize("type_ref", [type(None), list, set, tuple, dict, object, NotDataclass])
def test_struct_inference_raises_on_unknown_type(type_ref):
def test_struct_inference_raises_on_unknown_type(type_ref) -> None:
inference = StructInference()
with pytest.raises(StructInferError):
inference.as_ddl(type_ref)
Expand All @@ -65,6 +66,6 @@ def test_struct_inference_raises_on_unknown_type(type_ref):
),
],
)
def test_as_schema(type_ref, ddl):
def test_as_schema(type_ref, ddl) -> None:
inference = StructInference()
assert inference.as_schema(type_ref) == ddl

0 comments on commit 3d2a23e

Please sign in to comment.