Skip to content

Commit

Permalink
Added support for generic types in SqlBackend
Browse files Browse the repository at this point in the history
This PR adds the ability to use rich dataclasses like:

```python
@DataClass
class Foo:
    first: str
    second: bool | None

@DataClass
class Nested:
    foo: Foo
    mapping: dict[str, int]
    array: list[int]
```
  • Loading branch information
nfx committed Sep 11, 2024
1 parent e21699f commit 1ef5fd4
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 42 deletions.
85 changes: 43 additions & 42 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import datetime
import logging
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Sequence
from types import UnionType
from typing import Any, ClassVar, Protocol, TypeVar

from databricks.labs.blueprint.commands import CommandExecutor
Expand All @@ -20,6 +20,7 @@
from databricks.sdk.service.compute import Language

from databricks.labs.lsql.core import Row, StatementExecutionExt
from databricks.labs.lsql.structs import StructInference

logger = logging.getLogger(__name__)

Expand All @@ -42,6 +43,8 @@ class SqlBackend(ABC):
execute SQL statements, fetch results from SQL statements, and save data
to tables."""

_STRUCTS = StructInference()

@abstractmethod
def execute(self, sql: str, *, catalog: str | None = None, schema: str | None = None) -> None:
raise NotImplementedError
Expand All @@ -55,33 +58,9 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
raise NotImplementedError

def create_table(self, full_name: str, klass: Dataclass):
ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA"
ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._STRUCTS.as_schema(klass)}) USING DELTA"
self.execute(ddl)

_builtin_type_mapping: ClassVar[dict[type, str]] = {
str: "STRING",
int: "LONG",
bool: "BOOLEAN",
float: "FLOAT",
}

@classmethod
def _schema_for(cls, klass: Dataclass):
fields = []
for f in dataclasses.fields(klass):
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if field_type not in cls._builtin_type_mapping:
msg = f"Cannot auto-convert {field_type}"
raise SyntaxError(msg)
not_null = " NOT NULL"
if f.default is None:
not_null = ""
spark_type = cls._builtin_type_mapping[field_type]
fields.append(f"{f.name} {spark_type}{not_null}")
return ", ".join(fields)

@classmethod
def _field_type(cls, field: dataclasses.Field):
# workaround rare (Python?) issue where f.type is the type name instead of the type itself
Expand Down Expand Up @@ -177,23 +156,45 @@ def _row_to_sql(cls, row: DataclassInstance, fields: tuple[dataclasses.Field[Any
data = []
for f in fields:
value = getattr(row, f.name)
field_type = cls._field_type(f)
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if value is None:
data.append("NULL")
elif field_type is bool:
data.append("TRUE" if value else "FALSE")
elif field_type is str:
value = str(value).replace("'", "''")
data.append(f"'{value}'")
elif field_type is int:
data.append(f"{value}")
else:
msg = f"unknown type: {field_type}"
raise ValueError(msg)
data.append(cls._value_to_sql(value))
return ", ".join(data)

@classmethod
def _value_to_sql(cls, value: Any) -> str:
if value is None:
return "NULL"
if isinstance(value, int):
return f"{value}"
if isinstance(value, float):
return f"{value:0.2f}"
if isinstance(value, str):
value = str(value).replace("'", "''")
return f"'{value}'"
if isinstance(value, bool):
return "TRUE" if value else "FALSE"
if isinstance(value, datetime.date):
return f"'{value.year}-{value.month}-{value.day}'"
if isinstance(value, datetime.datetime):
return f"'{value.strftime('%Y-%m-%d %H:%M:%S')}'"
if isinstance(value, list):
values = ", ".join(cls._value_to_sql(v) for v in value)
return f"ARRAY({values})"
if isinstance(value, dict):
map_values: list[str] = []
for k, v in value.items():
map_values.append(cls._value_to_sql(k))
map_values.append(cls._value_to_sql(v))
return f"MAP({', '.join(map_values)})"
if dataclasses.is_dataclass(value):
struct = []
for f in dataclasses.fields(value):
v = getattr(value, f.name)
sql_value = f"{cls._value_to_sql(v)} AS {f.name}"
struct.append(sql_value)
return f"STRUCT({', '.join(struct)})"
msg = f"unsupported: {value}"
raise ValueError(msg)


class StatementExecutionBackend(ExecutionBackend):
def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000, **kwargs):
Expand Down Expand Up @@ -273,7 +274,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D
self.create_table(full_name, klass)
return
# pyspark deals well with lists of dataclass instances, as long as schema is provided
df = self._spark.createDataFrame(rows, self._schema_for(klass))
df = self._spark.createDataFrame(rows, self._STRUCTS.as_schema(klass))
df.write.saveAsTable(full_name, mode=mode)


Expand Down
166 changes: 166 additions & 0 deletions src/databricks/labs/lsql/structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import dataclasses
import datetime
import enum
import types
from dataclasses import dataclass
from typing import ClassVar, Protocol, get_args, get_type_hints


class StructInferError(TypeError):
pass


class SqlType(Protocol):
def as_sql(self) -> str: ...


@dataclass
class NullableType(SqlType):
inner_type: SqlType

def as_sql(self) -> str:
return self.inner_type.as_sql()


@dataclass
class ArrayType(SqlType):
element_type: SqlType

def as_sql(self) -> str:
return f"ARRAY<{self.element_type.as_sql()}>"


@dataclass
class MapType(SqlType):
key_type: SqlType
value_type: SqlType

def as_sql(self) -> str:
return f"MAP<{self.key_type.as_sql()},{self.value_type.as_sql()}>"


@dataclass
class PrimitiveType(SqlType):
name: str

def as_sql(self) -> str:
return self.name


@dataclass
class StructField:
name: str
type: SqlType

@property
def nullable(self) -> bool:
return isinstance(self.type, NullableType)

def as_sql(self) -> str:
return f"{self.name}:{self.type.as_sql()}"


@dataclass
class StructType(SqlType):
fields: list[StructField]

def as_sql(self) -> str:
fields = ",".join(f.as_sql() for f in self.fields)
return f"STRUCT<{fields}>"

def as_schema(self) -> str:
fields = []
for field in self.fields:
not_null = "" if field.nullable else " NOT NULL"
fields.append(f"{field.name} {field.type.as_sql()}{not_null}")
return ", ".join(fields)


class StructInference:
_PRIMITIVES: ClassVar[dict[type, str]] = {
str: "STRING",
int: "LONG",
bool: "BOOLEAN",
float: "FLOAT",
datetime.date: "DATE",
datetime.datetime: "TIMESTAMP",
}

def as_ddl(self, type_ref: type) -> str:
v = self._infer(type_ref, [])
return v.as_sql()

def as_schema(self, type_ref: type) -> str:
v = self._infer(type_ref, [])
if hasattr(v, "as_schema"):
return v.as_schema()
raise StructInferError(f"Cannot generate schema for {type_ref}")

def _infer(self, type_ref: type, path: list[str]) -> SqlType:
if dataclasses.is_dataclass(type_ref):
return self._infer_struct(type_ref, path)
if isinstance(type_ref, enum.EnumMeta):
return self._infer_primitive(str, path)
if type_ref in self._PRIMITIVES:
return self._infer_primitive(type_ref, path)
if type_ref is list:
raise StructInferError("Cannot determine element type of list. Rewrite as: list[XXX]")
if type_ref is set:
raise StructInferError("Cannot determine element type of set. Rewrite as: set[XXX]")
if type_ref is dict:
raise StructInferError("Cannot determine key and value types of dict. Rewrite as: dict[XXX, YYY]")
return self._infer_generic(type_ref, path)

def _infer_primitive(self, type_ref: type, path: list[str]) -> PrimitiveType:
if type_ref in self._PRIMITIVES:
return PrimitiveType(self._PRIMITIVES[type_ref])
raise StructInferError(f'{".".join(path)}: unknown: {type_ref}')

def _infer_generic(self, type_ref: type, path: list[str]) -> SqlType:
# pylint: disable-next=import-outside-toplevel
from typing import ( # type: ignore[attr-defined]
_GenericAlias,
_UnionGenericAlias,
)

if isinstance(type_ref, (types.UnionType, _UnionGenericAlias)): # type: ignore[attr-defined]
return self._infer_nullable(type_ref, path)
if isinstance(type_ref, (_GenericAlias, types.GenericAlias)): # type: ignore[attr-defined]
if type_ref.__origin__ in (dict, list) or isinstance(type_ref, types.GenericAlias):
return self._infer_container(type_ref, path)
prefix = ".".join(path)
if prefix:
prefix = f"{prefix}: "
raise StructInferError(f"{prefix}unsupported type: {type_ref.__name__}")

def _infer_nullable(self, type_ref: type, path: list[str]) -> SqlType:
type_args = get_args(type_ref)
if len(type_args) > 2:
raise StructInferError(f'{".".join(path)}: union: too many variants: {type_args}')
first_type = self._infer(type_args[0], [*path, "(first)"])
if type_args[1] is not type(None):
msg = f'{".".join(path)}.(second): not a NoneType: {type_args[1]}'
raise StructInferError(msg)
return NullableType(first_type)

def _infer_container(self, type_ref: type, path: list[str]) -> SqlType:
type_args = get_args(type_ref)
if not type_args:
raise StructInferError(f"Missing type arguments: {type_args} in {type_ref}")
if len(type_args) == 2:
key_type = self._infer(type_args[0], [*path, "key"])
value_type = self._infer(type_args[1], [*path, "value"])
return MapType(key_type, value_type)
# here we make a simple assumption that not two type arguments means a list
element_type = self._infer(type_args[0], path)
return ArrayType(element_type)

def _infer_struct(self, type_ref: type, path: list[str]) -> StructType:
fields = []
for field, hint in get_type_hints(type_ref).items():
origin = getattr(hint, "__origin__", None)
if origin is ClassVar:
continue
field_type = self._infer(hint, [*path, field])
fields.append(StructField(field, field_type))
return StructType(fields)
14 changes: 14 additions & 0 deletions tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,17 @@ def test_mock_backend_overwrite():
Row(first="aaa", second=True),
Row(first="bbb", second=False),
]


@dataclass
class Nested:
foo: Foo
mapping: dict[str, int]
array: list[int]


def test_supports_complex_types():
mock_backend = MockBackend()
mock_backend.create_table("nested", Nested)
expected = [...]
assert expected == mock_backend.queries
70 changes: 70 additions & 0 deletions tests/unit/test_structs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import datetime
from dataclasses import dataclass
from typing import Optional

import pytest

from databricks.labs.lsql.structs import StructInference, StructInferError


@dataclass
class Foo:
first: str
second: bool | None


@dataclass
class Nested:
foo: Foo
mapping: dict[str, int]
array: list[int]


class NotDataclass:
x: int


@pytest.mark.parametrize(
"type_ref, ddl",
[
(int, "LONG"),
(int | None, "LONG"),
(Optional[int], "LONG"),
(float, "FLOAT"),
(str, "STRING"),
(bool, "BOOLEAN"),
(datetime.date, "DATE"),
(datetime.datetime, "TIMESTAMP"),
(list[str], "ARRAY<STRING>"),
(dict[str, int], "MAP<STRING,LONG>"),
(dict[int, list[str]], "MAP<LONG,ARRAY<STRING>>"),
(Foo, "STRUCT<first:STRING,second:BOOLEAN>"),
(Nested, "STRUCT<foo:STRUCT<first:STRING,second:BOOLEAN>,mapping:MAP<STRING,LONG>,array:ARRAY<LONG>>"),
],
)
def test_struct_inference(type_ref, ddl):
inference = StructInference()
assert inference.as_ddl(type_ref) == ddl


@pytest.mark.parametrize("type_ref", [type(None), list, set, tuple, dict, object, NotDataclass])
def test_struct_inference_raises_on_unknown_type(type_ref):
inference = StructInference()
with pytest.raises(StructInferError):
inference.as_ddl(type_ref)


@pytest.mark.parametrize(
"type_ref,ddl",
[
(Foo, "first STRING NOT NULL, second BOOLEAN"),
(
Nested,
"foo STRUCT<first:STRING,second:BOOLEAN> NOT NULL, "
"mapping MAP<STRING,LONG> NOT NULL, array ARRAY<LONG> NOT NULL",
),
],
)
def test_as_schema(type_ref, ddl):
inference = StructInference()
assert inference.as_schema(type_ref) == ddl

0 comments on commit 1ef5fd4

Please sign in to comment.