From 68e6f58e73fd55db97f37d734f8269c8e84c36ca Mon Sep 17 00:00:00 2001 From: Ke Zhu Date: Tue, 31 Aug 2021 21:13:43 -0400 Subject: [PATCH] Support connection parameter sessionUser --- sqlalchemy_trino/dialect.py | 3 ++- tests/test_dialect.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 tests/test_dialect.py diff --git a/sqlalchemy_trino/dialect.py b/sqlalchemy_trino/dialect.py index d8bed22..85efbb1 100644 --- a/sqlalchemy_trino/dialect.py +++ b/sqlalchemy_trino/dialect.py @@ -85,7 +85,8 @@ def create_connect_args(self, url: URL) -> Tuple[List[Any], Dict[str, Any]]: raise ValueError(f'Unexpected database format {url.database}') username = kwargs.pop('username', 'anonymous') - kwargs['user'] = username + session_user = kwargs.pop('sessionUser', username) + kwargs['user'] = session_user password = kwargs.pop('password', None) if password: diff --git a/tests/test_dialect.py b/tests/test_dialect.py new file mode 100644 index 0000000..29d1339 --- /dev/null +++ b/tests/test_dialect.py @@ -0,0 +1,21 @@ +from sqlalchemy.engine import url +from sqlalchemy_trino.dialect import TrinoDialect + + +def test_trino_connection_string_user(): + dialect = TrinoDialect() + username = 'test-user' + u = url.make_url(f'trino://{username}@host') + _, cparams = dialect.create_connect_args(u) + + assert cparams['user'] == username + + +def test_trino_connection_string_session_user(): + dialect = TrinoDialect() + username = 'test-user' + session_user = 'sess-user' + u = url.make_url(f'trino://{username}@host/?sessionUser={session_user}') + _, cparams = dialect.create_connect_args(u) + + assert cparams['user'] == session_user