Skip to content

Commit

Permalink
don't provide default username
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Sep 13, 2021
1 parent 360a652 commit 0ad84e8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
20 changes: 16 additions & 4 deletions tests/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def setup(self):
@pytest.mark.parametrize(
'url, expected_args, expected_kwargs',
[
(make_url('trino://localhost'),
list(), dict(host='localhost', catalog='system', user='anonymous')),
(make_url('trino://1.2.3.4:4321/mysql/sakila'),
list(), dict(host='1.2.3.4', port=4321, catalog='mysql', schema='sakila', user='anonymous')),
(make_url('trino://user@localhost'),
list(), dict(host='localhost', catalog='system', user='user')),
(make_url('trino://user@localhost:8080'),
list(), dict(host='localhost', port=8080, catalog='system', user='user')),
(make_url('trino://user:pass@localhost:8080'),
list(), dict(host='localhost', port=8080, catalog='system', user='user',
auth=BasicAuthentication('user', 'pass'), http_scheme='https')),
Expand All @@ -36,6 +36,18 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_
assert_that(actual_args).is_equal_to(expected_args)
assert_that(actual_kwargs).is_equal_to(expected_kwargs)

def test_create_connect_args_missing_user_when_specify_password(self):
url = make_url('trino://:pass@localhost')
assert_that(self.dialect.create_connect_args).raises(ValueError) \
.when_called_with(url) \
.is_equal_to('Username is required when specify password in connection URL')

def test_create_connect_args_wrong_db_format(self):
url = make_url('trino://abc@localhost/catalog/schema/foobar')
assert_that(self.dialect.create_connect_args).raises(ValueError) \
.when_called_with(url) \
.is_equal_to('Unexpected database format catalog/schema/foobar')

def test_get_default_isolation_level(self):
isolation_level = self.dialect.get_default_isolation_level(mock.Mock())
assert_that(isolation_level).is_equal_to('AUTOCOMMIT')
Expand Down
23 changes: 14 additions & 9 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple

from sqlalchemy import exc, sql
from sqlalchemy.engine.base import Connection
Expand Down Expand Up @@ -58,10 +58,14 @@ def dbapi(cls):
"""
return trino_dbapi

def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
args, kwargs = super(TrinoDialect, self).create_connect_args(url) # type: List[Any], Dict[str, Any]
def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any]]:
args = list()
kwargs = dict(host=url.host)

db_parts = kwargs.pop('database', 'system').split('/')
if url.port:
kwargs['port'] = url.port

db_parts = (url.database or 'system').split('/')
if len(db_parts) == 1:
kwargs['catalog'] = db_parts[0]
elif len(db_parts) == 2:
Expand All @@ -70,13 +74,14 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]:
else:
raise ValueError(f'Unexpected database format {url.database}')

username = kwargs.pop('username', 'anonymous')
kwargs['user'] = username
if url.username:
kwargs['user'] = url.username

password = kwargs.pop('password', None)
if password:
if url.password:
if not url.username:
raise ValueError(f'Username is required when specify password in connection URL')
kwargs['http_scheme'] = 'https'
kwargs['auth'] = BasicAuthentication(username, password)
kwargs['auth'] = BasicAuthentication(url.username, url.password)

return args, kwargs

Expand Down

0 comments on commit 0ad84e8

Please sign in to comment.