diff --git a/metricflow/execution/execution_plan.py b/metricflow/execution/execution_plan.py index de8829515b..e5fdba5b8b 100644 --- a/metricflow/execution/execution_plan.py +++ b/metricflow/execution/execution_plan.py @@ -16,7 +16,6 @@ from metricflow.protocols.sql_client import SqlClient from metricflow.protocols.sql_request import SqlJsonTag from metricflow.sql.sql_bind_parameters import SqlBindParameters -from metricflow.sql_clients.sql_utils import sync_execute, sync_query from metricflow.visitor import Visitable logger = logging.getLogger(__name__) @@ -126,11 +125,10 @@ def bind_parameters(self) -> SqlBindParameters: # noqa: D def execute(self) -> TaskExecutionResult: # noqa: D start_time = time.time() - df = sync_query( - self._sql_client, + df = self._sql_client.query( self._sql_query, - bind_parameters=self.bind_parameters, - extra_sql_tags=self._extra_sql_tags, + sql_bind_parameters=self.bind_parameters, + extra_tags=self._extra_sql_tags, ) end_time = time.time() @@ -191,11 +189,10 @@ def execute(self) -> TaskExecutionResult: # noqa: D logger.info(f"Creating table {self._output_table} using a SELECT query") sql_query = self.sql_query assert sql_query - sync_execute( - self._sql_client, + self._sql_client.execute( sql_query.sql_query, - bind_parameters=sql_query.bind_parameters, - extra_sql_tags=self._extra_sql_tags, + sql_bind_parameters=sql_query.bind_parameters, + extra_tags=self._extra_sql_tags, ) end_time = time.time() diff --git a/metricflow/protocols/sql_client.py b/metricflow/protocols/sql_client.py index 1a247032d3..1f90712fd6 100644 --- a/metricflow/protocols/sql_client.py +++ b/metricflow/protocols/sql_client.py @@ -7,7 +7,7 @@ from pandas import DataFrame from metricflow.dataflow.sql_table import SqlTable -from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestResult +from metricflow.protocols.sql_request import SqlJsonTag from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer from metricflow.sql.sql_bind_parameters import SqlBindParameters @@ -167,40 +167,6 @@ def render_bind_parameter_key(self, bind_parameter_key: str) -> str: """Wrap the bind parameter key with syntax accepted by engine.""" raise NotImplementedError - def async_query( - self, - statement: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), - extra_tags: SqlJsonTag = SqlJsonTag(), - isolation_level: Optional[SqlIsolationLevel] = None, - ) -> SqlRequestId: - """Execute a query asynchronously.""" - raise NotImplementedError - - @abstractmethod - def async_request_result(self, request_id: SqlRequestId) -> SqlRequestResult: - """Wait until a async query has finished, and then return the result.""" - raise NotImplementedError - - @abstractmethod - def async_execute( - self, - statement: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), - extra_tags: SqlJsonTag = SqlJsonTag(), - isolation_level: Optional[SqlIsolationLevel] = None, - ) -> SqlRequestId: - """Execute a statement that does not return values asynchronously.""" - raise NotImplementedError - - @abstractmethod - def active_requests(self) -> Sequence[SqlRequestId]: - """Return requests that are still in progress. - - If the results for a request have not yet been fetched with async_request_result(), it's considered in progress. - """ - raise NotImplementedError - class SqlEngineAttributes(Protocol): """Base interface for SQL engine-specific attributes and features. diff --git a/metricflow/sql_clients/base_sql_client_implementation.py b/metricflow/sql_clients/base_sql_client_implementation.py index a3eda72cf7..bc2160ca4f 100644 --- a/metricflow/sql_clients/base_sql_client_implementation.py +++ b/metricflow/sql_clients/base_sql_client_implementation.py @@ -2,10 +2,9 @@ import logging import textwrap -import threading import time from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Tuple import jinja2 import pandas as pd @@ -18,10 +17,9 @@ SqlEngineAttributes, SqlIsolationLevel, ) -from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestResult, SqlRequestTagSet +from metricflow.protocols.sql_request import SqlJsonTag, SqlRequestId, SqlRequestTagSet from metricflow.random_id import random_id from metricflow.sql.sql_bind_parameters import SqlBindParameters -from metricflow.sql_clients.common_client import check_isolation_level from metricflow.sql_clients.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata logger = logging.getLogger(__name__) @@ -36,10 +34,6 @@ class SqlClientException(Exception): class BaseSqlClientImplementation(ABC, SqlClient): """Abstract implementation that other SQL clients are based on.""" - def __init__(self) -> None: # noqa: D - self._request_id_to_thread: Dict[SqlRequestId, BaseSqlClientImplementation.SqlRequestExecutorThread] = {} - self._state_lock = threading.Lock() - def generate_health_check_tests(self, schema_name: str) -> List[Tuple[str, Any]]: # type: ignore """List of base health checks we want to perform.""" table_name = "health_report" @@ -264,70 +258,6 @@ def render_bind_parameter_key(self, bind_parameter_key: str) -> str: """Wrap execution parameter key with syntax accepted by engine.""" return f":{bind_parameter_key}" - def async_query( # noqa: D - self, - statement: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), - extra_tags: SqlJsonTag = SqlJsonTag(), - isolation_level: Optional[SqlIsolationLevel] = None, - ) -> SqlRequestId: - check_isolation_level(self, isolation_level) - with self._state_lock: - request_id = SqlRequestId(f"mf_rid__{random_id()}") - thread = BaseSqlClientImplementation.SqlRequestExecutorThread( - sql_client=self, - request_id=request_id, - statement=statement, - bind_parameters=bind_parameters, - extra_tag=extra_tags, - isolation_level=isolation_level, - ) - self._request_id_to_thread[request_id] = thread - self._request_id_to_thread[request_id].start() - return request_id - - def async_execute( # noqa: D - self, - statement: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), - extra_tags: SqlJsonTag = SqlJsonTag(), - isolation_level: Optional[SqlIsolationLevel] = None, - ) -> SqlRequestId: - check_isolation_level(self, isolation_level) - with self._state_lock: - request_id = SqlRequestId(f"mf_rid__{random_id()}") - thread = BaseSqlClientImplementation.SqlRequestExecutorThread( - sql_client=self, - request_id=request_id, - statement=statement, - bind_parameters=bind_parameters, - extra_tag=extra_tags, - is_query=False, - isolation_level=isolation_level, - ) - self._request_id_to_thread[request_id] = thread - self._request_id_to_thread[request_id].start() - return request_id - - def async_request_result(self, query_id: SqlRequestId) -> SqlRequestResult: # noqa: D - thread: Optional[BaseSqlClientImplementation.SqlRequestExecutorThread] = None - with self._state_lock: - thread = self._request_id_to_thread.get(query_id) - if thread is None: - raise RuntimeError( - f"Query ID: {query_id} is not known. Either the query ID is invalid, or results for the query ID " - f"were already fetched." - ) - - thread.join() - with self._state_lock: - del self._request_id_to_thread[query_id] - return thread.result - - def active_requests(self) -> Sequence[SqlRequestId]: # noqa: D - with self._state_lock: - return tuple(executor_thread.request_id for executor_thread in self._request_id_to_thread.values()) - @staticmethod def _consolidate_tags(json_tags: SqlJsonTag, request_id: SqlRequestId) -> CombinedSqlTags: """Consolidates json tags and request ID into a single set of tags.""" @@ -335,87 +265,3 @@ def _consolidate_tags(json_tags: SqlJsonTag, request_id: SqlRequestId) -> Combin system_tags=SqlRequestTagSet().add_request_id(request_id=request_id), extra_tag=json_tags, ) - - class SqlRequestExecutorThread(threading.Thread): - """Thread that helps to execute a request to the SQL engine asynchronously.""" - - def __init__( # noqa: D - self, - sql_client: BaseSqlClientImplementation, - request_id: SqlRequestId, - statement: str, - bind_parameters: SqlBindParameters, - extra_tag: SqlJsonTag = SqlJsonTag(), - is_query: bool = True, - isolation_level: Optional[SqlIsolationLevel] = None, - ) -> None: - """Initializer. - - Args: - sql_client: SQL client used to execute statements. - request_id: The request ID associated with the statement. - statement: The statement to execute. - bind_parameters: The parameters to use for the statement. - extra_tag: Tags that should be associated with the request for the statement. - is_query: Whether the request is for .query (returns data) or .execute (does not return data) - isolation_level: The isolation level to use for the query. - """ - self._sql_client = sql_client - self._request_id = request_id - self._statement = statement - self._bind_parameters = bind_parameters - self._extra_tag = extra_tag - self._result: Optional[SqlRequestResult] = None - self._is_query = is_query - self._isolation_level = isolation_level - super().__init__(name=f"Async Execute SQL Request ID: {request_id}", daemon=True) - - def run(self) -> None: # noqa: D - start_time = time.time() - try: - combined_tags = BaseSqlClientImplementation._consolidate_tags( - json_tags=self._extra_tag, request_id=self._request_id - ) - statement = SqlStatementCommentMetadata.add_tag_metadata_as_comment( - sql_statement=self._statement, combined_tags=combined_tags - ) - - logger.info( - BaseSqlClientImplementation._format_run_query_log_message( - statement=self._statement, sql_bind_parameters=self._bind_parameters - ) - ) - - if self._is_query: - df = self._sql_client._engine_specific_query_implementation( - statement, - bind_params=self._bind_parameters, - isolation_level=self._isolation_level, - system_tags=combined_tags.system_tags, - extra_tags=self._extra_tag, - ) - self._result = SqlRequestResult(df=df) - else: - self._sql_client._engine_specific_execute_implementation( - statement, - bind_params=self._bind_parameters, - isolation_level=self._isolation_level, - system_tags=combined_tags.system_tags, - extra_tags=self._extra_tag, - ) - self._result = SqlRequestResult(df=pd.DataFrame()) - logger.info(f"Successfully executed {self._request_id} in {time.time() - start_time:.2f}s") - except Exception as e: - logger.exception( - f"Unsuccessfully executed {self._request_id} in {time.time() - start_time:.2f}s with exception:" - ) - self._result = SqlRequestResult(exception=e) - - @property - def result(self) -> SqlRequestResult: # noqa: D - assert self._result is not None, ".result() should only be called once the thread is finished running" - return self._result - - @property - def request_id(self) -> SqlRequestId: # noqa: D - return self._request_id diff --git a/metricflow/sql_clients/sql_utils.py b/metricflow/sql_clients/sql_utils.py index 03b60d0233..41330b1f01 100644 --- a/metricflow/sql_clients/sql_utils.py +++ b/metricflow/sql_clients/sql_utils.py @@ -21,10 +21,7 @@ CONFIG_DWH_WAREHOUSE, ) from metricflow.configuration.yaml_handler import YamlFileHandler -from metricflow.protocols.sql_client import SqlClient, SqlIsolationLevel -from metricflow.protocols.sql_request import SqlJsonTag -from metricflow.sql.sql_bind_parameters import SqlBindParameters -from metricflow.sql_clients.base_sql_client_implementation import SqlClientException +from metricflow.protocols.sql_client import SqlClient from metricflow.sql_clients.big_query import BigQuerySqlClient from metricflow.sql_clients.common_client import SqlDialect, not_empty from metricflow.sql_clients.databricks import DatabricksSqlClient @@ -157,48 +154,3 @@ def make_sql_client_from_config(handler: YamlFileHandler) -> SqlClient: else: supported_dialects = [x.value for x in SqlDialect] raise ValueError(f"Invalid dialect '{dialect}', must be one of {supported_dialects} in {url}") - - -def sync_execute( # noqa: D - sql_client: SqlClient, - statement: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), - extra_sql_tags: SqlJsonTag = SqlJsonTag(), - isolation_level: Optional[SqlIsolationLevel] = None, -) -> None: - request_id = sql_client.async_execute( - statement=statement, - bind_parameters=bind_parameters, - extra_tags=extra_sql_tags, - isolation_level=isolation_level, - ) - - result = sql_client.async_request_result(request_id) - if result.exception: - raise SqlClientException( - f"Got an exception when trying to execute a statement: {result.exception}" - ) from result.exception - return - - -def sync_query( # noqa: D - sql_client: SqlClient, - statement: str, - bind_parameters: SqlBindParameters = SqlBindParameters(), - extra_sql_tags: SqlJsonTag = SqlJsonTag(), - isolation_level: Optional[SqlIsolationLevel] = None, -) -> pd.DataFrame: - request_id = sql_client.async_query( - statement=statement, - bind_parameters=bind_parameters, - extra_tags=extra_sql_tags, - isolation_level=isolation_level, - ) - - result = sql_client.async_request_result(request_id) - if result.exception: - raise SqlClientException( - f"Got an exception when trying to execute a statement: {result.exception}" - ) from result.exception - assert result.df is not None, "A dataframe should have been returned if there was no error" - return result.df diff --git a/metricflow/test/sql_clients/test_async.py b/metricflow/test/sql_clients/test_async.py deleted file mode 100644 index 6618e897c4..0000000000 --- a/metricflow/test/sql_clients/test_async.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -import json -import logging - -import pytest -from dbt_semantic_interfaces.enum_extension import assert_values_exhausted - -from metricflow.dataflow.sql_table import SqlTable -from metricflow.protocols.sql_client import SqlClient, SqlEngine -from metricflow.protocols.sql_request import MF_EXTRA_TAGS_KEY, SqlJsonTag -from metricflow.sql_clients.sql_utils import make_df -from metricflow.test.compare_df import assert_dataframes_equal -from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState - -logger = logging.getLogger(__name__) - - -def create_table_with_n_rows(sql_client: SqlClient, schema_name: str, num_rows: int) -> SqlTable: - """Create a table with a specific number of rows.""" - sql_table = SqlTable( - schema_name=schema_name, - table_name=f"table_with_{num_rows}_rows", - ) - sql_client.drop_table(sql_table) - sql_client.create_table_from_dataframe( - sql_table=sql_table, - df=make_df(sql_client=sql_client, columns=["example_string"], data=(("foo",) for _ in range(num_rows))), - ) - return sql_table - - -def test_async_query(sql_client: SqlClient, mf_test_session_state: MetricFlowTestSessionState) -> None: # noqa: D - request_id = sql_client.async_query("SELECT 1 AS foo") - result = sql_client.async_request_result(request_id) - assert_dataframes_equal( - actual=result.df, - expected=make_df(sql_client=sql_client, columns=["foo"], data=((1,),)), - ) - assert result.exception is None - - -def test_async_execute(sql_client: SqlClient, mf_test_session_state: MetricFlowTestSessionState) -> None: # noqa: D - request_id = sql_client.async_execute("SELECT 1 AS foo") - result = sql_client.async_request_result(request_id) - assert result.exception is None - - -def test_isolation_level(mf_test_session_state: MetricFlowTestSessionState, sql_client: SqlClient) -> None: # noqa: D - for isolation_level in sql_client.sql_engine_attributes.supported_isolation_levels: - logger.info(f"Testing isolation level: {isolation_level}") - request_id = sql_client.async_query("SELECT 1", isolation_level=isolation_level) - sql_client.async_request_result(request_id) - - -def test_request_tags( - mf_test_session_state: MetricFlowTestSessionState, - sql_client: SqlClient, -) -> None: - """Test whether request tags are appropriately used in queries to the SQL engine.""" - engine_type = sql_client.sql_engine_attributes.sql_engine_type - extra_tags = SqlJsonTag({"example_key": "example_value"}) - if engine_type is SqlEngine.SNOWFLAKE: - request_id0 = sql_client.async_query( - "SHOW PARAMETERS LIKE 'QUERY_TAG'", - extra_tags=extra_tags, - ) - result0 = sql_client.async_request_result(request_id0) - df = result0.df - assert df is not None - assert result0.exception is None - - assert len(df.index) == 1 - tag_json = json.loads(df.iloc[0]["value"]) - assert MF_EXTRA_TAGS_KEY in tag_json - assert tag_json[MF_EXTRA_TAGS_KEY] == {"example_key": "example_value"} - elif ( - engine_type is SqlEngine.DUCKDB - or engine_type is SqlEngine.BIGQUERY - or engine_type is SqlEngine.REDSHIFT - or engine_type is SqlEngine.DATABRICKS - or engine_type is SqlEngine.POSTGRES - ): - pytest.skip(f"Testing tags not supported in {engine_type}") - else: - assert_values_exhausted(engine_type)