Skip to content

Commit

Permalink
fix(Jinja): Extra cache keys to consider vars with set (#30549)
Browse files Browse the repository at this point in the history
  • Loading branch information
geido authored Oct 9, 2024
1 parent 211564a commit 318eff7
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 75 deletions.
15 changes: 8 additions & 7 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ class ExtraCache:
# Regular expression for detecting the presence of templated methods which could
# be added to the cache key.
regex = re.compile(
r"\{\{.*("
r"current_user_id\(.*\)|"
r"current_username\(.*\)|"
r"current_user_email\(.*\)|"
r"cache_key_wrapper\(.*\)|"
r"url_param\(.*\)"
r").*\}\}"
r"(\{\{|\{%)[^{}]*?("
r"current_user_id\([^()]*\)|"
r"current_username\([^()]*\)|"
r"current_user_email\([^()]*\)|"
r"cache_key_wrapper\([^()]*\)|"
r"url_param\([^()]*\)"
r")"
r"[^{}]*?(\}\}|\%\})"
)

def __init__( # pylint: disable=too-many-arguments
Expand Down
194 changes: 126 additions & 68 deletions tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,74 +133,6 @@ def test_db_column_types(self):
col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True)
assert col.is_temporal

@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
@patch("superset.jinja_context.get_user_email", return_value="abc@test.com")
def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["id", "username", "email"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}

# Table with Jinja callable.
table1 = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="""
SELECT
'{{ current_user_id() }}' as id,
'{{ current_username() }}' as username,
'{{ current_user_email() }}' as email
""",
database=get_example_database(),
)

query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
assert table1.has_extra_cache_key_calls(query_obj)
assert set(extra_cache_keys) == {1, "abc", "abc@test.com"}

# Table with Jinja callable disabled.
table2 = SqlaTable(
table_name="test_has_extra_cache_keys_disabled_table",
sql="""
SELECT
'{{ current_user_id(False) }}' as id,
'{{ current_username(False) }}' as username,
'{{ current_user_email(False) }}' as email,
""",
database=get_example_database(),
)
query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table2.get_extra_cache_keys(query_obj)
assert table2.has_extra_cache_key_calls(query_obj)
self.assertListEqual(extra_cache_keys, []) # noqa: PT009

# Table with no Jinja callable.
query = "SELECT 'abc' as user"
table3 = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql=query,
database=get_example_database(),
)

query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"})
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
assert not table3.has_extra_cache_key_calls(query_obj)
self.assertListEqual(extra_cache_keys, []) # noqa: PT009

# With Jinja callable in SQL expression.
query_obj = dict(
**base_query_obj, extras={"where": "(user != '{{ current_username() }}')"}
)
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
assert table3.has_extra_cache_key_calls(query_obj)
assert extra_cache_keys == ["abc"]

@patch("superset.jinja_context.get_username", return_value="abc")
def test_jinja_metrics_and_calc_columns(self, mock_username):
base_query_obj = {
Expand Down Expand Up @@ -859,6 +791,132 @@ def test_none_operand_in_filter(login_as_admin, physical_dataset):
)


@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"table_name,sql,expected_cache_keys,has_extra_cache_keys",
[
(
"test_has_extra_cache_keys_table",
"""
SELECT
'{{ current_user_id() }}' as id,
'{{ current_username() }}' as username,
'{{ current_user_email() }}' as email
""",
{1, "abc", "abc@test.com"},
True,
),
(
"test_has_extra_cache_keys_table_with_set",
"""
{% set user_email = current_user_email() %}
SELECT
'{{ current_user_id() }}' as id,
'{{ current_username() }}' as username,
'{{ user_email }}' as email
""",
{1, "abc", "abc@test.com"},
True,
),
(
"test_has_extra_cache_keys_table_with_se_multiple",
"""
{% set user_conditional_id = current_user_email() and current_user_id() %}
SELECT
'{{ user_conditional_id }}' as conditional
""",
{1, "abc@test.com"},
True,
),
(
"test_has_extra_cache_keys_disabled_table",
"""
SELECT
'{{ current_user_id(False) }}' as id,
'{{ current_username(False) }}' as username,
'{{ current_user_email(False) }}' as email
""",
[],
True,
),
("test_has_no_extra_cache_keys_table", "SELECT 'abc' as user", [], False),
],
)
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
@patch("superset.jinja_context.get_user_email", return_value="abc@test.com")
def test_extra_cache_keys(
mock_user_email,
mock_username,
mock_user_id,
table_name,
sql,
expected_cache_keys,
has_extra_cache_keys,
):
table = SqlaTable(
table_name=table_name,
sql=sql,
database=get_example_database(),
)
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["id", "username", "email"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}

query_obj = dict(**base_query_obj, extras={})

extra_cache_keys = table.get_extra_cache_keys(query_obj)
assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
assert set(extra_cache_keys) == set(expected_cache_keys)


@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"sql_expression,expected_cache_keys,has_extra_cache_keys",
[
("(user != '{{ current_username() }}')", ["abc"], True),
("(user != 'abc')", [], False),
],
)
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
@patch("superset.jinja_context.get_user_email", return_value="abc@test.com")
def test_extra_cache_keys_in_sql_expression(
mock_user_email,
mock_username,
mock_user_id,
sql_expression,
expected_cache_keys,
has_extra_cache_keys,
):
table = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql="SELECT 'abc' as user",
database=get_example_database(),
)
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["id", "username", "email"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}

query_obj = dict(**base_query_obj, extras={"where": sql_expression})

extra_cache_keys = table.get_extra_cache_keys(query_obj)
assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
assert extra_cache_keys == expected_cache_keys


@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"row,dimension,result",
Expand Down

0 comments on commit 318eff7

Please sign in to comment.