From d7797896523461a2b3d1ca4aa7f48645cd4789e4 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 17 Oct 2022 10:40:42 +0100 Subject: [PATCH] chore(sqla): refactor query utils (#21811) Co-authored-by: Ville Brofeldt --- superset/connectors/sqla/models.py | 26 ++- .../charts/data/api_tests.py | 190 +++++++++++++++++- tests/integration_tests/sqla_models_tests.py | 2 +- 3 files changed, 209 insertions(+), 9 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index efd67cdd9a953..d7a6c206777ca 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -92,6 +92,7 @@ DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, QueryObjectValidationError, + SupersetSecurityException, ) from superset.extensions import feature_flag_manager from superset.jinja_context import ( @@ -647,19 +648,19 @@ def _process_sql_expression( expression: Optional[str], database_id: int, schema: str, - template_processor: Optional[BaseTemplateProcessor], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Optional[str]: if template_processor and expression: expression = template_processor.process_template(expression) if expression: - expression = validate_adhoc_subquery( - expression, - database_id, - schema, - ) try: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: + except (QueryClauseValidationException, SupersetSecurityException) as ex: raise QueryObjectValidationError(ex.message) from ex return expression @@ -1639,6 +1640,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + where = _process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + ) where_clause_and += [self.text(where)] having = extras.get("having") if having: @@ -1651,7 +1657,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + having = _process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + ) having_clause_and += [self.text(having)] + if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 6bbed00759c6e..b53e3b18d6564 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -21,7 +21,7 @@ import copy from datetime import datetime from io import BytesIO -from typing import Optional +from typing import Any, Dict, List, Optional from unittest import mock from zipfile import ZipFile @@ -963,3 +963,191 @@ def test_chart_data_with_adhoc_column(self): unique_genders = {row["male_or_female"] for row in data} assert unique_genders == {"male", "female"} assert result["applied_filters"] == [{"column": "male_or_female"}] + + +@pytest.fixture() +def physical_query_context(physical_dataset) -> Dict[str, Any]: + return { + "datasource": { + "type": physical_dataset.type, + "id": physical_dataset.id, + }, + "queries": [ + { + "columns": ["col1"], + "metrics": ["count"], + "orderby": [["col1", True]], + } + ], + "result_type": ChartDataResultType.FULL, + "force": True, + } + + +@mock.patch( + "superset.common.query_context_processor.config", + { + **app.config, + "CACHE_DEFAULT_TIMEOUT": 1234, + "DATA_CACHE_CONFIG": { + **app.config["DATA_CACHE_CONFIG"], + "CACHE_DEFAULT_TIMEOUT": None, + }, + }, +) +def test_cache_default_timeout(test_client, login_as_admin, physical_query_context): + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 1234 + + +def test_custom_cache_timeout(test_client, login_as_admin, physical_query_context): + physical_query_context["custom_cache_timeout"] = 5678 + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 5678 + + +@mock.patch( + "superset.common.query_context_processor.config", + { + **app.config, + "CACHE_DEFAULT_TIMEOUT": 100000, + "DATA_CACHE_CONFIG": { + **app.config["DATA_CACHE_CONFIG"], + "CACHE_DEFAULT_TIMEOUT": 3456, + }, + }, +) +def test_data_cache_default_timeout( + test_client, + login_as_admin, + physical_query_context, +): + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 3456 + + +def test_chart_cache_timeout( + test_client, + login_as_admin, + physical_query_context, + load_energy_table_with_slice: List[Slice], +): + # should override datasource cache timeout + + slice_with_cache_timeout = load_energy_table_with_slice[0] + slice_with_cache_timeout.cache_timeout = 20 + db.session.merge(slice_with_cache_timeout) + + datasource: SqlaTable = ( + db.session.query(SqlaTable) + .filter(SqlaTable.id == physical_query_context["datasource"]["id"]) + .first() + ) + datasource.cache_timeout = 1254 + db.session.merge(datasource) + + db.session.commit() + + physical_query_context["form_data"] = {"slice_id": slice_with_cache_timeout.id} + + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 20 + + +@mock.patch( + "superset.common.query_context_processor.config", + { + **app.config, + "DATA_CACHE_CONFIG": { + **app.config["DATA_CACHE_CONFIG"], + "CACHE_DEFAULT_TIMEOUT": 1010, + }, + }, +) +def test_chart_cache_timeout_not_present( + test_client, login_as_admin, physical_query_context +): + # should use datasource cache, if it's present + + datasource: SqlaTable = ( + db.session.query(SqlaTable) + .filter(SqlaTable.id == physical_query_context["datasource"]["id"]) + .first() + ) + datasource.cache_timeout = 1980 + db.session.merge(datasource) + db.session.commit() + + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 1980 + + +@mock.patch( + "superset.common.query_context_processor.config", + { + **app.config, + "DATA_CACHE_CONFIG": { + **app.config["DATA_CACHE_CONFIG"], + "CACHE_DEFAULT_TIMEOUT": 1010, + }, + }, +) +def test_chart_cache_timeout_chart_not_found( + test_client, login_as_admin, physical_query_context +): + # should use default timeout + + physical_query_context["form_data"] = {"slice_id": 0} + + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + assert rv.json["result"][0]["cache_timeout"] == 1010 + + +@pytest.mark.parametrize( + "status_code,extras", + [ + (200, {"where": "1 = 1"}), + (200, {"having": "count(*) > 0"}), + (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), + ], +) +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +def test_chart_data_subquery_not_allowed( + test_client, + login_as_admin, + physical_dataset, + physical_query_context, + status_code, + extras, +): + physical_query_context["queries"][0]["extras"] = extras + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + + assert rv.status_code == status_code + + +@pytest.mark.parametrize( + "status_code,extras", + [ + (200, {"where": "1 = 1"}), + (200, {"having": "count(*) > 0"}), + (200, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (200, {"having": "count(*) > (select count(*) from physical_dataset)"}), + ], +) +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +def test_chart_data_subquery_allowed( + test_client, + login_as_admin, + physical_dataset, + physical_query_context, + status_code, + extras, +): + physical_query_context["queries"][0]["extras"] = extras + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + + assert rv.status_code == status_code diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 6c5b6736d1a15..5614ad263a4f2 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -261,7 +261,7 @@ def test_adhoc_metrics_and_calc_columns(self): ) db.session.commit() - with pytest.raises(SupersetSecurityException): + with pytest.raises(QueryObjectValidationError): table.get_sqla_query(**base_query_obj) # Cleanup db.session.delete(table)