Skip to content

Commit

Permalink
Add ConnectionWrapper base class (#828)
Browse files Browse the repository at this point in the history
* Add ConnectionWrapper base class

* Changie

* Rename to SparkConnectionWrapper

* Cleanup
  • Loading branch information
Fokko authored Aug 10, 2023
1 parent 0b80b47 commit 5ed503a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230707-135442.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add SessionConnectionWrapper
time: 2023-07-07T13:54:42.41341+02:00
custom:
Author: Fokko
Issue: "829"
49 changes: 45 additions & 4 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from dbt.contracts.connection import Connection
from hologram.helpers import StrEnum
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable
from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable, Sequence

from abc import ABC, abstractmethod

try:
from thrift.transport.TSSLSocket import TSSLSocket
Expand Down Expand Up @@ -158,7 +160,42 @@ def _connection_keys(self) -> Tuple[str, ...]:
return "host", "port", "cluster", "endpoint", "schema", "organization"


class PyhiveConnectionWrapper(object):
class SparkConnectionWrapper(ABC):
@abstractmethod
def cursor(self) -> "SparkConnectionWrapper":
pass

@abstractmethod
def cancel(self) -> None:
pass

@abstractmethod
def close(self) -> None:
pass

@abstractmethod
def rollback(self) -> None:
pass

@abstractmethod
def fetchall(self) -> Optional[List]:
pass

@abstractmethod
def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None:
pass

@property
@abstractmethod
def description(
self,
) -> Sequence[
Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool]
]:
pass


class PyhiveConnectionWrapper(SparkConnectionWrapper):
"""Wrap a Spark connection in a way that no-ops transactions"""

# https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa
Expand Down Expand Up @@ -268,7 +305,11 @@ def _fix_binding(cls, value: Any) -> Union[float, str]:
return value

@property
def description(self) -> Tuple[Tuple[str, Any, int, int, int, int, bool]]:
def description(
self,
) -> Sequence[
Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool]
]:
assert self._cursor, "Cursor not available"
return self._cursor.description

Expand Down Expand Up @@ -354,7 +395,7 @@ def open(cls, connection: Connection) -> Connection:

creds = connection.credentials
exc = None
handle: Any
handle: SparkConnectionWrapper

for i in range(1 + creds.connect_retries):
try:
Expand Down
17 changes: 12 additions & 5 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import datetime as dt
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence

from dbt.events import AdapterLogger
from dbt.utils import DECIMALS
from pyspark.sql import DataFrame, Row, SparkSession
from dbt.adapters.spark.connections import SparkConnectionWrapper


logger = AdapterLogger("Spark")
Expand Down Expand Up @@ -44,13 +45,15 @@ def __exit__(
@property
def description(
self,
) -> List[Tuple[str, str, None, None, None, None, bool]]:
) -> Sequence[
Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool]
]:
"""
Get the description.
Returns
-------
out : List[Tuple[str, str, None, None, None, None, bool]]
out : Sequence[Tuple[str, str, None, None, None, None, bool]]
The description.
Source
Expand Down Expand Up @@ -180,7 +183,7 @@ def cursor(self) -> Cursor:
return Cursor(server_side_parameters=self.server_side_parameters)


class SessionConnectionWrapper(object):
class SessionConnectionWrapper(SparkConnectionWrapper):
"""Connection wrapper for the session connection method."""

handle: Connection
Expand Down Expand Up @@ -220,7 +223,11 @@ def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None:
self._cursor.execute(sql, *bindings)

@property
def description(self) -> List[Tuple[str, str, None, None, None, None, bool]]:
def description(
self,
) -> Sequence[
Tuple[str, Any, Optional[int], Optional[int], Optional[int], Optional[int], bool]
]:
assert self._cursor, "Cursor not available"
return self._cursor.description

Expand Down

0 comments on commit 5ed503a

Please sign in to comment.