Skip to content

Commit

Permalink
feat: add trino sqlalchemy dialect
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Apr 12, 2021
1 parent 4c50265 commit b7aa393
Show file tree
Hide file tree
Showing 6 changed files with 596 additions and 3 deletions.
40 changes: 37 additions & 3 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,44 @@
import trino.exceptions
import trino.client
import trino.logging
from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION

from trino.transaction import (
Transaction,
IsolationLevel,
NO_TRANSACTION
)
from trino.exceptions import (
Warning,
Error,
InterfaceError,
DatabaseError,
DataError,
OperationalError,
IntegrityError,
InternalError,
ProgrammingError,
NotSupportedError,
)

__all__ = ["connect", "Connection", "Cursor"]
__all__ = [
# https://www.python.org/dev/peps/pep-0249/#globals
"apilevel",
"threadsafety",
"paramstyle",
"connect",
"Connection",
"Cursor",
# https://www.python.org/dev/peps/pep-0249/#exceptions
"Warning",
"Error",
"InterfaceError",
"DatabaseError",
"DataError",
"OperationalError",
"IntegrityError",
"InternalError",
"ProgrammingError",
"NotSupportedError",
]


apilevel = "2.0"
Expand Down
4 changes: 4 additions & 0 deletions trino/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sqlalchemy.dialects import registry

__version__ = '0.2.0'
registry.register("trino", "trino.sqlalchemy.dialect.TrinoDialect", "TrinoDialect")
92 changes: 92 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from sqlalchemy.sql import compiler

# https://trino.io/docs/current/language/reserved.html
RESERVED_WORDS = {
"alter",
"and",
"as",
"between",
"by",
"case",
"cast",
"constraint",
"create",
"cross",
"cube",
"current_date",
"current_path",
"current_role",
"current_time",
"current_timestamp",
"current_user",
"deallocate",
"delete",
"describe",
"distinct",
"drop",
"else",
"end",
"escape",
"except",
"execute",
"exists",
"extract",
"false",
"for",
"from",
"full",
"group",
"grouping",
"having",
"in",
"inner",
"insert",
"intersect",
"into",
"is",
"join",
"left",
"like",
"localtime",
"localtimestamp",
"natural",
"normalize",
"not",
"null",
"on",
"or",
"order",
"outer",
"prepare",
"recursive",
"right",
"rollup",
"select",
"table",
"then",
"true",
"uescape",
"union",
"unnest",
"using",
"values",
"when",
"where",
"with",
}


class TrinoSQLCompiler(compiler.SQLCompiler):
pass


class TrinoDDLCompiler(compiler.DDLCompiler):
pass


class TrinoTypeCompiler(compiler.GenericTypeCompiler):
pass


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
158 changes: 158 additions & 0 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import re
from typing import *

from sqlalchemy import util
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine

# https://trino.io/docs/current/language/types.html
_type_map = {
# === Boolean ===
'boolean': sqltypes.BOOLEAN,

# === Integer ===
'tinyint': sqltypes.SMALLINT,
'smallint': sqltypes.SMALLINT,
'integer': sqltypes.INTEGER,
'bigint': sqltypes.BIGINT,

# === Floating-point ===
'real': sqltypes.FLOAT,
'double': sqltypes.FLOAT,

# === Fixed-precision ===
'decimal': sqltypes.DECIMAL,

# === String ===
'varchar': sqltypes.VARCHAR,
'char': sqltypes.CHAR,
'varbinary': sqltypes.VARBINARY,
'json': sqltypes.JSON,

# === Date and time ===
'date': sqltypes.DATE,
'time': sqltypes.TIME,
'timestamp': sqltypes.TIMESTAMP,

# 'interval year to month':
# 'interval day to second':
#
# === Structural ===
# 'array': ARRAY,
# 'map': MAP
# 'row': ROW
#
# === Mixed ===
# 'ipaddress': IPADDRESS
# 'uuid': UUID,
# 'hyperloglog': HYPERLOGLOG,
# 'p4hyperloglog': P4HYPERLOGLOG,
# 'qdigest': QDIGEST,
# 'tdigest': TDIGEST,
}

SQLType = Union[TypeEngine, Type[TypeEngine]]


class MAP(TypeEngine):
__visit_name__ = "MAP"

def __init__(self, key_type: SQLType, value_type: SQLType):
if isinstance(key_type, type):
key_type = key_type()
self.key_type: TypeEngine = key_type

if isinstance(value_type, type):
value_type = value_type()
self.value_type: TypeEngine = value_type

@property
def python_type(self):
return dict


class ROW(TypeEngine):
__visit_name__ = "ROW"

def __init__(self, attr_types: Dict[str, SQLType]):
for name, attr_type in attr_types.items():
if isinstance(attr_type, type):
attr_type = attr_type()
attr_types[name] = attr_type
self.attr_types: Dict[str, TypeEngine] = attr_types

@property
def python_type(self):
return dict


def split(string: str, delimiter: str = ',',
quote: str = '"', escaped_quote: str = r'\"',
open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]:
"""
A split function that is aware of quotes and brackets/parentheses.
:param string: string to split
:param delimiter: string defining where to split, usually a comma or space
:param quote: string, either a single or a double quote
:param escaped_quote: string representing an escaped quote
:param open_bracket: string, either [, {, < or (
:param close_bracket: string, either ], }, > or )
"""
parens = 0
quotes = False
i = 0
for j, character in enumerate(string):
complete = parens == 0 and not quotes
if complete and character == delimiter:
yield string[i:j]
i = j + len(delimiter)
elif character == open_bracket:
parens += 1
elif character == close_bracket:
parens -= 1
elif character == quote:
if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote:
quotes = False
elif not quotes:
quotes = True
yield string[i:]


def parse_sqltype(type_str: str) -> TypeEngine:
type_str = type_str.strip().lower()
match = re.match(r'^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?', type_str)
if not match:
util.warn(f"Could not parse type name '{type_str}'")
return sqltypes.NULLTYPE
type_name = match.group("type")
type_opts = match.group("options")

if type_name == "array":
item_type = parse_sqltype(type_opts)
if isinstance(item_type, sqltypes.ARRAY):
dimensions = (item_type.dimensions or 1) + 1
return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions)
return sqltypes.ARRAY(item_type)
elif type_name == "map":
key_type_str, value_type_str = split(type_opts)
key_type = parse_sqltype(key_type_str)
value_type = parse_sqltype(value_type_str)
return MAP(key_type, value_type)
elif type_name == "row":
attr_types: Dict[str, SQLType] = {}
for attr_str in split(type_opts):
name, attr_type_str = split(attr_str.strip(), delimiter=' ')
attr_type = parse_sqltype(attr_type_str)
attr_types[name] = attr_type
return ROW(attr_types)

if type_name not in _type_map:
util.warn(f"Did not recognize type '{type_name}'")
return sqltypes.NULLTYPE
type_class = _type_map[type_name]
type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else []
if type_name in ('time', 'timestamp'):
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision
return type_class(*type_args)
Loading

0 comments on commit b7aa393

Please sign in to comment.