From 8ea1597f3757adef4c940e7b0cf88622c46b5322 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Sat, 24 Jun 2023 20:41:07 +0200 Subject: [PATCH] Disallow untyped `def`'s (#767) * Disallow untyped `def`'s --------- Co-authored-by: Mike Alfare <13974384+mikealfare@users.noreply.github.com> --- .../unreleased/Fixes-20230510-163110.yaml | 6 ++ .pre-commit-config.yaml | 4 +- dbt/adapters/spark/column.py | 2 +- dbt/adapters/spark/connections.py | 80 +++++++++++-------- dbt/adapters/spark/impl.py | 31 +++---- dbt/adapters/spark/python_submissions.py | 18 ++--- dbt/adapters/spark/relation.py | 4 +- dbt/adapters/spark/session.py | 28 ++++--- 8 files changed, 101 insertions(+), 72 deletions(-) create mode 100644 .changes/unreleased/Fixes-20230510-163110.yaml diff --git a/.changes/unreleased/Fixes-20230510-163110.yaml b/.changes/unreleased/Fixes-20230510-163110.yaml new file mode 100644 index 000000000..06672ac91 --- /dev/null +++ b/.changes/unreleased/Fixes-20230510-163110.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Disallow untyped `def`'s +time: 2023-05-10T16:31:10.593358+02:00 +custom: + Author: Fokko + Issue: "760" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ddec9b665..5e7fdbd04 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,7 @@ repos: alias: flake8-check stages: [manual] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + rev: v1.2.0 hooks: - id: mypy # N.B.: Mypy is... a bit fragile. @@ -52,7 +52,7 @@ repos: # of our control to the mix. Unfortunately, there's nothing we can # do about per pre-commit's author. # See https://github.com/pre-commit/pre-commit/issues/730 for details. - args: [--show-error-codes, --ignore-missing-imports, --explicit-package-bases, --warn-unused-ignores] + args: [--show-error-codes, --ignore-missing-imports, --explicit-package-bases, --warn-unused-ignores, --disallow-untyped-defs] files: ^dbt/adapters/.* language: system - id: mypy diff --git a/dbt/adapters/spark/column.py b/dbt/adapters/spark/column.py index 8100fa450..bde49a492 100644 --- a/dbt/adapters/spark/column.py +++ b/dbt/adapters/spark/column.py @@ -26,7 +26,7 @@ def can_expand_to(self: Self, other_column: Self) -> bool: # type: ignore """returns True if both columns are strings""" return self.is_string() and other_column.is_string() - def literal(self, value): + def literal(self, value: Any) -> str: return "cast({} as {})".format(value, self.dtype) @property diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 9d3e385b0..bde614fa7 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -from typing import Tuple import dbt.exceptions from dbt.adapters.base import Credentials @@ -23,10 +22,10 @@ pyodbc = None from datetime import datetime import sqlparams - +from dbt.contracts.connection import Connection from hologram.helpers import StrEnum from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable try: from thrift.transport.TSSLSocket import TSSLSocket @@ -45,7 +44,7 @@ NUMBERS = DECIMALS + (int, float) -def _build_odbc_connnection_string(**kwargs) -> str: +def _build_odbc_connnection_string(**kwargs: Any) -> str: return ";".join([f"{k}={v}" for k, v in kwargs.items()]) @@ -78,17 +77,17 @@ class SparkCredentials(Credentials): retry_all: bool = False @classmethod - def __pre_deserialize__(cls, data): + def __pre_deserialize__(cls, data: Any) -> Any: data = super().__pre_deserialize__(data) if "database" not in data: data["database"] = None return data @property - def cluster_id(self): + def cluster_id(self) -> Optional[str]: return self.cluster - def __post_init__(self): + def __post_init__(self) -> None: # spark classifies database and schema as the same thing if self.database is not None and self.database != self.schema: raise dbt.exceptions.DbtRuntimeError( @@ -141,15 +140,15 @@ def __post_init__(self): ) from e @property - def type(self): + def type(self) -> str: return "spark" @property - def unique_field(self): + def unique_field(self) -> str: return self.host def _connection_keys(self) -> Tuple[str, ...]: - return ("host", "port", "cluster", "endpoint", "schema", "organization") + return "host", "port", "cluster", "endpoint", "schema", "organization" class PyhiveConnectionWrapper(object): @@ -157,15 +156,18 @@ class PyhiveConnectionWrapper(object): # https://forums.databricks.com/questions/2157/in-apache-spark-sql-can-we-roll-back-the-transacti.html # noqa - def __init__(self, handle): + handle: "pyodbc.Connection" + _cursor: "Optional[pyodbc.Cursor]" + + def __init__(self, handle: "pyodbc.Connection") -> None: self.handle = handle self._cursor = None - def cursor(self): + def cursor(self) -> "PyhiveConnectionWrapper": self._cursor = self.handle.cursor() return self - def cancel(self): + def cancel(self) -> None: if self._cursor: # Handle bad response in the pyhive lib when # the connection is cancelled @@ -174,7 +176,7 @@ def cancel(self): except EnvironmentError as exc: logger.debug("Exception while cancelling query: {}".format(exc)) - def close(self): + def close(self) -> None: if self._cursor: # Handle bad response in the pyhive lib when # the connection is cancelled @@ -184,13 +186,14 @@ def close(self): logger.debug("Exception while closing cursor: {}".format(exc)) self.handle.close() - def rollback(self, *args, **kwargs): + def rollback(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: rollback") - def fetchall(self): + def fetchall(self) -> List["pyodbc.Row"]: + assert self._cursor, "Cursor not available" return self._cursor.fetchall() - def execute(self, sql, bindings=None): + def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None: if sql.strip().endswith(";"): sql = sql.strip()[:-1] @@ -212,6 +215,8 @@ def execute(self, sql, bindings=None): if bindings is not None: bindings = [self._fix_binding(binding) for binding in bindings] + assert self._cursor, "Cursor not available" + self._cursor.execute(sql, bindings, async_=True) poll_state = self._cursor.poll() state = poll_state.operationState @@ -245,7 +250,7 @@ def execute(self, sql, bindings=None): logger.debug("Poll status: {}, query complete".format(state)) @classmethod - def _fix_binding(cls, value): + def _fix_binding(cls, value: Any) -> Union[float, str]: """Convert complex datatypes to primitives that can be loaded by the Spark driver""" if isinstance(value, NUMBERS): @@ -256,12 +261,14 @@ def _fix_binding(cls, value): return value @property - def description(self): + def description(self) -> Tuple[Tuple[str, Any, int, int, int, int, bool]]: + assert self._cursor, "Cursor not available" return self._cursor.description class PyodbcConnectionWrapper(PyhiveConnectionWrapper): - def execute(self, sql, bindings=None): + def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None: + assert self._cursor, "Cursor not available" if sql.strip().endswith(";"): sql = sql.strip()[:-1] # pyodbc does not handle a None type binding! @@ -282,7 +289,7 @@ class SparkConnectionManager(SQLConnectionManager): SPARK_CONNECTION_URL = "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH @contextmanager - def exception_handler(self, sql): + def exception_handler(self, sql: str) -> Generator[None, None, None]: try: yield @@ -299,30 +306,30 @@ def exception_handler(self, sql): else: raise dbt.exceptions.DbtRuntimeError(str(exc)) - def cancel(self, connection): + def cancel(self, connection: Connection) -> None: connection.handle.cancel() @classmethod - def get_response(cls, cursor) -> AdapterResponse: + def get_response(cls, cursor: Any) -> AdapterResponse: # https://github.com/dbt-labs/dbt-spark/issues/142 message = "OK" return AdapterResponse(_message=message) # No transactions on Spark.... - def add_begin_query(self, *args, **kwargs): + def add_begin_query(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: add_begin_query") - def add_commit_query(self, *args, **kwargs): + def add_commit_query(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: add_commit_query") - def commit(self, *args, **kwargs): + def commit(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: commit") - def rollback(self, *args, **kwargs): + def rollback(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: rollback") @classmethod - def validate_creds(cls, creds, required): + def validate_creds(cls, creds: Any, required: Iterable[str]) -> None: method = creds.method for key in required: @@ -333,7 +340,7 @@ def validate_creds(cls, creds, required): ) @classmethod - def open(cls, connection): + def open(cls, connection: Connection) -> Connection: if connection.state == ConnectionState.OPEN: logger.debug("Connection is already open, skipping open.") return connection @@ -450,7 +457,7 @@ def open(cls, connection): SessionConnectionWrapper, ) - handle = SessionConnectionWrapper(Connection()) + handle = SessionConnectionWrapper(Connection()) # type: ignore else: raise dbt.exceptions.DbtProfileError( f"invalid credential method: {creds.method}" @@ -487,7 +494,7 @@ def open(cls, connection): else: raise dbt.exceptions.FailedToConnectError("failed to connect") from e else: - raise exc + raise exc # type: ignore connection.handle = handle connection.state = ConnectionState.OPEN @@ -507,7 +514,14 @@ def data_type_code_to_name(cls, type_code: Union[type, str]) -> str: # type: ig return type_code.__name__.upper() -def build_ssl_transport(host, port, username, auth, kerberos_service_name, password=None): +def build_ssl_transport( + host: str, + port: int, + username: str, + auth: str, + kerberos_service_name: str, + password: Optional[str] = None, +) -> "thrift_sasl.TSaslClientTransport": transport = None if port is None: port = 10000 @@ -531,7 +545,7 @@ def build_ssl_transport(host, port, username, auth, kerberos_service_name, passw # to be nonempty. password = "x" - def sasl_factory(): + def sasl_factory() -> sasl.Client: sasl_client = sasl.Client() sasl_client.setAttr("host", host) if sasl_auth == "GSSAPI": diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 1d4a64973..2864c4f30 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,7 +1,10 @@ import re from concurrent.futures import Future from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable +from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable, Set + +from dbt.adapters.base.relation import InformationSchema +from dbt.contracts.graph.manifest import Manifest from typing_extensions import TypeAlias @@ -109,27 +112,27 @@ def date_function(cls) -> str: return "current_timestamp()" @classmethod - def convert_text_type(cls, agate_table, col_idx): + def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "string" @classmethod - def convert_number_type(cls, agate_table, col_idx): + def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str: decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) return "double" if decimals else "bigint" @classmethod - def convert_date_type(cls, agate_table, col_idx): + def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "date" @classmethod - def convert_time_type(cls, agate_table, col_idx): + def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "time" @classmethod - def convert_datetime_type(cls, agate_table, col_idx): + def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str: return "timestamp" - def quote(self, identifier): + def quote(self, identifier: str) -> str: # type: ignore return "`{}`".format(identifier) def _get_relation_information(self, row: agate.Row) -> RelationInfo: @@ -344,7 +347,7 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, as_dict["table_database"] = None yield as_dict - def get_catalog(self, manifest): + def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]: schema_map = self._get_catalog_schemas(manifest) if len(schema_map) > 1: raise dbt.exceptions.CompilationError( @@ -370,9 +373,9 @@ def get_catalog(self, manifest): def _get_one_catalog( self, - information_schema, - schemas, - manifest, + information_schema: InformationSchema, + schemas: Set[str], + manifest: Manifest, ) -> agate.Table: if len(schemas) != 1: raise dbt.exceptions.CompilationError( @@ -388,7 +391,7 @@ def _get_one_catalog( columns.extend(self._get_columns_for_catalog(relation)) return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER) - def check_schema_exists(self, database, schema): + def check_schema_exists(self, database: str, schema: str) -> bool: results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) exists = True if schema in [row[0] for row in results] else False @@ -425,7 +428,7 @@ def get_rows_different_sql( # This is for use in the test suite # Spark doesn't have 'commit' and 'rollback', so this override # doesn't include those commands. - def run_sql_for_tests(self, sql, fetch, conn): + def run_sql_for_tests(self, sql, fetch, conn): # type: ignore cursor = conn.handle.cursor() try: cursor.execute(sql) @@ -477,7 +480,7 @@ def standardize_grants_dict(self, grants_table: agate.Table) -> dict: grants_dict.update({privilege: [grantee]}) return grants_dict - def debug_query(self): + def debug_query(self) -> None: """Override for DebugTask method""" self.execute("select 1 as id") diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index 47529e079..89831ca7f 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -1,7 +1,7 @@ import base64 import time import requests -from typing import Any, Dict +from typing import Any, Dict, Callable, Iterable import uuid import dbt.exceptions @@ -149,18 +149,18 @@ def submit(self, compiled_code: str) -> None: def polling( self, - status_func, - status_func_kwargs, - get_state_func, - terminal_states, - expected_end_state, - get_state_msg_func, + status_func: Callable, + status_func_kwargs: Dict, + get_state_func: Callable, + terminal_states: Iterable[str], + expected_end_state: str, + get_state_msg_func: Callable, ) -> Dict: state = None start = time.time() exceeded_timeout = False - response = {} - while state not in terminal_states: + response: Dict = {} + while state is None or state not in terminal_states: if time.time() - start > self.timeout: exceeded_timeout = True break diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index f5a3e3e15..e80f2623f 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -36,11 +36,11 @@ class SparkRelation(BaseRelation): # TODO: make this a dict everywhere information: Optional[str] = None - def __post_init__(self): + def __post_init__(self) -> None: if self.database != self.schema and self.database: raise DbtRuntimeError("Cannot set database in spark!") - def render(self): + def render(self) -> str: if self.include_policy.database and self.include_policy.schema: raise DbtRuntimeError( "Got a spark relation with schema and database set to " diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index d275c73c5..5e4bcc492 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -4,7 +4,7 @@ import datetime as dt from types import TracebackType -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union from dbt.events import AdapterLogger from dbt.utils import DECIMALS @@ -172,33 +172,38 @@ def cursor(self) -> Cursor: class SessionConnectionWrapper(object): - """Connection wrapper for the sessoin connection method.""" + """Connection wrapper for the session connection method.""" - def __init__(self, handle): + handle: Connection + _cursor: Optional[Cursor] + + def __init__(self, handle: Connection) -> None: self.handle = handle self._cursor = None - def cursor(self): + def cursor(self) -> "SessionConnectionWrapper": self._cursor = self.handle.cursor() return self - def cancel(self): + def cancel(self) -> None: logger.debug("NotImplemented: cancel") - def close(self): + def close(self) -> None: if self._cursor: self._cursor.close() - def rollback(self, *args, **kwargs): + def rollback(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: rollback") - def fetchall(self): + def fetchall(self) -> Optional[List[Row]]: + assert self._cursor, "Cursor not available" return self._cursor.fetchall() - def execute(self, sql, bindings=None): + def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None: if sql.strip().endswith(";"): sql = sql.strip()[:-1] + assert self._cursor, "Cursor not available" if bindings is None: self._cursor.execute(sql) else: @@ -206,11 +211,12 @@ def execute(self, sql, bindings=None): self._cursor.execute(sql, *bindings) @property - def description(self): + def description(self) -> List[Tuple[str, str, None, None, None, None, bool]]: + assert self._cursor, "Cursor not available" return self._cursor.description @classmethod - def _fix_binding(cls, value): + def _fix_binding(cls, value: Any) -> Union[str, float]: """Convert complex datatypes to primitives that can be loaded by the Spark driver""" if isinstance(value, NUMBERS):