Skip to content

Commit

Permalink
Enforce server side parameters to be strings (#827)
Browse files Browse the repository at this point in the history
* Enfore server side parameters to be strings

* Add changie

* Iterate over items

* Test for casting server_side_parameters to strings

---------

Co-authored-by: colin-rogers-dbt <111200756+colin-rogers-dbt@users.noreply.github.com>
  • Loading branch information
JCZuurmond and colin-rogers-dbt authored Jul 26, 2023
1 parent 955564d commit 53809c3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230707-114650.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enforce server side parameters keys and values to be strings
time: 2023-07-07T11:46:50.390918+02:00
custom:
Author: Fokko,JCZuurmond
Issue: "826"
6 changes: 5 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class SparkCredentials(Credentials):
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False
server_side_parameters: Dict[str, Any] = field(default_factory=dict)
server_side_parameters: Dict[str, str] = field(default_factory=dict)
retry_all: bool = False

@classmethod
Expand Down Expand Up @@ -142,6 +142,10 @@ def __post_init__(self) -> None:
if self.method != SparkConnectionMethod.SESSION:
self.host = self.host.rstrip("/")

self.server_side_parameters = {
str(key): str(value) for key, value in self.server_side_parameters.items()
}

@property
def type(self) -> str:
return "spark"
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dbt.adapters.spark.connections import SparkConnectionMethod, SparkCredentials


def test_credentials_server_side_parameters_keys_and_values_are_strings() -> None:
credentials = SparkCredentials(
host="localhost",
method=SparkConnectionMethod.THRIFT,
database="tests",
schema="tests",
server_side_parameters={"spark.configuration": 10},
)
assert credentials.server_side_parameters["spark.configuration"] == "10"

0 comments on commit 53809c3

Please sign in to comment.