Skip to content

Commit

Permalink
Remove async query and execute methods from SqlClient
Browse files Browse the repository at this point in the history
We will not be using async query dispatch mechanisms going forward,
so this removes the relevant methods and all supporting methods and
classes.
  • Loading branch information
tlento committed Jun 9, 2023
1 parent f9bf74e commit 069fd0d
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 335 deletions.
15 changes: 6 additions & 9 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from metricflow.protocols.sql_client import SqlClient
from metricflow.protocols.sql_request import SqlJsonTag
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.sql_utils import sync_execute, sync_query
from metricflow.visitor import Visitable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,11 +125,10 @@ def bind_parameters(self) -> SqlBindParameters: # noqa: D
def execute(self) -> TaskExecutionResult: # noqa: D
start_time = time.time()

df = sync_query(
self._sql_client,
df = self._sql_client.query(
self._sql_query,
bind_parameters=self.bind_parameters,
extra_sql_tags=self._extra_sql_tags,
sql_bind_parameters=self.bind_parameters,
extra_tags=self._extra_sql_tags,
)

end_time = time.time()
Expand Down Expand Up @@ -191,11 +189,10 @@ def execute(self) -> TaskExecutionResult: # noqa: D
logger.info(f"Creating table {self._output_table} using a SELECT query")
sql_query = self.sql_query
assert sql_query
sync_execute(
self._sql_client,
self._sql_client.execute(
sql_query.sql_query,
bind_parameters=sql_query.bind_parameters,
extra_sql_tags=self._extra_sql_tags,
sql_bind_parameters=sql_query.bind_parameters,
extra_tags=self._extra_sql_tags,
)

end_time = time.time()
Expand Down
36 changes: 1 addition & 35 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pandas import DataFrame

from metricflow.dataflow.sql_table import SqlTable
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestResult
from metricflow.protocols.sql_request import SqlJsonTag
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters

Expand Down Expand Up @@ -167,40 +167,6 @@ def render_bind_parameter_key(self, bind_parameter_key: str) -> str:
"""Wrap the bind parameter key with syntax accepted by engine."""
raise NotImplementedError

def async_query(
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
"""Execute a query asynchronously."""
raise NotImplementedError

@abstractmethod
def async_request_result(self, request_id: SqlRequestId) -> SqlRequestResult:
"""Wait until a async query has finished, and then return the result."""
raise NotImplementedError

@abstractmethod
def async_execute(
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
"""Execute a statement that does not return values asynchronously."""
raise NotImplementedError

@abstractmethod
def active_requests(self) -> Sequence[SqlRequestId]:
"""Return requests that are still in progress.
If the results for a request have not yet been fetched with async_request_result(), it's considered in progress.
"""
raise NotImplementedError


class SqlEngineAttributes(Protocol):
"""Base interface for SQL engine-specific attributes and features.
Expand Down
158 changes: 2 additions & 156 deletions metricflow/sql_clients/base_sql_client_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import logging
import textwrap
import threading
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Tuple

import jinja2
import pandas as pd
Expand All @@ -18,10 +17,9 @@
SqlEngineAttributes,
SqlIsolationLevel,
)
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestResult, SqlRequestTagSet
from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.random_id import random_id
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.common_client import check_isolation_level
from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata

logger = logging.getLogger(__name__)
Expand All @@ -36,10 +34,6 @@ class SqlClientException(Exception):
class BaseSqlClientImplementation(ABC, SqlClient):
"""Abstract implementation that other SQL clients are based on."""

def __init__(self) -> None: # noqa: D
self._request_id_to_thread: Dict[SqlRequestId, BaseSqlClientImplementation.SqlRequestExecutorThread] = {}
self._state_lock = threading.Lock()

def generate_health_check_tests(self, schema_name: str) -> List[Tuple[str, Any]]: # type: ignore
"""List of base health checks we want to perform."""
table_name = "health_report"
Expand Down Expand Up @@ -264,158 +258,10 @@ def render_bind_parameter_key(self, bind_parameter_key: str) -> str:
"""Wrap execution parameter key with syntax accepted by engine."""
return f":{bind_parameter_key}"

def async_query( # noqa: D
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
check_isolation_level(self, isolation_level)
with self._state_lock:
request_id = SqlRequestId(f"mf_rid__{random_id()}")
thread = BaseSqlClientImplementation.SqlRequestExecutorThread(
sql_client=self,
request_id=request_id,
statement=statement,
bind_parameters=bind_parameters,
extra_tag=extra_tags,
isolation_level=isolation_level,
)
self._request_id_to_thread[request_id] = thread
self._request_id_to_thread[request_id].start()
return request_id

def async_execute( # noqa: D
self,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> SqlRequestId:
check_isolation_level(self, isolation_level)
with self._state_lock:
request_id = SqlRequestId(f"mf_rid__{random_id()}")
thread = BaseSqlClientImplementation.SqlRequestExecutorThread(
sql_client=self,
request_id=request_id,
statement=statement,
bind_parameters=bind_parameters,
extra_tag=extra_tags,
is_query=False,
isolation_level=isolation_level,
)
self._request_id_to_thread[request_id] = thread
self._request_id_to_thread[request_id].start()
return request_id

def async_request_result(self, query_id: SqlRequestId) -> SqlRequestResult: # noqa: D
thread: Optional[BaseSqlClientImplementation.SqlRequestExecutorThread] = None
with self._state_lock:
thread = self._request_id_to_thread.get(query_id)
if thread is None:
raise RuntimeError(
f"Query ID: {query_id} is not known. Either the query ID is invalid, or results for the query ID "
f"were already fetched."
)

thread.join()
with self._state_lock:
del self._request_id_to_thread[query_id]
return thread.result

def active_requests(self) -> Sequence[SqlRequestId]: # noqa: D
with self._state_lock:
return tuple(executor_thread.request_id for executor_thread in self._request_id_to_thread.values())

@staticmethod
def _consolidate_tags(json_tags: SqlJsonTag, request_id: SqlRequestId) -> CombinedSqlTags:
"""Consolidates json tags and request ID into a single set of tags."""
return CombinedSqlTags(
system_tags=SqlRequestTagSet().add_request_id(request_id=request_id),
extra_tag=json_tags,
)

class SqlRequestExecutorThread(threading.Thread):
"""Thread that helps to execute a request to the SQL engine asynchronously."""

def __init__( # noqa: D
self,
sql_client: BaseSqlClientImplementation,
request_id: SqlRequestId,
statement: str,
bind_parameters: SqlBindParameters,
extra_tag: SqlJsonTag = SqlJsonTag(),
is_query: bool = True,
isolation_level: Optional[SqlIsolationLevel] = None,
) -> None:
"""Initializer.
Args:
sql_client: SQL client used to execute statements.
request_id: The request ID associated with the statement.
statement: The statement to execute.
bind_parameters: The parameters to use for the statement.
extra_tag: Tags that should be associated with the request for the statement.
is_query: Whether the request is for .query (returns data) or .execute (does not return data)
isolation_level: The isolation level to use for the query.
"""
self._sql_client = sql_client
self._request_id = request_id
self._statement = statement
self._bind_parameters = bind_parameters
self._extra_tag = extra_tag
self._result: Optional[SqlRequestResult] = None
self._is_query = is_query
self._isolation_level = isolation_level
super().__init__(name=f"Async Execute SQL Request ID: {request_id}", daemon=True)

def run(self) -> None: # noqa: D
start_time = time.time()
try:
combined_tags = BaseSqlClientImplementation._consolidate_tags(
json_tags=self._extra_tag, request_id=self._request_id
)
statement = SqlStatementCommentMetadata.add_tag_metadata_as_comment(
sql_statement=self._statement, combined_tags=combined_tags
)

logger.info(
BaseSqlClientImplementation._format_run_query_log_message(
statement=self._statement, sql_bind_parameters=self._bind_parameters
)
)

if self._is_query:
df = self._sql_client._engine_specific_query_implementation(
statement,
bind_params=self._bind_parameters,
isolation_level=self._isolation_level,
system_tags=combined_tags.system_tags,
extra_tags=self._extra_tag,
)
self._result = SqlRequestResult(df=df)
else:
self._sql_client._engine_specific_execute_implementation(
statement,
bind_params=self._bind_parameters,
isolation_level=self._isolation_level,
system_tags=combined_tags.system_tags,
extra_tags=self._extra_tag,
)
self._result = SqlRequestResult(df=pd.DataFrame())
logger.info(f"Successfully executed {self._request_id} in {time.time() - start_time:.2f}s")
except Exception as e:
logger.exception(
f"Unsuccessfully executed {self._request_id} in {time.time() - start_time:.2f}s with exception:"
)
self._result = SqlRequestResult(exception=e)

@property
def result(self) -> SqlRequestResult: # noqa: D
assert self._result is not None, ".result() should only be called once the thread is finished running"
return self._result

@property
def request_id(self) -> SqlRequestId: # noqa: D
return self._request_id
50 changes: 1 addition & 49 deletions metricflow/sql_clients/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
CONFIG_DWH_WAREHOUSE,
)
from metricflow.configuration.yaml_handler import YamlFileHandler
from metricflow.protocols.sql_client import SqlClient, SqlIsolationLevel
from metricflow.protocols.sql_request import SqlJsonTag
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_clients.base_sql_client_implementation import SqlClientException
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql_clients.big_query import BigQuerySqlClient
from metricflow.sql_clients.common_client import SqlDialect, not_empty
from metricflow.sql_clients.databricks import DatabricksSqlClient
Expand Down Expand Up @@ -157,48 +154,3 @@ def make_sql_client_from_config(handler: YamlFileHandler) -> SqlClient:
else:
supported_dialects = [x.value for x in SqlDialect]
raise ValueError(f"Invalid dialect '{dialect}', must be one of {supported_dialects} in {url}")


def sync_execute( # noqa: D
sql_client: SqlClient,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_sql_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> None:
request_id = sql_client.async_execute(
statement=statement,
bind_parameters=bind_parameters,
extra_tags=extra_sql_tags,
isolation_level=isolation_level,
)

result = sql_client.async_request_result(request_id)
if result.exception:
raise SqlClientException(
f"Got an exception when trying to execute a statement: {result.exception}"
) from result.exception
return


def sync_query( # noqa: D
sql_client: SqlClient,
statement: str,
bind_parameters: SqlBindParameters = SqlBindParameters(),
extra_sql_tags: SqlJsonTag = SqlJsonTag(),
isolation_level: Optional[SqlIsolationLevel] = None,
) -> pd.DataFrame:
request_id = sql_client.async_query(
statement=statement,
bind_parameters=bind_parameters,
extra_tags=extra_sql_tags,
isolation_level=isolation_level,
)

result = sql_client.async_request_result(request_id)
if result.exception:
raise SqlClientException(
f"Got an exception when trying to execute a statement: {result.exception}"
) from result.exception
assert result.df is not None, "A dataframe should have been returned if there was no error"
return result.df
Loading

0 comments on commit 069fd0d

Please sign in to comment.