Skip to content

Commit

Permalink
feat: estimate query cost in Postgres (#12130)
Browse files Browse the repository at this point in the history
* feat: estimate query cost in Postgres

* Add example in config

* Fix lint
  • Loading branch information
betodealmeida authored Dec 19, 2020
1 parent fa27ed1 commit 877b153
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 30 deletions.
5 changes: 5 additions & 0 deletions superset/commands/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def run(self) -> None:
db.session.rollback()
raise self.import_error()

# pylint: disable=too-many-branches
def validate(self) -> None:
exceptions: List[ValidationError] = []

Expand Down Expand Up @@ -99,6 +100,10 @@ def validate(self) -> None:

# validate objects
for file_name, content in self.contents.items():
# skip directories
if not content:
continue

prefix = file_name.split("/")[0]
schema = self.schemas.get(f"{prefix}/")
if schema:
Expand Down
31 changes: 31 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,37 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# query costs before they run. These EXPLAIN queries should have a small
# timeout.
SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = 10 # seconds
# The feature is off by default, and currently only supported in Presto and Postgres.
# It also need to be enabled on a per-database basis, by adding the key/value pair
# `cost_estimate_enabled: true` to the database `extra` attribute.
ESTIMATE_QUERY_COST = False
# The cost returned by the databases is a relative value; in order to map the cost to
# a tangible value you need to define a custom formatter that takes into consideration
# your specific infrastructure. For example, you could analyze queries a posteriori by
# running EXPLAIN on them, and compute a histogram of relative costs to present the
# cost as a percentile:
#
# def postgres_query_cost_formatter(
# result: List[Dict[str, Any]]
# ) -> List[Dict[str, str]]:
# # 25, 50, 75% percentiles
# percentile_costs = [100.0, 1000.0, 10000.0]
#
# out = []
# for row in result:
# relative_cost = row["Total cost"]
# percentile = bisect.bisect_left(percentile_costs, relative_cost) + 1
# out.append({
# "Relative cost": relative_cost,
# "Percentile": str(percentile * 25) + "%",
# })
#
# return out
#
# DEFAULT_FEATURE_FLAGS = {
# "ESTIMATE_QUERY_COST": True,
# "QUERY_COST_FORMATTERS_BY_ENGINE": {"postgresql": postgres_query_cost_formatter},
# }

# Flag that controls if limit should be enforced on the CTA (create table as queries).
SQLLAB_CTAS_NO_LIMIT = False
Expand Down
1 change: 1 addition & 0 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ class ImportV1DatabaseExtraSchema(Schema):
engine_params = fields.Dict(keys=fields.Str(), values=fields.Raw())
metadata_cache_timeout = fields.Dict(keys=fields.Str(), values=fields.Integer())
schemas_allowed_for_csv_upload = fields.List(fields.String)
cost_estimate_enabled = fields.Boolean()


class ImportV1DatabaseSchema(Schema):
Expand Down
44 changes: 31 additions & 13 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine

from superset import app, sql_parse
from superset import app, security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
Expand Down Expand Up @@ -203,7 +203,7 @@ def is_db_column_type_match(
return any(pattern.match(db_column_type) for pattern in patterns)

@classmethod
def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return False

@classmethod
Expand Down Expand Up @@ -790,16 +790,12 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
return sql

@classmethod
def estimate_statement_cost(
cls, statement: str, database: "Database", cursor: Any, user_name: str
) -> Dict[str, Any]:
def estimate_statement_cost(cls, statement: str, cursor: Any,) -> Dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
:param statement: A single SQL statement
:param database: Database instance
:param cursor: Cursor instance
:param username: Effective username
:return: Dictionary with different costs
"""
raise Exception("Database does not support cost estimation")
Expand All @@ -816,10 +812,31 @@ def query_cost_formatter(
"""
raise Exception("Database does not support cost estimation")

@classmethod
def process_statement(
cls, statement: str, database: "Database", user_name: str
) -> str:
"""
Process a SQL statement by stripping and mutating it.
:param statement: A single SQL statement
:param database: Database instance
:param username: Effective username
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()

sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
sql = sql_query_mutator(sql, user_name, security_manager, database)

return sql

@classmethod
def estimate_query_cost(
cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
) -> List[Dict[str, str]]:
) -> List[Dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
Expand All @@ -828,8 +845,8 @@ def estimate_query_cost(
:param sql: SQL query with possibly multiple statements
:param source: Source of the query (eg, "sql_lab")
"""
database_version = database.get_extra().get("version")
if not cls.get_allow_cost_estimate(database_version):
extra = database.get_extra() or {}
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")

user_name = g.user.username if g.user else None
Expand All @@ -841,10 +858,11 @@ def estimate_query_cost(
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
for statement in statements:
processed_statement = cls.process_statement(
statement, database, user_name
)
costs.append(
cls.estimate_statement_cost(
statement, database, cursor, user_name
)
cls.estimate_statement_cost(processed_statement, cursor)
)
return costs

Expand Down
28 changes: 27 additions & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

from pytz import _FixedOffset # type: ignore
from sqlalchemy.dialects.postgresql.base import PGInspector
Expand Down Expand Up @@ -71,6 +72,31 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
max_column_name_length = 63
try_remove_schema_from_table_name = False

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
sql = f"EXPLAIN {statement}"
cursor.execute(sql)

result = cursor.fetchone()[0]
match = re.search(r"cost=([\d\.]+)\.\.([\d\.]+)", result)
if match:
return {
"Start-up cost": float(match.group(1)),
"Total cost": float(match.group(2)),
}

return {}

@classmethod
def query_cost_formatter(
cls, raw_cost: List[Dict[str, Any]]
) -> List[Dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]

@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
Expand Down
16 changes: 5 additions & 11 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select

from superset import app, cache_manager, is_feature_enabled, security_manager
from superset import app, cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetTemplateException
Expand Down Expand Up @@ -132,7 +132,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
}

@classmethod
def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and StrictVersion(version) >= StrictVersion("0.319")

@classmethod
Expand Down Expand Up @@ -484,7 +485,7 @@ def select_star( # pylint: disable=too-many-arguments

@classmethod
def estimate_statement_cost( # pylint: disable=too-many-locals
cls, statement: str, database: "Database", cursor: Any, user_name: str
cls, statement: str, cursor: Any
) -> Dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
Expand All @@ -495,14 +496,7 @@ def estimate_statement_cost( # pylint: disable=too-many-locals
:param username: Effective username
:return: JSON response from Presto
"""
parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()

sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
sql = sql_query_mutator(sql, user_name, security_manager, database)

sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}"
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
cursor.execute(sql)

# the output from Presto is a single column and a single row containing
Expand Down
7 changes: 2 additions & 5 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,11 @@ def function_names(self) -> List[str]:

@property
def allows_cost_estimate(self) -> bool:
extra = self.get_extra()

database_version = extra.get("version")
extra = self.get_extra() or {}
cost_estimate_enabled: bool = extra.get("cost_estimate_enabled") # type: ignore

return (
self.db_engine_spec.get_allow_cost_estimate(database_version)
and cost_estimate_enabled
self.db_engine_spec.get_allow_cost_estimate(extra) and cost_estimate_enabled
)

@property
Expand Down

0 comments on commit 877b153

Please sign in to comment.