Skip to content

Commit

Permalink
chore(sqla): refactor query utils (#21811)
Browse files Browse the repository at this point in the history
Co-authored-by: Ville Brofeldt <ville.brofeldt@apple.com>
  • Loading branch information
2 people authored and AAfghahi committed Nov 18, 2022
1 parent aa11e54 commit d779789
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 9 deletions.
26 changes: 19 additions & 7 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
from superset.jinja_context import (
Expand Down Expand Up @@ -647,19 +648,19 @@ def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
except (QueryClauseValidationException, SupersetSecurityException) as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression

Expand Down Expand Up @@ -1639,6 +1640,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
where = _process_sql_expression(
expression=where,
database_id=self.database_id,
schema=self.schema,
)
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
Expand All @@ -1651,7 +1657,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
having = _process_sql_expression(
expression=having,
database_id=self.database_id,
schema=self.schema,
)
having_clause_and += [self.text(having)]

if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
190 changes: 189 additions & 1 deletion tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import copy
from datetime import datetime
from io import BytesIO
from typing import Optional
from typing import Any, Dict, List, Optional
from unittest import mock
from zipfile import ZipFile

Expand Down Expand Up @@ -963,3 +963,191 @@ def test_chart_data_with_adhoc_column(self):
unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]


@pytest.fixture()
def physical_query_context(physical_dataset) -> Dict[str, Any]:
return {
"datasource": {
"type": physical_dataset.type,
"id": physical_dataset.id,
},
"queries": [
{
"columns": ["col1"],
"metrics": ["count"],
"orderby": [["col1", True]],
}
],
"result_type": ChartDataResultType.FULL,
"force": True,
}


@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"CACHE_DEFAULT_TIMEOUT": 1234,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": None,
},
},
)
def test_cache_default_timeout(test_client, login_as_admin, physical_query_context):
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1234


def test_custom_cache_timeout(test_client, login_as_admin, physical_query_context):
physical_query_context["custom_cache_timeout"] = 5678
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 5678


@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"CACHE_DEFAULT_TIMEOUT": 100000,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": 3456,
},
},
)
def test_data_cache_default_timeout(
test_client,
login_as_admin,
physical_query_context,
):
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 3456


def test_chart_cache_timeout(
test_client,
login_as_admin,
physical_query_context,
load_energy_table_with_slice: List[Slice],
):
# should override datasource cache timeout

slice_with_cache_timeout = load_energy_table_with_slice[0]
slice_with_cache_timeout.cache_timeout = 20
db.session.merge(slice_with_cache_timeout)

datasource: SqlaTable = (
db.session.query(SqlaTable)
.filter(SqlaTable.id == physical_query_context["datasource"]["id"])
.first()
)
datasource.cache_timeout = 1254
db.session.merge(datasource)

db.session.commit()

physical_query_context["form_data"] = {"slice_id": slice_with_cache_timeout.id}

rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 20


@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": 1010,
},
},
)
def test_chart_cache_timeout_not_present(
test_client, login_as_admin, physical_query_context
):
# should use datasource cache, if it's present

datasource: SqlaTable = (
db.session.query(SqlaTable)
.filter(SqlaTable.id == physical_query_context["datasource"]["id"])
.first()
)
datasource.cache_timeout = 1980
db.session.merge(datasource)
db.session.commit()

rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1980


@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": 1010,
},
},
)
def test_chart_cache_timeout_chart_not_found(
test_client, login_as_admin, physical_query_context
):
# should use default timeout

physical_query_context["form_data"] = {"slice_id": 0}

rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1010


@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_not_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

assert rv.status_code == status_code


@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

assert rv.status_code == status_code
2 changes: 1 addition & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_adhoc_metrics_and_calc_columns(self):
)
db.session.commit()

with pytest.raises(SupersetSecurityException):
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)
Expand Down

0 comments on commit d779789

Please sign in to comment.