Skip to content

Commit

Permalink
fix: improve get_db_engine_spec_for_backend (#21171)
Browse files Browse the repository at this point in the history
* fix: improve get_db_engine_spec_for_backend

* Fix tests

* Fix docs

* fix lint

* fix fallback

* Fix engine validation

* Fix test
  • Loading branch information
betodealmeida authored Aug 29, 2022
1 parent 710a8ce commit 8772e2c
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 127 deletions.
4 changes: 2 additions & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,8 +1083,8 @@ def available(self) -> Response:
"preferred": engine_spec.engine_name in preferred_databases,
}

if hasattr(engine_spec, "default_driver"):
payload["default_driver"] = engine_spec.default_driver # type: ignore
if engine_spec.default_driver:
payload["default_driver"] = engine_spec.default_driver

# show configuration parameters for DBs that support it
if (
Expand Down
27 changes: 3 additions & 24 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
)
from superset.databases.dao import DatabaseDAO
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_specs
from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs import get_engine_spec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
Expand All @@ -45,25 +44,13 @@ def __init__(self, parameters: Dict[str, Any]):

def run(self) -> None:
engine = self._properties["engine"]
engine_specs = get_engine_specs()
driver = self._properties.get("driver")

if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return

if engine not in engine_specs:
raise InvalidEngineError(
SupersetError(
message=__(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={"allowed": list(engine_specs), "provided": engine},
),
)
engine_spec = engine_specs[engine]
engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
Expand All @@ -73,14 +60,6 @@ def run(self) -> None:
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
extra={
"allowed": [
name
for name, engine_spec in engine_specs.items()
if issubclass(engine_spec, BasicParametersMixin)
],
"provided": engine,
},
),
)

Expand Down
43 changes: 16 additions & 27 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import inspect
import json
from typing import Any, Dict, Optional, Type
from typing import Any, Dict

from flask import current_app
from flask_babel import lazy_gettext as _
Expand All @@ -28,7 +28,7 @@
from superset import db
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
Expand Down Expand Up @@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
[
_(
"Invalid connection string, a valid string usually follows: "
"driver://user:password@database-host/database-name"
"backend+driver://user:password@database-host/database-name"
)
]
) from ex
Expand Down Expand Up @@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
"""

engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(),
Expand Down Expand Up @@ -262,10 +263,20 @@ def build_sqlalchemy_uri(
or parameters.pop("engine", None)
or data.pop("backend", None)
)
driver = data.pop("driver", None)

configuration_method = data.get("configuration_method")
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
engine_spec = get_engine_spec(engine)
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_spec = get_engine_spec(engine, driver)

if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
engine_spec, "parameters_schema"
Expand Down Expand Up @@ -295,34 +306,12 @@ def build_sqlalchemy_uri(
return data


def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
if not engine:
raise ValidationError(
[
_(
"An engine must be specified when passing "
"individual parameters to a database."
)
]
)
engine_specs = get_engine_specs()
if engine not in engine_specs:
raise ValidationError(
[
_(
'Engine "%(engine)s" is not a valid engine.',
engine=engine,
)
]
)
return engine_specs[engine]


class DatabaseValidateParametersSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE

engine = fields.String(required=True, description="SQLAlchemy engine to use")
driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(allow_none=True),
Expand Down
48 changes: 33 additions & 15 deletions superset/db_engine_specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,34 @@
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Set, Type
from typing import Any, Dict, List, Optional, Set, Type

import sqlalchemy.databases
import sqlalchemy.dialects
from pkg_resources import iter_entry_points
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.engine.url import URL

from superset.db_engine_specs.base import BaseEngineSpec

logger = logging.getLogger(__name__)


def is_engine_spec(attr: Any) -> bool:
def is_engine_spec(obj: Any) -> bool:
"""
Return true if a given object is a DB engine spec.
"""
return (
inspect.isclass(attr)
and issubclass(attr, BaseEngineSpec)
and attr != BaseEngineSpec
inspect.isclass(obj)
and issubclass(obj, BaseEngineSpec)
and obj != BaseEngineSpec
)


def load_engine_specs() -> List[Type[BaseEngineSpec]]:
"""
Load all engine specs, native and 3rd party.
"""
engine_specs: List[Type[BaseEngineSpec]] = []

# load standard engines
Expand All @@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs


def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
"""
Return the DB engine spec associated with a given SQLAlchemy URL.
Note that if a driver is not specified the function returns the first DB engine spec
that supports the backend. Also, if a driver is specified but no DB engine explicitly
supporting that driver exists then a backend-only match is done, in order to allow new
drivers to work with Superset even if they are not listed in the DB engine spec
drivers.
"""
engine_specs = load_engine_specs()

# build map from name/alias -> spec
engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
for engine_spec in engine_specs:
names = [engine_spec.engine]
if engine_spec.engine_aliases:
names.extend(engine_spec.engine_aliases)
if driver is not None:
for engine_spec in engine_specs:
if engine_spec.supports_backend(backend, driver):
return engine_spec

for name in names:
engine_specs_map[name] = engine_spec
# check ignoring the driver, in order to support new drivers; this will return a
# random DB engine spec that supports the engine
for engine_spec in engine_specs:
if engine_spec.supports_backend(backend):
return engine_spec

return engine_specs_map
# default to the generic DB engine spec
return BaseEngineSpec


# there's a mismatch between the dialect name reported by the driver in these
Expand Down
62 changes: 60 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""

engine_name: Optional[str] = None # for user messages, overridden in child classes

# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
# see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
engine_name: Optional[str] = None # for user messages, overridden in child classes
drivers: Dict[str, str] = {}
default_driver: Optional[str] = None

_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
Expand Down Expand Up @@ -355,6 +361,58 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
] = {}

@classmethod
def supports_url(cls, url: URL) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy URL.
As an example, if a given DB engine spec has:
class PostgresDBEngineSpec:
engine = "postgresql"
engine_aliases = "postgres"
drivers = {
"psycopg2": "The default Postgres driver",
"asyncpg": "An asynchronous Postgres driver",
}
It would be used for all the following SQLAlchemy URIs:
- postgres://user:password@host/db
- postgresql://user:password@host/db
- postgres+asyncpg://user:password@host/db
- postgres+psycopg2://user:password@host/db
- postgresql+asyncpg://user:password@host/db
- postgresql+psycopg2://user:password@host/db
Note that SQLAlchemy has a default driver even if one is not specified:
>>> from sqlalchemy.engine.url import make_url
>>> make_url('postgres://').get_driver_name()
'psycopg2'
"""
backend = url.get_backend_name()
driver = url.get_driver_name()
return cls.supports_backend(backend, driver)

