Skip to content

Commit

Permalink
Fixes of rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
zamar-roura-fever committed Nov 26, 2022
1 parent 3d02618 commit 58705fa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 90 deletions.
13 changes: 8 additions & 5 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,11 +1296,14 @@ def estimate_query_cost(
statements = parsed_query.get_statements()

costs = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, cursor))
with cls.get_engine(database, schema=schema, source=source) as engine:
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(
cls.estimate_statement_cost(processed_statement, cursor)
)
return costs

@classmethod
Expand Down
89 changes: 4 additions & 85 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,89 +211,6 @@ def convert_dttm(
return f"""CAST('{dttm.isoformat(timespec="microseconds")}' AS TIMESTAMP)"""
return None

@classmethod
def estimate_query_cost(
cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
:param database: Database instance
:param schema: Database schema
:param sql: SQL query with possibly multiple statements
:param source: Source of the query (eg, "sql_lab")
"""
extra = database.get_extra() or {}
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()
engine = cls.get_engine(database, schema=schema, source=source)
costs = []
for statement in statements:
processed_statement = cls.process_statement(statement, database)
costs.append(cls.estimate_statement_cost(processed_statement, database))
return costs

@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True

@classmethod
def estimate_statement_cost(
cls, statement: str, database: "Database"
) -> Dict[str, Any]:
try:
# pylint: disable=import-outside-toplevel
# It's the only way to perfom a dry-run estimate cost
from google.cloud import bigquery
from google.oauth2 import service_account
except ImportError as ex:
raise Exception(
"Could not import libraries `pygibquery` or `google.oauth2`, which are "
"required to be installed in your environment in order "
"to upload data to BigQuery"
) from ex
engine = cls.get_engine(database)

creds = engine.dialect.credentials_info
creds = service_account.Credentials.from_service_account_info(creds)
client = bigquery.Client(credentials=creds)
job_config = bigquery.QueryJobConfig(dry_run=True, use_query_cache=True)

query_job = client.query(
(statement),
job_config=job_config,
) # Make an API request.

# Format Bytes.
if query_job.total_bytes_processed // 1000 == 0:
byte_type = "B"
total_bytes_processed = query_job.total_bytes_processed
elif query_job.total_bytes_processed // (1000**2) == 0:
byte_type = "KB"
total_bytes_processed = round(query_job.total_bytes_processed / 1000, 2)
elif query_job.total_bytes_processed // (1000**3) == 0:
byte_type = "MB"
total_bytes_processed = round(
query_job.total_bytes_processed / (1000**2), 2
)
else:
byte_type = "GB"
total_bytes_processed = round(
query_job.total_bytes_processed / (1000**3), 2
)

return {f"{byte_type} Processed": total_bytes_processed}

@classmethod
def query_cost_formatter(
cls, raw_cost: List[Dict[str, Any]]
) -> List[Dict[str, str]]:
print([{k: str(v) for k, v in row.items()} for row in raw_cost])
return [{k: str(v) for k, v in row.items()} for row in raw_cost]

@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
Expand Down Expand Up @@ -473,6 +390,7 @@ def estimate_query_cost(
costs = []
for statement in statements:
processed_statement = cls.process_statement(statement, database)

costs.append(cls.estimate_statement_cost(processed_statement, database))
return costs

Expand All @@ -493,9 +411,10 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
"required to be installed in your environment in order "
"to upload data to BigQuery"
) from ex
engine = cls.get_engine(cursor)

creds = engine.dialect.credentials_info
with cls.get_engine(cursor) as engine:
creds = engine.dialect.credentials_info

creds = service_account.Credentials.from_service_account_info(creds)
client = bigquery.Client(credentials=creds)
job_config = bigquery.QueryJobConfig(dry_run=True)
Expand Down

0 comments on commit 58705fa

Please sign in to comment.