Skip to content

Commit

Permalink
feat: add trino sqlalchemy dialect
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Sep 2, 2021
1 parent c8aa76c commit ff00675
Show file tree
Hide file tree
Showing 15 changed files with 975 additions and 26 deletions.
11 changes: 11 additions & 0 deletions integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
4 changes: 2 additions & 2 deletions integration_tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import pytest
import pytz

import trino
from conftest import TRINO_VERSION
import trino.dbapi
from trino.exceptions import TrinoQueryError
from trino.transaction import IsolationLevel
from .conftest import TRINO_VERSION


@pytest.fixture
Expand Down
28 changes: 16 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@

import ast
import re
from setuptools import setup
import textwrap

from setuptools import setup

_version_re = re.compile(r"__version__\s+=\s+(.*)")


with open("trino/__init__.py", "rb") as f:
trino_version = _version_re.search(f.read().decode("utf-8"))
assert trino_version is not None
version = str(ast.literal_eval(trino_version.group(1)))


kerberos_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy~=1.3"]

all_require = kerberos_require + []
all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click"]
tests_require = all_require + ["httpretty", "pytest", "pytest-runner", "pytz", "click", "assertpy"]

setup(
name="trino",
Expand All @@ -44,19 +43,17 @@
description="Client for the Trino distributed SQL Engine",
long_description=textwrap.dedent(
"""
Client for Trino (https://trino.io), a distributed SQL engine for
interactive and batch big data processing. Provides a low-level client and
a DBAPI 2.0 implementation.
"""
Client for Trino (https://trino.io), a distributed SQL engine for
interactive and batch big data processing. Provides a low-level client and
a DBAPI 2.0 implementation.
"""
),
license="Apache 2.0",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX",
"Operating System :: Microsoft :: Windows",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
Expand All @@ -65,13 +62,20 @@
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Topic :: Database",
"Topic :: Database :: Front-Ends",
],
python_requires='>=3.6',
install_requires=["requests"],
extras_require={
"all": all_require,
"kerberos": kerberos_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
},
entry_points={
"sqlalchemy.dialects": [
"trino = trino.sqlalchemy.dialect:TrinoDialect",
]
},
)
Empty file added tests/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
45 changes: 45 additions & 0 deletions tests/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from assertpy import add_extension, assert_that
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType


def assert_sqltype(this: SQLType, that: SQLType):
if isinstance(this, type):
this = this()
if isinstance(that, type):
that = that()
assert_that(type(this)).is_same_as(type(that))
if isinstance(this, ARRAY):
assert_sqltype(this.item_type, that.item_type)
if this.dimensions is None or this.dimensions == 1:
assert_that(that.dimensions).is_in(None, 1)
else:
assert_that(this.dimensions).is_equal_to(this.dimensions)
elif isinstance(this, MAP):
assert_sqltype(this.key_type, that.key_type)
assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
for name, this_attr in this.attr_types.items():
that_attr = this.attr_types[name]
assert_sqltype(this_attr, that_attr)
else:
assert_that(str(this)).is_equal_to(str(that))


@add_extension
def is_sqltype(self, that):
this = self.val
assert_sqltype(this, that)
127 changes: 127 additions & 0 deletions tests/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from assertpy import assert_that
from sqlalchemy.sql.sqltypes import (
CHAR, VARCHAR,
ARRAY,
INTEGER, DECIMAL,
DATE, TIME, TIMESTAMP
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW


@pytest.mark.parametrize(
'type_str, sql_type',
datatype._type_map.items(),
ids=datatype._type_map.keys()
)
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_that(actual_type).is_equal_to(sql_type)


parse_type_options_testcases = {
'VARCHAR(10)': VARCHAR(10),
'DECIMAL(20)': DECIMAL(20),
'DECIMAL(20, 3)': DECIMAL(20, 3),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_type_options_testcases.items(),
ids=parse_type_options_testcases.keys()
)
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_array_testcases = {
'array(integer)': ARRAY(INTEGER()),
'array(varchar(10))': ARRAY(VARCHAR(10)),
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_array_testcases.items(),
ids=parse_array_testcases.keys()
)
def test_parse_array(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_map_testcases = {
'map(char, integer)': MAP(CHAR(), INTEGER()),
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_map_testcases.items(),
ids=parse_map_testcases.keys()
)
def test_parse_map(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_row_testcases = {
'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())),
'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))),
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_row_testcases.items(),
ids=parse_row_testcases.keys()
)
def test_parse_row(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)


parse_datetime_testcases = {
'date': DATE(),
'time': TIME(),
'time with time zone': TIME(timezone=True),
'timestamp': TIMESTAMP(),
'timestamp with time zone': TIMESTAMP(timezone=True),
}


@pytest.mark.parametrize(
'type_str, sql_type',
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys()
)
def test_parse_datetime(type_str: str, sql_type: ARRAY):
actual_type = datatype.parse_sqltype(type_str)
assert_that(actual_type).is_sqltype(sql_type)
64 changes: 64 additions & 0 deletions tests/sqlalchemy/test_datatype_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import pytest
from assertpy import assert_that

from trino.sqlalchemy import datatype

split_string_testcases = {
'10': ['10'],
'10,3': ['10', '3'],
'varchar': ['varchar'],
'varchar,int': ['varchar', 'int'],
'varchar,int,float': ['varchar', 'int', 'float'],
'array(varchar)': ['array(varchar)'],
'array(varchar),int': ['array(varchar)', 'int'],
'array(varchar(20))': ['array(varchar(20))'],
'array(varchar(20)),int': ['array(varchar(20))', 'int'],
'array(varchar(20)),array(varchar(20))': ['array(varchar(20))', 'array(varchar(20))'],
'map(varchar, integer),int': ['map(varchar, integer)', 'int'],
'map(varchar(20), integer),int': ['map(varchar(20), integer)', 'int'],
'map(varchar(20), varchar(20)),int': ['map(varchar(20), varchar(20))', 'int'],
'map(varchar(20), varchar(20)),array(varchar)': ['map(varchar(20), varchar(20))', 'array(varchar)'],
'row(first_name varchar(20), last_name varchar(20)),int':
['row(first_name varchar(20), last_name varchar(20))', 'int'],
}


@pytest.mark.parametrize(
'input_string, output_strings',
split_string_testcases.items(),
ids=split_string_testcases.keys()
)
def test_split_string(input_string: str, output_strings: List[str]):
actual = list(datatype.split(input_string))
assert_that(actual).is_equal_to(output_strings)


split_delimiter_testcases = [
('first,second', ',', ['first', 'second']),
('first second', ' ', ['first', 'second']),
('first|second', '|', ['first', 'second']),
('first,second third', ',', ['first', 'second third']),
('first,second third', ' ', ['first,second', 'third']),
]


@pytest.mark.parametrize(
'input_string, delimiter, output_strings',
split_delimiter_testcases,
)
def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]):
actual = list(datatype.split(input_string, delimiter=delimiter))
assert_that(actual).is_equal_to(output_strings)
9 changes: 0 additions & 9 deletions trino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from . import auth
from . import dbapi
from . import client
from . import constants
from . import exceptions
from . import logging

__all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging']

__version__ = "0.305.0"
Loading

0 comments on commit ff00675

Please sign in to comment.