diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index ef153301f6dda..87951d396ef10 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1296,11 +1296,14 @@ def estimate_query_cost( statements = parsed_query.get_statements() costs = [] - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - for statement in statements: - processed_statement = cls.process_statement(statement, database) - costs.append(cls.estimate_statement_cost(processed_statement, cursor)) + with cls.get_engine(database, schema=schema, source=source) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + for statement in statements: + processed_statement = cls.process_statement(statement, database) + costs.append( + cls.estimate_statement_cost(processed_statement, cursor) + ) return costs @classmethod diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index ea606614387fc..57c466f4e9037 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -211,89 +211,6 @@ def convert_dttm( return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)""" return None - @classmethod - def estimate_query_cost( - cls, database: "Database", schema: str, sql: str, source: Optional[str] = None - ) -> List[Dict[str, Any]]: - """ - Estimate the cost of a multiple statement SQL query. - - :param database: Database instance - :param schema: Database schema - :param sql: SQL query with possibly multiple statements - :param source: Source of the query (eg, "sql_lab") - """ - extra = database.get_extra() or {} - if not cls.get_allow_cost_estimate(extra): - raise Exception("Database does not support cost estimation") - - parsed_query = sql_parse.ParsedQuery(sql) - statements = parsed_query.get_statements() - engine = cls.get_engine(database, schema=schema, source=source) - costs = [] - for statement in statements: - processed_statement = cls.process_statement(statement, database) - costs.append(cls.estimate_statement_cost(processed_statement, database)) - return costs - - @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: - return True - - @classmethod - def estimate_statement_cost( - cls, statement: str, database: "Database" - ) -> Dict[str, Any]: - try: - # pylint: disable=import-outside-toplevel - # It's the only way to perfom a dry-run estimate cost - from google.cloud import bigquery - from google.oauth2 import service_account - except ImportError as ex: - raise Exception( - "Could not import libraries `pygibquery` or `google.oauth2`, which are " - "required to be installed in your environment in order " - "to upload data to BigQuery" - ) from ex - engine = cls.get_engine(database) - - creds = engine.dialect.credentials_info - creds = service_account.Credentials.from_service_account_info(creds) - client = bigquery.Client(credentials=creds) - job_config = bigquery.QueryJobConfig(dry_run=True, use_query_cache=True) - - query_job = client.query( - (statement), - job_config=job_config, - ) # Make an API request. - - # Format Bytes. - if query_job.total_bytes_processed // 1000 == 0: - byte_type = "B" - total_bytes_processed = query_job.total_bytes_processed - elif query_job.total_bytes_processed // (1000**2) == 0: - byte_type = "KB" - total_bytes_processed = round(query_job.total_bytes_processed / 1000, 2) - elif query_job.total_bytes_processed // (1000**3) == 0: - byte_type = "MB" - total_bytes_processed = round( - query_job.total_bytes_processed / (1000**2), 2 - ) - else: - byte_type = "GB" - total_bytes_processed = round( - query_job.total_bytes_processed / (1000**3), 2 - ) - - return {f"{byte_type} Processed": total_bytes_processed} - - @classmethod - def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: - print([{k: str(v) for k, v in row.items()} for row in raw_cost]) - return [{k: str(v) for k, v in row.items()} for row in raw_cost] - @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None @@ -473,6 +390,7 @@ def estimate_query_cost( costs = [] for statement in statements: processed_statement = cls.process_statement(statement, database) + costs.append(cls.estimate_statement_cost(processed_statement, database)) return costs @@ -493,9 +411,10 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: "required to be installed in your environment in order " "to upload data to BigQuery" ) from ex - engine = cls.get_engine(cursor) - creds = engine.dialect.credentials_info + with cls.get_engine(cursor) as engine: + creds = engine.dialect.credentials_info + creds = service_account.Credentials.from_service_account_info(creds) client = bigquery.Client(credentials=creds) job_config = bigquery.QueryJobConfig(dry_run=True)