From ff00675aa658dbe7362c5c6d9462c949b4f76819 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Fri, 2 Apr 2021 15:00:42 +0700 Subject: [PATCH] feat: add trino sqlalchemy dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Đặng Minh Dũng --- integration_tests/__init__.py | 11 + integration_tests/test_dbapi.py | 4 +- setup.py | 28 ++- tests/__init__.py | 0 tests/sqlalchemy/__init__.py | 11 + tests/sqlalchemy/conftest.py | 45 ++++ tests/sqlalchemy/test_datatype_parse.py | 127 ++++++++++ tests/sqlalchemy/test_datatype_split.py | 64 +++++ trino/__init__.py | 9 - trino/dbapi.py | 40 ++- trino/sqlalchemy/__init__.py | 14 ++ trino/sqlalchemy/compiler.py | 140 +++++++++++ trino/sqlalchemy/datatype.py | 174 +++++++++++++ trino/sqlalchemy/dialect.py | 310 ++++++++++++++++++++++++ trino/sqlalchemy/error.py | 24 ++ 15 files changed, 975 insertions(+), 26 deletions(-) create mode 100644 integration_tests/__init__.py create mode 100644 tests/__init__.py create mode 100644 tests/sqlalchemy/__init__.py create mode 100644 tests/sqlalchemy/conftest.py create mode 100644 tests/sqlalchemy/test_datatype_parse.py create mode 100644 tests/sqlalchemy/test_datatype_split.py create mode 100644 trino/sqlalchemy/__init__.py create mode 100644 trino/sqlalchemy/compiler.py create mode 100644 trino/sqlalchemy/datatype.py create mode 100644 trino/sqlalchemy/dialect.py create mode 100644 trino/sqlalchemy/error.py diff --git a/integration_tests/__init__.py b/integration_tests/__init__.py new file mode 100644 index 00000000..4d9a9249 --- /dev/null +++ b/integration_tests/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/integration_tests/test_dbapi.py b/integration_tests/test_dbapi.py index be4b027a..997a6687 100644 --- a/integration_tests/test_dbapi.py +++ b/integration_tests/test_dbapi.py @@ -15,10 +15,10 @@ import pytest import pytz -import trino -from conftest import TRINO_VERSION +import trino.dbapi from trino.exceptions import TrinoQueryError from trino.transaction import IsolationLevel +from .conftest import TRINO_VERSION @pytest.fixture diff --git a/setup.py b/setup.py index 99ed6598..3d546c6e 100755 --- a/setup.py +++ b/setup.py @@ -14,24 +14,23 @@ import ast import re -from setuptools import setup import textwrap +from setuptools import setup _version_re = re.compile(r"__version__\s+=\s+(.*)") - with open("trino/__init__.py", "rb") as f: trino_version = _version_re.search(f.read().decode("utf-8")) assert trino_version is not None version = str(ast.literal_eval(trino_version.group(1))) - kerberos_require = ["requests_kerberos"] +sqlalchemy_require = ["sqlalchemy~=1.3"] -all_require = kerberos_require + [] +all_require = kerberos_require + sqlalchemy_require -tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click"] +tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click", "assertpy"] setup( name="trino", @@ -44,19 +43,17 @@ description="Client for the Trino distributed SQL Engine", long_description=textwrap.dedent( """ - Client for Trino (https://trino.io), a distributed SQL engine for - interactive and batch big data processing. Provides a low-level client and - a DBAPI 2.0 implementation. - """ + Client for Trino (https://trino.io), a distributed SQL engine for + interactive and batch big data processing. Provides a low-level client and + a DBAPI 2.0 implementation. + """ ), license="Apache 2.0", classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", - "Operating System :: MacOS :: MacOS X", - "Operating System :: POSIX", - "Operating System :: Microsoft :: Windows", + "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", @@ -65,6 +62,7 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Database", "Topic :: Database :: Front-Ends", ], python_requires='>=3.6', @@ -72,6 +70,12 @@ extras_require={ "all": all_require, "kerberos": kerberos_require, + "sqlalchemy": sqlalchemy_require, "tests": tests_require, }, + entry_points={ + "sqlalchemy.dialects": [ + "trino = trino.sqlalchemy.dialect:TrinoDialect", + ] + }, ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py new file mode 100644 index 00000000..4d9a9249 --- /dev/null +++ b/tests/sqlalchemy/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py new file mode 100644 index 00000000..da618e5a --- /dev/null +++ b/tests/sqlalchemy/conftest.py @@ -0,0 +1,45 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from assertpy import add_extension, assert_that +from sqlalchemy.sql.sqltypes import ARRAY + +from trino.sqlalchemy.datatype import MAP, ROW, SQLType + + +def assert_sqltype(this: SQLType, that: SQLType): + if isinstance(this, type): + this = this() + if isinstance(that, type): + that = that() + assert_that(type(this)).is_same_as(type(that)) + if isinstance(this, ARRAY): + assert_sqltype(this.item_type, that.item_type) + if this.dimensions is None or this.dimensions == 1: + assert_that(that.dimensions).is_in(None, 1) + else: + assert_that(this.dimensions).is_equal_to(this.dimensions) + elif isinstance(this, MAP): + assert_sqltype(this.key_type, that.key_type) + assert_sqltype(this.value_type, that.value_type) + elif isinstance(this, ROW): + assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types)) + for name, this_attr in this.attr_types.items(): + that_attr = this.attr_types[name] + assert_sqltype(this_attr, that_attr) + else: + assert_that(str(this)).is_equal_to(str(that)) + + +@add_extension +def is_sqltype(self, that): + this = self.val + assert_sqltype(this, that) diff --git a/tests/sqlalchemy/test_datatype_parse.py b/tests/sqlalchemy/test_datatype_parse.py new file mode 100644 index 00000000..a7408f40 --- /dev/null +++ b/tests/sqlalchemy/test_datatype_parse.py @@ -0,0 +1,127 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from assertpy import assert_that +from sqlalchemy.sql.sqltypes import ( + CHAR, VARCHAR, + ARRAY, + INTEGER, DECIMAL, + DATE, TIME, TIMESTAMP +) +from sqlalchemy.sql.type_api import TypeEngine + +from trino.sqlalchemy import datatype +from trino.sqlalchemy.datatype import MAP, ROW + + +@pytest.mark.parametrize( + 'type_str, sql_type', + datatype._type_map.items(), + ids=datatype._type_map.keys() +) +def test_parse_simple_type(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + if not isinstance(actual_type, type): + actual_type = type(actual_type) + assert_that(actual_type).is_equal_to(sql_type) + + +parse_type_options_testcases = { + 'VARCHAR(10)': VARCHAR(10), + 'DECIMAL(20)': DECIMAL(20), + 'DECIMAL(20, 3)': DECIMAL(20, 3), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_type_options_testcases.items(), + ids=parse_type_options_testcases.keys() +) +def test_parse_type_options(type_str: str, sql_type: TypeEngine): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_array_testcases = { + 'array(integer)': ARRAY(INTEGER()), + 'array(varchar(10))': ARRAY(VARCHAR(10)), + 'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)), + 'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_array_testcases.items(), + ids=parse_array_testcases.keys() +) +def test_parse_array(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_map_testcases = { + 'map(char, integer)': MAP(CHAR(), INTEGER()), + 'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)), + 'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)), + 'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))), + 'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))), + 'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_map_testcases.items(), + ids=parse_map_testcases.keys() +) +def test_parse_map(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_row_testcases = { + 'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())), + 'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))), + 'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))': + ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_row_testcases.items(), + ids=parse_row_testcases.keys() +) +def test_parse_row(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) + + +parse_datetime_testcases = { + 'date': DATE(), + 'time': TIME(), + 'time with time zone': TIME(timezone=True), + 'timestamp': TIMESTAMP(), + 'timestamp with time zone': TIMESTAMP(timezone=True), +} + + +@pytest.mark.parametrize( + 'type_str, sql_type', + parse_datetime_testcases.items(), + ids=parse_datetime_testcases.keys() +) +def test_parse_datetime(type_str: str, sql_type: ARRAY): + actual_type = datatype.parse_sqltype(type_str) + assert_that(actual_type).is_sqltype(sql_type) diff --git a/tests/sqlalchemy/test_datatype_split.py b/tests/sqlalchemy/test_datatype_split.py new file mode 100644 index 00000000..4cce0df6 --- /dev/null +++ b/tests/sqlalchemy/test_datatype_split.py @@ -0,0 +1,64 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import pytest +from assertpy import assert_that + +from trino.sqlalchemy import datatype + +split_string_testcases = { + '10': ['10'], + '10,3': ['10', '3'], + 'varchar': ['varchar'], + 'varchar,int': ['varchar', 'int'], + 'varchar,int,float': ['varchar', 'int', 'float'], + 'array(varchar)': ['array(varchar)'], + 'array(varchar),int': ['array(varchar)', 'int'], + 'array(varchar(20))': ['array(varchar(20))'], + 'array(varchar(20)),int': ['array(varchar(20))', 'int'], + 'array(varchar(20)),array(varchar(20))': ['array(varchar(20))', 'array(varchar(20))'], + 'map(varchar, integer),int': ['map(varchar, integer)', 'int'], + 'map(varchar(20), integer),int': ['map(varchar(20), integer)', 'int'], + 'map(varchar(20), varchar(20)),int': ['map(varchar(20), varchar(20))', 'int'], + 'map(varchar(20), varchar(20)),array(varchar)': ['map(varchar(20), varchar(20))', 'array(varchar)'], + 'row(first_name varchar(20), last_name varchar(20)),int': + ['row(first_name varchar(20), last_name varchar(20))', 'int'], +} + + +@pytest.mark.parametrize( + 'input_string, output_strings', + split_string_testcases.items(), + ids=split_string_testcases.keys() +) +def test_split_string(input_string: str, output_strings: List[str]): + actual = list(datatype.split(input_string)) + assert_that(actual).is_equal_to(output_strings) + + +split_delimiter_testcases = [ + ('first,second', ',', ['first', 'second']), + ('first second', ' ', ['first', 'second']), + ('first|second', '|', ['first', 'second']), + ('first,second third', ',', ['first', 'second third']), + ('first,second third', ' ', ['first,second', 'third']), +] + + +@pytest.mark.parametrize( + 'input_string, delimiter, output_strings', + split_delimiter_testcases, +) +def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]): + actual = list(datatype.split(input_string, delimiter=delimiter)) + assert_that(actual).is_equal_to(output_strings) diff --git a/trino/__init__.py b/trino/__init__.py index 4798a223..405fbf6b 100644 --- a/trino/__init__.py +++ b/trino/__init__.py @@ -10,13 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import auth -from . import dbapi -from . import client -from . import constants -from . import exceptions -from . import logging - -__all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging'] - __version__ = "0.305.0" diff --git a/trino/dbapi.py b/trino/dbapi.py index 616ae341..e3b18ae8 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -29,10 +29,44 @@ import trino.exceptions import trino.client import trino.logging -from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION - +from trino.transaction import ( + Transaction, + IsolationLevel, + NO_TRANSACTION +) +from trino.exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) -__all__ = ["connect", "Connection", "Cursor"] +__all__ = [ + # https://www.python.org/dev/peps/pep-0249/#globals + "apilevel", + "threadsafety", + "paramstyle", + "connect", + "Connection", + "Cursor", + # https://www.python.org/dev/peps/pep-0249/#exceptions + "Warning", + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", +] apilevel = "2.0" diff --git a/trino/sqlalchemy/__init__.py b/trino/sqlalchemy/__init__.py new file mode 100644 index 00000000..000d3e08 --- /dev/null +++ b/trino/sqlalchemy/__init__.py @@ -0,0 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sqlalchemy.dialects import registry + +registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect") diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py new file mode 100644 index 00000000..f4b85692 --- /dev/null +++ b/trino/sqlalchemy/compiler.py @@ -0,0 +1,140 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from sqlalchemy.sql import compiler + +# https://trino.io/docs/current/language/reserved.html +RESERVED_WORDS = { + "alter", + "and", + "as", + "between", + "by", + "case", + "cast", + "constraint", + "create", + "cross", + "cube", + "current_date", + "current_path", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "deallocate", + "delete", + "describe", + "distinct", + "drop", + "else", + "end", + "escape", + "except", + "execute", + "exists", + "extract", + "false", + "for", + "from", + "full", + "group", + "grouping", + "having", + "in", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "left", + "like", + "localtime", + "localtimestamp", + "natural", + "normalize", + "not", + "null", + "on", + "or", + "order", + "outer", + "prepare", + "recursive", + "right", + "rollup", + "select", + "table", + "then", + "true", + "uescape", + "union", + "unnest", + "using", + "values", + "when", + "where", + "with", +} + + +class TrinoSQLCompiler(compiler.SQLCompiler): + pass + + +class TrinoDDLCompiler(compiler.DDLCompiler): + pass + + +class TrinoTypeCompiler(compiler.GenericTypeCompiler): + def visit_FLOAT(self, type_, **kw): + precision = type_.precision or 32 + if 0 <= precision <= 32: + return self.visit_REAL(type_, **kw) + elif 32 < precision <= 64: + return self.visit_DOUBLE(type_, **kw) + else: + raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}") + + def visit_DOUBLE(self, type_, **kw): + return "DOUBLE" + + def visit_NUMERIC(self, type_, **kw): + return self.visit_DECIMAL(type_, **kw) + + def visit_NCHAR(self, type_, **kw): + return self.visit_CHAR(type_, **kw) + + def visit_NVARCHAR(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_TEXT(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_BINARY(self, type_, **kw): + return self.visit_VARBINARY(type_, **kw) + + def visit_CLOB(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_NCLOB(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_BLOB(self, type_, **kw): + return self.visit_VARBINARY(type_, **kw) + + def visit_DATETIME(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + +class TrinoIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py new file mode 100644 index 00000000..df71650d --- /dev/null +++ b/trino/sqlalchemy/datatype.py @@ -0,0 +1,174 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import Dict, Iterator, Type, Union + +from sqlalchemy import util +from sqlalchemy.sql import sqltypes +from sqlalchemy.sql.type_api import TypeEngine + +SQLType = Union[TypeEngine, Type[TypeEngine]] + + +class DOUBLE(sqltypes.Float): + __visit_name__ = "DOUBLE" + + +class MAP(TypeEngine): + __visit_name__ = "MAP" + + def __init__(self, key_type: SQLType, value_type: SQLType): + if isinstance(key_type, type): + key_type = key_type() + self.key_type: TypeEngine = key_type + + if isinstance(value_type, type): + value_type = value_type() + self.value_type: TypeEngine = value_type + + @property + def python_type(self): + return dict + + +class ROW(TypeEngine): + __visit_name__ = "ROW" + + def __init__(self, attr_types: Dict[str, SQLType]): + for name, attr_type in attr_types.items(): + if isinstance(attr_type, type): + attr_type = attr_type() + attr_types[name] = attr_type + self.attr_types: Dict[str, TypeEngine] = attr_types + + @property + def python_type(self): + return dict + + +# https://trino.io/docs/current/language/types.html +_type_map = { + # === Boolean === + 'boolean': sqltypes.BOOLEAN, + + # === Integer === + 'tinyint': sqltypes.SMALLINT, + 'smallint': sqltypes.SMALLINT, + 'integer': sqltypes.INTEGER, + 'bigint': sqltypes.BIGINT, + + # === Floating-point === + 'real': sqltypes.REAL, + 'double': DOUBLE, + + # === Fixed-precision === + 'decimal': sqltypes.DECIMAL, + + # === String === + 'varchar': sqltypes.VARCHAR, + 'char': sqltypes.CHAR, + 'varbinary': sqltypes.VARBINARY, + 'json': sqltypes.JSON, + + # === Date and time === + 'date': sqltypes.DATE, + 'time': sqltypes.TIME, + 'timestamp': sqltypes.TIMESTAMP, + + # 'interval year to month': + # 'interval day to second': + # + # === Structural === + # 'array': ARRAY, + # 'map': MAP + # 'row': ROW + # + # === Mixed === + # 'ipaddress': IPADDRESS + # 'uuid': UUID, + # 'hyperloglog': HYPERLOGLOG, + # 'p4hyperloglog': P4HYPERLOGLOG, + # 'qdigest': QDIGEST, + # 'tdigest': TDIGEST, +} + + +def split(string: str, delimiter: str = ',', + quote: str = '"', escaped_quote: str = r'\"', + open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]: + """ + A split function that is aware of quotes and brackets/parentheses. + + :param string: string to split + :param delimiter: string defining where to split, usually a comma or space + :param quote: string, either a single or a double quote + :param escaped_quote: string representing an escaped quote + :param open_bracket: string, either [, {, < or ( + :param close_bracket: string, either ], }, > or ) + """ + parens = 0 + quotes = False + i = 0 + for j, character in enumerate(string): + complete = parens == 0 and not quotes + if complete and character == delimiter: + yield string[i:j] + i = j + len(delimiter) + elif character == open_bracket: + parens += 1 + elif character == close_bracket: + parens -= 1 + elif character == quote: + if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: + quotes = False + elif not quotes: + quotes = True + yield string[i:] + + +def parse_sqltype(type_str: str) -> TypeEngine: + type_str = type_str.strip().lower() + match = re.match(r'^(?P\w+)\s*(?:\((?P.*)\))?', type_str) + if not match: + util.warn(f"Could not parse type name '{type_str}'") + return sqltypes.NULLTYPE + type_name = match.group("type") + type_opts = match.group("options") + + if type_name == "array": + item_type = parse_sqltype(type_opts) + if isinstance(item_type, sqltypes.ARRAY): + dimensions = (item_type.dimensions or 1) + 1 + return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions) + return sqltypes.ARRAY(item_type) + elif type_name == "map": + key_type_str, value_type_str = split(type_opts) + key_type = parse_sqltype(key_type_str) + value_type = parse_sqltype(value_type_str) + return MAP(key_type, value_type) + elif type_name == "row": + attr_types: Dict[str, SQLType] = {} + for attr_str in split(type_opts): + name, attr_type_str = split(attr_str.strip(), delimiter=' ') + attr_type = parse_sqltype(attr_type_str) + attr_types[name] = attr_type + return ROW(attr_types) + + if type_name not in _type_map: + util.warn(f"Did not recognize type '{type_name}'") + return sqltypes.NULLTYPE + type_class = _type_map[type_name] + type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else [] + if type_name in ('time', 'timestamp'): + type_kwargs = dict(timezone=type_str.endswith("with time zone")) + return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision + return type_class(*type_args) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py new file mode 100644 index 00000000..a1505f57 --- /dev/null +++ b/trino/sqlalchemy/dialect.py @@ -0,0 +1,310 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from textwrap import dedent +from typing import Any, Dict, List, Optional, Tuple + +from sqlalchemy import exc, sql +from sqlalchemy.engine.base import Connection +from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext +from sqlalchemy.engine.url import URL + +from trino import dbapi as trino_dbapi +from trino.auth import BasicAuthentication +from trino.dbapi import Cursor +from . import compiler, datatype, error + + +class TrinoDialect(DefaultDialect): + name = 'trino' + driver = 'rest' + + statement_compiler = compiler.TrinoSQLCompiler + ddl_compiler = compiler.TrinoDDLCompiler + type_compiler = compiler.TrinoTypeCompiler + preparer = compiler.TrinoIdentifierPreparer + + # Data Type + supports_native_enum = False + supports_native_boolean = True + supports_native_decimal = True + + # Column options + supports_sequences = False + supports_comments = True + inline_comments = True + supports_default_values = False + + # DDL + supports_alter = True + + # DML + supports_empty_insert = False + supports_multivalues_insert = True + postfetch_lastrowid = False + + @classmethod + def dbapi(cls): + """ + ref: https://www.python.org/dev/peps/pep-0249/#module-interface + """ + return trino_dbapi + + def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: + args, kwargs = super(TrinoDialect, self).create_connect_args(url) # type: List[Any], Dict[str, Any] + + db_parts = kwargs.pop('database', 'system').split('/') + if len(db_parts) == 1: + kwargs['catalog'] = db_parts[0] + elif len(db_parts) == 2: + kwargs['catalog'] = db_parts[0] + kwargs['schema'] = db_parts[1] + else: + raise ValueError(f'Unexpected database format {url.database}') + + username = kwargs.pop('username', 'anonymous') + kwargs['user'] = username + + password = kwargs.pop('password', None) + if password: + kwargs['http_scheme'] = 'https' + kwargs['auth'] = BasicAuthentication(username, password) + + return args, kwargs + + def get_columns(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f'schema={schema}, table={table_name}') + return self._get_columns(connection, table_name, schema, **kw) + + def _get_columns(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + schema = schema or self._get_default_schema_name(connection) + query = dedent(''' + SELECT + "column_name", + "column_default", + UPPER("is_nullable"), + "data_type" + FROM "information_schema"."columns" + WHERE "table_schema" = :schema AND "table_name" = :table + ORDER BY "ordinal_position" ASC + ''').strip() + res = connection.execute(sql.text(query), schema=schema, table=table_name) + columns = [] + for record in res: + column = dict( + name=record.column_name, + type=datatype.parse_sqltype(record.data_type), + nullable=record.is_nullable == 'YES', + default=record.column_default, + ) + columns.append(column) + return columns + + def get_pk_constraint(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + """Trino has no support for primary keys. Returns a dummy""" + return dict(name=None, constrained_columns=[]) + + def get_primary_keys(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[str]: + pk = self.get_pk_constraint(connection, table_name, schema) + return pk.get('constrained_columns') # type: List[str] + + def get_foreign_keys(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for foreign keys. Returns an empty list.""" + return [] + + def get_schema_names(self, connection: Connection, **kw) -> List[str]: + query = 'SHOW SCHEMAS' + res = connection.execute(sql.text(query)) + return [row.Schema for row in res] + + def get_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + query = 'SHOW TABLES' + if schema: + query = f'{query} FROM {self.identifier_preparer.quote_identifier(schema)}' + res = connection.execute(sql.text(query)) + return [row.Table for row in res] + + def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for temporary tables. Returns an empty list.""" + return [] + + def get_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + schema = schema or self._get_default_schema_name(connection) + if schema is None: + raise exc.NoSuchTableError('schema is required') + query = dedent(''' + SELECT "table_name" + FROM "information_schema"."views" + WHERE "table_schema" = :schema + ''').strip() + res = connection.execute(sql.text(query), schema=schema) + return [row.table_name for row in res] + + def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + """Trino has no support for temporary views. Returns an empty list.""" + return [] + + def get_view_definition(self, connection: Connection, view_name: str, schema: str = None, **kw) -> str: + full_view = self._get_full_table(view_name, schema) + query = f'SHOW CREATE VIEW {full_view}' + try: + res = connection.execute(sql.text(query)) + return res.scalar() + except error.TrinoQueryError as e: + if e.error_name in ( + error.TABLE_NOT_FOUND, + error.SCHEMA_NOT_FOUND, + error.CATALOG_NOT_FOUND, + ): + raise exc.NoSuchTableError(full_view) from e + raise + + def get_indexes(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f'schema={schema}, table={table_name}') + + partitioned_columns = self._get_columns(connection, f'{table_name}$partitions', schema, **kw) + partition_index = dict( + name='partition', + column_names=[col['name'] for col in partitioned_columns], + unique=False + ) + return [partition_index, ] + + def get_unique_constraints(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for unique constraints. Returns an empty list.""" + return [] + + def get_check_constraints(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + """Trino has no support for check constraints. Returns an empty list.""" + return [] + + def get_table_comment(self, connection: Connection, + table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + properties_table = self._get_full_table(f"{table_name}$properties", schema) + query = f'SELECT "comment" FROM {properties_table}' + try: + res = connection.execute(sql.text(query)) + return dict(text=res.scalar()) + except error.TrinoQueryError as e: + if e.error_name in ( + error.NOT_FOUND, + error.COLUMN_NOT_FOUND, + error.TABLE_NOT_FOUND, + ): + return dict(text=None) + raise + + def has_schema(self, connection: Connection, schema: str) -> bool: + query = f"SHOW SCHEMAS LIKE '{schema}'" + try: + res = connection.execute(sql.text(query)) + return res.first() is not None + except error.TrinoQueryError as e: + if e.error_name in ( + error.TABLE_NOT_FOUND, + error.SCHEMA_NOT_FOUND, + error.CATALOG_NOT_FOUND, + ): + return False + raise + + def has_table(self, connection: Connection, + table_name: str, schema: str = None) -> bool: + query = 'SHOW TABLES' + if schema: + query = f'{query} FROM {self.identifier_preparer.quote_identifier(schema)}' + query = f"{query} LIKE '{table_name}'" + try: + res = connection.execute(sql.text(query)) + return res.first() is not None + except error.TrinoQueryError as e: + if e.error_name in ( + error.TABLE_NOT_FOUND, + error.SCHEMA_NOT_FOUND, + error.CATALOG_NOT_FOUND, + error.MISSING_SCHEMA_NAME, + ): + return False + raise + + def has_sequence(self, connection: Connection, + sequence_name: str, schema: str = None) -> bool: + """Trino has no support for sequence. Returns False indicate that given sequence does not exists.""" + return False + + def _get_server_version_info(self, connection: Connection) -> Tuple[int, ...]: + query = 'SELECT version()' + res = connection.execute(sql.text(query)) + version = res.scalar() + return tuple([version]) + + def _get_default_schema_name(self, connection: Connection) -> Optional[str]: + dbapi_connection: trino_dbapi.Connection = connection.connection + return dbapi_connection.schema + + def do_execute(self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], + context: DefaultExecutionContext = None): + cursor.execute(statement, parameters) + if context and context.should_autocommit: + # SQL statement only submitted to Trino server when cursor.fetch*() is called. + # For DDL (CREATE/ALTER/DROP) and DML (INSERT/UPDATE/DELETE) statement, call cursor.description + # to force submit statement immediately. + cursor.description # noqa + + def do_rollback(self, dbapi_connection): + if dbapi_connection.transaction is not None: + dbapi_connection.rollback() + + def do_begin_twophase(self, connection: Connection, xid): + pass + + def do_prepare_twophase(self, connection: Connection, xid): + pass + + def do_rollback_twophase(self, connection: Connection, xid, + is_prepared: bool = True, recover: bool = False) -> None: + pass + + def do_commit_twophase(self, connection: Connection, xid, + is_prepared: bool = True, recover: bool = False) -> None: + pass + + def do_recover_twophase(self, connection: Connection) -> None: + pass + + def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level) -> None: + dbapi_conn._isolation_level = getattr(trino_dbapi.IsolationLevel, level) + + def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str: + level_names = ['AUTOCOMMIT', + 'READ_UNCOMMITTED', + 'READ_COMMITTED', + 'REPEATABLE_READ', + 'SERIALIZABLE'] + return level_names[dbapi_conn.isolation_level] + + def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str: + table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name + if schema: + schema_part = self.identifier_preparer.quote_identifier(schema) if quote else schema + return f'{schema_part}.{table_part}' + + return table_part diff --git a/trino/sqlalchemy/error.py b/trino/sqlalchemy/error.py new file mode 100644 index 00000000..3079d6eb --- /dev/null +++ b/trino/sqlalchemy/error.py @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from trino.exceptions import TrinoQueryError # noqa + +# ref: https://github.com/trinodb/trino/blob/master/core/trino-spi/src/main/java/io/trino/spi/StandardErrorCode.java +NOT_FOUND = 'NOT_FOUND' +COLUMN_NOT_FOUND = 'COLUMN_NOT_FOUND' +TABLE_NOT_FOUND = 'TABLE_NOT_FOUND' +SCHEMA_NOT_FOUND = 'SCHEMA_NOT_FOUND' +CATALOG_NOT_FOUND = 'CATALOG_NOT_FOUND' + +MISSING_TABLE = 'MISSING_TABLE' +MISSING_COLUMN_NAME = 'MISSING_COLUMN_NAME' +MISSING_SCHEMA_NAME = 'MISSING_SCHEMA_NAME' +MISSING_CATALOG_NAME = 'MISSING_CATALOG_NAME'