Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove support for async queries from SqlClient #589

Merged
merged 2 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Breaking Changes-20230608-182212.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Breaking Changes
body: Removes async query and query cancel methods from SqlClient protocols
time: 2023-06-08T18:22:12.793133-07:00
custom:
Author: tlento
Issue: "577"
6 changes: 2 additions & 4 deletions metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,8 @@ def exit_signal_handler(signal_type: int, frame) -> None: # type: ignore
return

try:
if cfg.sql_client.sql_engine_attributes.cancel_submitted_queries_supported:
logger.info("Cancelling submitted queries")
cfg.sql_client.cancel_submitted_queries()
cfg.sql_client.close()
# Note: we may wish to add support for canceling all queries if zombie queries are a problem
cfg.sql_client.close()
finally:
sys.exit(-1)

Expand Down
15 changes: 6 additions & 9 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from metricflow.protocols.sql_client import SqlClient
from metricflow.protocols.sql_request import SqlJsonTag
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.sql_utils import sync_execute, sync_query
from metricflow.visitor import Visitable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,11 +125,10 @@ def bind_parameters(self) -> SqlBindParameters: # noqa: D
def execute(self) -> TaskExecutionResult: # noqa: D
start_time = time.time()

df = sync_query(
self._sql_client,
df = self._sql_client.query(
self._sql_query,
bind_parameters=self.bind_parameters,
extra_sql_tags=self._extra_sql_tags,
sql_bind_parameters=self.bind_parameters,
extra_tags=self._extra_sql_tags,
)

end_time = time.time()
Expand Down Expand Up @@ -191,11 +189,10 @@ def execute(self) -> TaskExecutionResult: # noqa: D
logger.info(f"Creating table {self._output_table} using a SELECT query")
sql_query = self.sql_query
assert sql_query
sync_execute(
self._sql_client,
self._sql_client.execute(
sql_query.sql_query,
bind_parameters=sql_query.bind_parameters,
extra_sql_tags=self._extra_sql_tags,
sql_bind_parameters=sql_query.bind_parameters,
extra_tags=self._extra_sql_tags,
)

end_time = time.time()
Expand Down
54 changes: 2 additions & 52 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from abc import abstractmethod
from enum import Enum
from typing import Callable, ClassVar, Dict, Optional, Protocol, Sequence
from typing import ClassVar, Dict, Optional, Protocol, Sequence

from pandas import DataFrame

from metricflow.dataflow.sql_table import SqlTable
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestResult
from metricflow.protocols.sql_request import SqlJsonTag
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags


class SqlEngine(Enum):
Expand Down Expand Up @@ -163,59 +162,11 @@ def close(self) -> None: # noqa: D
"""Close the connections / engines used by this client."""
raise NotImplementedError

@abstractmethod
def cancel_submitted_queries(self) -> None: # noqa: D
"""Cancel queries submitted through this client (that may be still running) with best-effort."""
raise NotImplementedError

@abstractmethod
def render_bind_parameter_key(self, bind_parameter_key: str) -> str:
"""Wrap the bind parameter key with syntax accepted by engine."""
raise NotImplementedError

def async_query(
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
"""Execute a query asynchronously."""
raise NotImplementedError

@abstractmethod
def async_request_result(self, request_id: SqlRequestId) -> SqlRequestResult:
"""Wait until a async query has finished, and then return the result."""
raise NotImplementedError

@abstractmethod
def async_execute(
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
"""Execute a statement that does not return values asynchronously."""
raise NotImplementedError

@abstractmethod
def cancel_request(self, match_function: Callable[[CombinedSqlTags], bool]) -> int:
"""Make a best-effort at canceling requests with tags that match the supplied function.

The function arguments are the tags associated with the query, and should return a bool indicating whether
the given query should be cancelled. Returns the number of cancellation commands sent.
"""
raise NotImplementedError

@abstractmethod
def active_requests(self) -> Sequence[SqlRequestId]:
"""Return requests that are still in progress.

If the results for a request have not yet been fetched with async_request_result(), it's considered in progress.
"""
raise NotImplementedError


class SqlEngineAttributes(Protocol):
"""Base interface for SQL engine-specific attributes and features.
Expand All @@ -242,7 +193,6 @@ class SqlEngineAttributes(Protocol):
multi_threading_supported: ClassVar[bool]
timestamp_type_supported: ClassVar[bool]
timestamp_to_string_comparison_supported: ClassVar[bool]
cancel_submitted_queries_supported: ClassVar[bool]
continuous_percentile_aggregation_supported: ClassVar[bool]
discrete_percentile_aggregation_supported: ClassVar[bool]
approximate_continuous_percentile_aggregation_supported: ClassVar[bool]
Expand Down
158 changes: 2 additions & 156 deletions metricflow/sql_clients/base_sql_client_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import logging
import textwrap
import threading
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Tuple

import jinja2
import pandas as pd
Expand All @@ -18,10 +17,9 @@
SqlEngineAttributes,
SqlIsolationLevel,
)
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestResult, SqlRequestTagSet
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.random_id import random_id
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.common_client import check_isolation_level
from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata

logger = logging.getLogger(__name__)
Expand All @@ -36,10 +34,6 @@ class SqlClientException(Exception):
class BaseSqlClientImplementation(ABC, SqlClient):
"""Abstract implementation that other SQL clients are based on."""

def __init__(self) -> None: # noqa: D
self._request_id_to_thread: Dict[SqlRequestId, BaseSqlClientImplementation.SqlRequestExecutorThread] = {}
self._state_lock = threading.Lock()

def generate_health_check_tests(self, schema_name: str) -> List[Tuple[str, Any]]: # type: ignore
"""List of base health checks we want to perform."""
table_name = "health_report"
Expand Down Expand Up @@ -264,158 +258,10 @@ def render_bind_parameter_key(self, bind_parameter_key: str) -> str:
"""Wrap execution parameter key with syntax accepted by engine."""
return f":{bind_parameter_key}"

def async_query( # noqa: D
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
check_isolation_level(self, isolation_level)
with self._state_lock:
request_id = SqlRequestId(f"mf_rid__{random_id()}")
thread = BaseSqlClientImplementation.SqlRequestExecutorThread(
sql_client=self,
request_id=request_id,
statement=statement,
bind_parameters=bind_parameters,
extra_tag=extra_tags,
isolation_level=isolation_level,
)
self._request_id_to_thread[request_id] = thread
self._request_id_to_thread[request_id].start()
return request_id

def async_execute( # noqa: D
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
check_isolation_level(self, isolation_level)
with self._state_lock:
request_id = SqlRequestId(f"mf_rid__{random_id()}")
thread = BaseSqlClientImplementation.SqlRequestExecutorThread(
sql_client=self,
request_id=request_id,
statement=statement,
bind_parameters=bind_parameters,
extra_tag=extra_tags,
is_query=False,
isolation_level=isolation_level,
)
self._request_id_to_thread[request_id] = thread
self._request_id_to_thread[request_id].start()
return request_id

def async_request_result(self, query_id: SqlRequestId) -> SqlRequestResult: # noqa: D
thread: Optional[BaseSqlClientImplementation.SqlRequestExecutorThread] = None
with self._state_lock:
thread = self._request_id_to_thread.get(query_id)
if thread is None:
raise RuntimeError(
f"Query ID: {query_id} is not known. Either the query ID is invalid, or results for the query ID "
f"were already fetched."
)

thread.join()
with self._state_lock:
del self._request_id_to_thread[query_id]
return thread.result

def active_requests(self) -> Sequence[SqlRequestId]: # noqa: D
with self._state_lock:
return tuple(executor_thread.request_id for executor_thread in self._request_id_to_thread.values())

@staticmethod
def _consolidate_tags(json_tags: SqlJsonTag, request_id: SqlRequestId) -> CombinedSqlTags:
"""Consolidates json tags and request ID into a single set of tags."""
return CombinedSqlTags(
system_tags=SqlRequestTagSet().add_request_id(request_id=request_id),
extra_tag=json_tags,
)

class SqlRequestExecutorThread(threading.Thread):
"""Thread that helps to execute a request to the SQL engine asynchronously."""

def __init__( # noqa: D
self,
sql_client: BaseSqlClientImplementation,
request_id: SqlRequestId,
statement: str,
bind_parameters: SqlBindParameters,
extra_tag: SqlJsonTag = SqlJsonTag(),
is_query: bool = True,
isolation_level: Optional[SqlIsolationLevel] = None,
) -> None:
"""Initializer.

Args:
sql_client: SQL client used to execute statements.
request_id: The request ID associated with the statement.
statement: The statement to execute.
bind_parameters: The parameters to use for the statement.
extra_tag: Tags that should be associated with the request for the statement.
is_query: Whether the request is for .query (returns data) or .execute (does not return data)
isolation_level: The isolation level to use for the query.
"""
self._sql_client = sql_client
self._request_id = request_id
self._statement = statement
self._bind_parameters = bind_parameters
self._extra_tag = extra_tag
self._result: Optional[SqlRequestResult] = None
self._is_query = is_query
self._isolation_level = isolation_level
super().__init__(name=f"Async Execute SQL Request ID: {request_id}", daemon=True)

def run(self) -> None: # noqa: D
start_time = time.time()
try:
combined_tags = BaseSqlClientImplementation._consolidate_tags(
json_tags=self._extra_tag, request_id=self._request_id
)
statement = SqlStatementCommentMetadata.add_tag_metadata_as_comment(
sql_statement=self._statement, combined_tags=combined_tags
)

logger.info(
BaseSqlClientImplementation._format_run_query_log_message(
statement=self._statement, sql_bind_parameters=self._bind_parameters
)
)

if self._is_query:
df = self._sql_client._engine_specific_query_implementation(
statement,
bind_params=self._bind_parameters,
isolation_level=self._isolation_level,
system_tags=combined_tags.system_tags,
extra_tags=self._extra_tag,
)
self._result = SqlRequestResult(df=df)
else:
self._sql_client._engine_specific_execute_implementation(
statement,
bind_params=self._bind_parameters,
isolation_level=self._isolation_level,
system_tags=combined_tags.system_tags,
extra_tags=self._extra_tag,
)
self._result = SqlRequestResult(df=pd.DataFrame())
logger.info(f"Successfully executed {self._request_id} in {time.time() - start_time:.2f}s")
except Exception as e:
logger.exception(
f"Unsuccessfully executed {self._request_id} in {time.time() - start_time:.2f}s with exception:"
)
self._result = SqlRequestResult(exception=e)

@property
def result(self) -> SqlRequestResult: # noqa: D
assert self._result is not None, ".result() should only be called once the thread is finished running"
return self._result

@property
def request_id(self) -> SqlRequestId: # noqa: D
return self._request_id
29 changes: 2 additions & 27 deletions metricflow/sql_clients/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import json
import logging
from typing import Callable, ClassVar, Dict, Optional, Sequence
from typing import ClassVar, Dict, Optional, Sequence

import google.oauth2.service_account
import sqlalchemy
from google.cloud.bigquery import Client, QueryJob
from google.cloud.bigquery import Client

from metricflow.protocols.sql_client import (
SqlEngine,
Expand All @@ -17,7 +17,6 @@
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.common_client import SqlDialect
from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata
from metricflow.sql_clients.sqlalchemy_dialect import SqlAlchemySqlClient

logger = logging.getLogger(__name__)
Expand All @@ -39,8 +38,6 @@ class BigQueryEngineAttributes:
multi_threading_supported: ClassVar[bool] = True
timestamp_type_supported: ClassVar[bool] = True
timestamp_to_string_comparison_supported: ClassVar[bool] = False
# Cancelling should be possible, but not yet implemented.
cancel_submitted_queries_supported: ClassVar[bool] = True
continuous_percentile_aggregation_supported: ClassVar[bool] = False
discrete_percentile_aggregation_supported: ClassVar[bool] = False
approximate_continuous_percentile_aggregation_supported: ClassVar[bool] = True
Expand Down Expand Up @@ -134,25 +131,3 @@ def list_tables(self, schema_name: str) -> Sequence[str]: # noqa: D
insp = sqlalchemy.inspection.inspect(conn)
schema_dot_tables = insp.get_table_names(schema=schema_name)
return [x.replace(schema_name + ".", "") for x in schema_dot_tables]

def cancel_submitted_queries(self) -> None: # noqa: D
raise NotImplementedError

def cancel_request(self, match_function: Callable[[CombinedSqlTags], bool]) -> int: # noqa: D
job: QueryJob
canceled_job_ids = []
# Couldn't find where these states were defined in the BQ libraries.
for state in ["PENDING", "RUNNING"]:
for job in self._bq_client.list_jobs(
project=self._project_id,
state_filter=state,
# Considering putting a creation_time_min filter as well.
):
parsed_tags = SqlStatementCommentMetadata.parse_tag_metadata_in_comments(job.query)

# A job can move from the pending to the running state during iteration, so dedupe.
if match_function(parsed_tags) and job.job_id not in canceled_job_ids:
logger.info(f"Canceling BQ job ID: {job.job_id}")
canceled_job_ids.append(job.job_id)
self._bq_client.cancel_job(job.job_id)
return len(canceled_job_ids)
Loading