Skip to content

Commit

Permalink
user is required
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Sep 12, 2021
1 parent 6a99009 commit 60c27d1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
20 changes: 16 additions & 4 deletions tests/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def setup(self):
@pytest.mark.parametrize(
'url, expected_args, expected_kwargs',
[
(make_url('trino://localhost'),
list(), dict(host='localhost', catalog='system', user='anonymous')),
(make_url('trino://1.2.3.4:4321/mysql/sakila'),
list(), dict(host='1.2.3.4', port=4321, catalog='mysql', schema='sakila', user='anonymous')),
(make_url('trino://user@localhost'),
list(), dict(host='localhost', catalog='system', user='user')),
(make_url('trino://user@localhost:8080'),
list(), dict(host='localhost', port=8080, catalog='system', user='user')),
(make_url('trino://user:pass@localhost:8080'),
list(), dict(host='localhost', port=8080, catalog='system', user='user',
auth=BasicAuthentication('user', 'pass'), http_scheme='https')),
Expand All @@ -36,6 +36,18 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_
assert_that(actual_args).is_equal_to(expected_args)
assert_that(actual_kwargs).is_equal_to(expected_kwargs)

def test_create_connect_args_missing_user(self):
url = make_url('trino://localhost')
assert_that(self.dialect.create_connect_args).raises(ValueError) \
.when_called_with(url) \
.is_equal_to('user is required')

def test_create_connect_args_wrong_db_format(self):
url = make_url('trino://abc@localhost/catalog/schema/foobar')
assert_that(self.dialect.create_connect_args).raises(ValueError) \
.when_called_with(url) \
.is_equal_to('Unexpected database format catalog/schema/foobar')

def test_get_default_isolation_level(self):
isolation_level = self.dialect.get_default_isolation_level(mock.Mock())
assert_that(isolation_level).is_equal_to('AUTOCOMMIT')
Expand Down
22 changes: 13 additions & 9 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple

from sqlalchemy import exc, sql
from sqlalchemy.engine.base import Connection
Expand Down Expand Up @@ -58,10 +58,14 @@ def dbapi(cls):
"""
return trino_dbapi

def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
args, kwargs = super(TrinoDialect, self).create_connect_args(url) # type: List[Any], Dict[str, Any]
def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any]]:
args = list()
kwargs = dict(host=url.host)

db_parts = kwargs.pop('database', 'system').split('/')
if url.port:
kwargs['port'] = url.port

db_parts = (url.database or 'system').split('/')
if len(db_parts) == 1:
kwargs['catalog'] = db_parts[0]
elif len(db_parts) == 2:
Expand All @@ -70,13 +74,13 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
else:
raise ValueError(f'Unexpected database format {url.database}')

username = kwargs.pop('username', 'anonymous')
kwargs['user'] = username
if not url.username:
raise ValueError(f"user is required")
kwargs['user'] = url.username

password = kwargs.pop('password', None)
if password:
if url.password:
kwargs['http_scheme'] = 'https'
kwargs['auth'] = BasicAuthentication(username, password)
kwargs['auth'] = BasicAuthentication(url.username, url.password)

return args, kwargs

Expand Down

0 comments on commit 60c27d1

Please sign in to comment.