diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 43bbcee12563d..4855dd1af3cca 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -95,6 +95,7 @@ DatasetInvalidPermissionEvaluationException, QueryClauseValidationException, QueryObjectValidationError, + SupersetSecurityException, ) from superset.extensions import feature_flag_manager from superset.jinja_context import ( @@ -655,19 +656,19 @@ def _process_sql_expression( expression: Optional[str], database_id: int, schema: str, - template_processor: Optional[BaseTemplateProcessor], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Optional[str]: if template_processor and expression: expression = template_processor.process_template(expression) if expression: - expression = validate_adhoc_subquery( - expression, - database_id, - schema, - ) try: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: + except (QueryClauseValidationException, SupersetSecurityException) as ex: raise QueryObjectValidationError(ex.message) from ex return expression @@ -1672,6 +1673,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + where = _process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + ) where_clause_and += [self.text(where)] having = extras.get("having") if having: @@ -1684,7 +1690,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex + having = _process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + ) having_clause_and += [self.text(having)] + if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index acf44be6f561e..8b2fd993886ce 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -1098,3 +1098,53 @@ def test_chart_cache_timeout_chart_not_found( rv = test_client.post(CHART_DATA_URI, json=physical_query_context) assert rv.json["result"][0]["cache_timeout"] == 1010 + + +@pytest.mark.parametrize( + "status_code,extras", + [ + (200, {"where": "1 = 1"}), + (200, {"having": "count(*) > 0"}), + (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (400, {"having": "count(*) > (select count(*) from physical_dataset)"}), + ], +) +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +def test_chart_data_subquery_not_allowed( + test_client, + login_as_admin, + physical_dataset, + physical_query_context, + status_code, + extras, +): + physical_query_context["queries"][0]["extras"] = extras + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + + assert rv.status_code == status_code + + +@pytest.mark.parametrize( + "status_code,extras", + [ + (200, {"where": "1 = 1"}), + (200, {"having": "count(*) > 0"}), + (200, {"where": "col1 in (select distinct col1 from physical_dataset)"}), + (200, {"having": "count(*) > (select count(*) from physical_dataset)"}), + ], +) +@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True) +@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") +def test_chart_data_subquery_allowed( + test_client, + login_as_admin, + physical_dataset, + physical_query_context, + status_code, + extras, +): + physical_query_context["queries"][0]["extras"] = extras + rv = test_client.post(CHART_DATA_URI, json=physical_query_context) + + assert rv.status_code == status_code diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index fc37e17b57f2d..b3ec031fc80de 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -262,7 +262,7 @@ def test_adhoc_metrics_and_calc_columns(self): ) db.session.commit() - with pytest.raises(SupersetSecurityException): + with pytest.raises(QueryObjectValidationError): table.get_sqla_query(**base_query_obj) # Cleanup db.session.delete(table)