Skip to content

Commit

Permalink
refactor: change IsolationLevel to Enum
Browse files Browse the repository at this point in the history
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
  • Loading branch information
dungdm93 committed Sep 13, 2021
1 parent c424c64 commit e79ab74
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
16 changes: 16 additions & 0 deletions tests/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import List, Any, Dict
from unittest import mock

import pytest
from assertpy import assert_that
from sqlalchemy.engine import make_url
from sqlalchemy.engine.url import URL

from trino.auth import BasicAuthentication
from trino.dbapi import Connection
from trino.sqlalchemy.dialect import TrinoDialect
from trino.transaction import IsolationLevel


class TestTrinoDialect:
Expand All @@ -32,3 +35,16 @@ def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_

assert_that(actual_args).is_equal_to(expected_args)
assert_that(actual_kwargs).is_equal_to(expected_kwargs)

def test_get_default_isolation_level(self):
isolation_level = self.dialect.get_default_isolation_level(mock.Mock())
assert_that(isolation_level).is_equal_to('AUTOCOMMIT')

def test_isolation_level(self):
dbapi_conn = Connection(host="localhost")

self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE")
assert_that(dbapi_conn._isolation_level).is_equal_to(IsolationLevel.SERIALIZABLE)

isolation_level = self.dialect.get_isolation_level(dbapi_conn)
assert_that(isolation_level).is_equal_to("SERIALIZABLE")
16 changes: 7 additions & 9 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from trino import dbapi as trino_dbapi
from trino.auth import BasicAuthentication
from trino.dbapi import Cursor
from . import compiler, datatype, error
from trino.sqlalchemy import compiler, datatype, error


class TrinoDialect(DefaultDialect):
Expand Down Expand Up @@ -291,16 +291,14 @@ def do_commit_twophase(self, connection: Connection, xid: str,
def do_recover_twophase(self, connection: Connection) -> None:
pass

def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level) -> None:
dbapi_conn._isolation_level = getattr(trino_dbapi.IsolationLevel, level)
def set_isolation_level(self, dbapi_conn: trino_dbapi.Connection, level: str) -> None:
dbapi_conn._isolation_level = trino_dbapi.IsolationLevel[level]

def get_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str:
level_names = ['AUTOCOMMIT',
'READ_UNCOMMITTED',
'READ_COMMITTED',
'REPEATABLE_READ',
'SERIALIZABLE']
return level_names[dbapi_conn.isolation_level]
return dbapi_conn.isolation_level.name

def get_default_isolation_level(self, dbapi_conn: trino_dbapi.Connection) -> str:
return trino_dbapi.IsolationLevel.AUTOCOMMIT.name

def _get_full_table(self, table_name: str, schema: str = None, quote: bool = True) -> str:
table_part = self.identifier_preparer.quote_identifier(table_name) if quote else table_name
Expand Down
28 changes: 11 additions & 17 deletions trino/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, unique
from typing import Iterable

from trino import constants
import trino.client
import trino.exceptions
import trino.logging

from trino import constants

logger = trino.logging.get_logger(__name__)


NO_TRANSACTION = "NONE"
START_TRANSACTION = "START TRANSACTION"
ROLLBACK = "ROLLBACK"
COMMIT = "COMMIT"


class IsolationLevel(object):
@unique
class IsolationLevel(Enum):
AUTOCOMMIT = 0
READ_UNCOMMITTED = 1
READ_COMMITTED = 2
Expand All @@ -35,16 +35,16 @@ class IsolationLevel(object):

@classmethod
def levels(cls) -> Iterable[str]:
return {k for k, v in cls.__dict__.items() if not k.startswith("_") and isinstance(v, int)}
return {isolation_level.name for isolation_level in IsolationLevel}

@classmethod
def values(cls) -> Iterable[int]:
return {getattr(cls, level) for level in cls.levels()}
return {isolation_level.value for isolation_level in IsolationLevel}

@classmethod
def check(cls, level: int) -> int:
if level not in cls.values():
raise ValueError("invalid isolation level {}".format(level))
raise ValueError(f"invalid isolation level {level}")
return level


Expand All @@ -60,9 +60,7 @@ def id(self):
def begin(self):
response = self._request.post(START_TRANSACTION)
if not response.ok:
raise trino.exceptions.DatabaseError(
"failed to start transaction: {}".format(response.status_code)
)
raise trino.exceptions.DatabaseError(f"failed to start transaction: {response.status_code}")
transaction_id = response.headers.get(constants.HEADER_STARTED_TRANSACTION)
if transaction_id and transaction_id != NO_TRANSACTION:
self._id = response.headers[constants.HEADER_STARTED_TRANSACTION]
Expand All @@ -74,16 +72,14 @@ def begin(self):
self._id = response.headers[constants.HEADER_STARTED_TRANSACTION]
status = self._request.process(response)
self._request.transaction_id = self._id
logger.info("transaction started: " + self._id)
logger.info("transaction started: %s", self._id)

def commit(self):
query = trino.client.TrinoQuery(self._request, COMMIT)
try:
list(query.execute())
except Exception as err:
raise trino.exceptions.DatabaseError(
"failed to commit transaction {}: {}".format(self._id, err)
)
raise trino.exceptions.DatabaseError(f"failed to commit transaction {self._id}") from err
self._id = NO_TRANSACTION
self._request.transaction_id = self._id

Expand All @@ -92,8 +88,6 @@ def rollback(self):
try:
list(query.execute())
except Exception as err:
raise trino.exceptions.DatabaseError(
"failed to rollback transaction {}: {}".format(self._id, err)
)
raise trino.exceptions.DatabaseError(f"failed to rollback transaction {self._id}") from err
self._id = NO_TRANSACTION
self._request.transaction_id = self._id

0 comments on commit e79ab74

Please sign in to comment.