Skip to content

Commit

Permalink
fix: handle default sizing for unbound varchar in string_size()
Browse files Browse the repository at this point in the history
set the string_size of an unbound varchar to max length of a varchar
add some unit tests to cover the bound and unbound varchar parsing
  • Loading branch information
kalvinnchau authored and mdesmet committed Feb 10, 2023
1 parent 796de9c commit a7908e4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
11 changes: 11 additions & 0 deletions dbt/adapters/trino/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from dbt.adapters.base.column import Column
from dbt.exceptions import DbtRuntimeError

# Taken from the MAX_LENGTH variable in
# https://github.com/trinodb/trino/blob/master/core/trino-spi/src/main/java/io/trino/spi/type/VarcharType.java
TRINO_VARCHAR_MAX_LENGTH = 2147483646


@dataclass
class TrinoColumn(Column):
Expand All @@ -28,6 +32,13 @@ def data_type(self):
def string_type(cls, size: int) -> str:
return "varchar({})".format(size)

def string_size(self) -> int:
# override the string_size function to handle the unbound varchar case
if self.dtype.lower() == "varchar" and self.char_size is None:
return TRINO_VARCHAR_MAX_LENGTH

return super().string_size()

@classmethod
def from_description(cls, name: str, raw_data_type: str) -> "Column":
# some of the Trino data types specify a type and not a precision
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dbt.exceptions import DbtDatabaseError, DbtRuntimeError, FailedToConnectError

from dbt.adapters.trino import TrinoAdapter
from dbt.adapters.trino.column import TRINO_VARCHAR_MAX_LENGTH, TrinoColumn
from dbt.adapters.trino.connections import (
HttpScheme,
TrinoCertificateCredentials,
Expand Down Expand Up @@ -546,3 +547,28 @@ def test_convert_date_type(self):
expected = ["DATE", "DATE", "DATE"]
for col_idx, expect in enumerate(expected):
assert TrinoAdapter.convert_date_type(agate_table, col_idx) == expect


class TestTrinoColumn(unittest.TestCase):
def test_bound_varchar(self):
col = TrinoColumn.from_description("my_col", "VARCHAR(100)")
assert col.column == "my_col"
assert col.dtype == "VARCHAR"
assert col.char_size == 100
# bounded varchars get formatted to lowercase
assert col.data_type == "varchar(100)"
assert col.string_size() == 100
assert col.is_string() is True
assert col.is_number() is False
assert col.is_numeric() is False

def test_unbound_varchar(self):
col = TrinoColumn.from_description("my_col", "VARCHAR")
assert col.column == "my_col"
assert col.dtype == "VARCHAR"
assert col.char_size is None
assert col.data_type == "VARCHAR"
assert col.string_size() == TRINO_VARCHAR_MAX_LENGTH
assert col.is_string() is True
assert col.is_number() is False
assert col.is_numeric() is False

0 comments on commit a7908e4

Please sign in to comment.