diff --git a/tests/sqlalchemy/test_dialect.py b/tests/sqlalchemy/test_dialect.py index 1f47d2ff..24213ebf 100644 --- a/tests/sqlalchemy/test_dialect.py +++ b/tests/sqlalchemy/test_dialect.py @@ -19,12 +19,12 @@ def setup(self): @pytest.mark.parametrize( 'url, expected_args, expected_kwargs', [ - (make_url('trino://localhost'), - list(), dict(host='localhost', catalog='system', user='anonymous')), - (make_url('trino://1.2.3.4:4321/mysql/sakila'), - list(), dict(host='1.2.3.4', port=4321, catalog='mysql', schema='sakila', user='anonymous')), + (make_url('trino://user@localhost'), + list(), dict(host='localhost', catalog='system', user='user')), + (make_url('trino://user@localhost:8080'), list(), dict(host='localhost', port=8080, catalog='system', user='user')), + (make_url('trino://user:pass@localhost:8080'), list(), dict(host='localhost', port=8080, catalog='system', user='user', auth=BasicAuthentication('user', 'pass'), http_scheme='https')), @@ -36,6 +36,18 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_ assert_that(actual_args).is_equal_to(expected_args) assert_that(actual_kwargs).is_equal_to(expected_kwargs) + def test_create_connect_args_missing_user(self): + url = make_url('trino://localhost') + assert_that(self.dialect.create_connect_args).raises(ValueError) \ + .when_called_with(url) \ + .is_equal_to('user is required') + + def test_create_connect_args_wrong_db_format(self): + url = make_url('trino://abc@localhost/catalog/schema/foobar') + assert_that(self.dialect.create_connect_args).raises(ValueError) \ + .when_called_with(url) \ + .is_equal_to('Unexpected database format catalog/schema/foobar') + def test_get_default_isolation_level(self): isolation_level = self.dialect.get_default_isolation_level(mock.Mock()) assert_that(isolation_level).is_equal_to('AUTOCOMMIT') diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 969238b8..1fd22f84 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from textwrap import dedent -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple from sqlalchemy import exc, sql from sqlalchemy.engine.base import Connection @@ -58,10 +58,14 @@ def dbapi(cls): """ return trino_dbapi - def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: - args, kwargs = super(TrinoDialect, self).create_connect_args(url) # type: List[Any], Dict[str, Any] + def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any]]: + args = list() + kwargs = dict(host=url.host) - db_parts = kwargs.pop('database', 'system').split('/') + if url.port: + kwargs['port'] = url.port + + db_parts = (url.database or 'system').split('/') if len(db_parts) == 1: kwargs['catalog'] = db_parts[0] elif len(db_parts) == 2: @@ -70,13 +74,13 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: else: raise ValueError(f'Unexpected database format {url.database}') - username = kwargs.pop('username', 'anonymous') - kwargs['user'] = username + if not url.username: + raise ValueError(f"user is required") + kwargs['user'] = url.username - password = kwargs.pop('password', None) - if password: + if url.password: kwargs['http_scheme'] = 'https' - kwargs['auth'] = BasicAuthentication(username, password) + kwargs['auth'] = BasicAuthentication(url.username, url.password) return args, kwargs