@classmethod
def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
# check the backend first
if backend != cls.engine and backend not in cls.engine_aliases:
return False

# originally DB engine specs didn't declare any drivers and the check was made
# only on the engine; if that's the case, ignore the driver for backwards
# compatibility
if not cls.drivers or driver is None:
return True

return driver in cls.drivers

@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
Expand Down Expand Up @@ -394,7 +452,7 @@ def get_allow_cost_estimate( # pylint: disable=unused-argument
@classmethod
def get_text_clause(cls, clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
SQLAlchemy wrapper to ensure text clauses are escaped properly
:param clause: string clause with potentially unescaped characters
:return: text clause with escaped characters
Expand Down
19 changes: 13 additions & 6 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,23 @@


class DatabricksHiveEngineSpec(HiveEngineSpec):
engine = "databricks"
engine_name = "Databricks Interactive Cluster"
driver = "pyhive"

engine = "databricks"
drivers = {"pyhive": "Hive driver for Interactive Cluster"}
default_driver = "pyhive"

_show_functions_column = "function"

_time_grain_expressions = time_grain_expressions


class DatabricksODBCEngineSpec(BaseEngineSpec):
engine = "databricks"
engine_name = "Databricks SQL Endpoint"
driver = "pyodbc"

engine = "databricks"
drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
default_driver = "pyodbc"

_time_grain_expressions = time_grain_expressions

Expand All @@ -74,9 +79,11 @@ def epoch_to_dttm(cls) -> str:


class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
engine = "databricks"
engine_name = "Databricks Native Connector"
driver = "connector"

engine = "databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
Expand Down
6 changes: 5 additions & 1 deletion superset/db_engine_specs/shillelagh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
class ShillelaghEngineSpec(SqliteEngineSpec):
"""Engine for shillelagh"""

engine = "shillelagh"
engine_name = "Shillelagh"
engine = "shillelagh"
drivers = {"apsw": "SQLite driver"}
default_driver = "apsw"
sqlalchemy_uri_placeholder = "shillelagh://"

allows_joins = True
allows_subqueries = True
Loading

0 comments on commit 8772e2c

Please sign in to comment.