Skip to content

Commit

Permalink
fix(Explore): Apply RLS at column values (#30490)
Browse files Browse the repository at this point in the history
Co-authored-by: Beto Dealmeida <roberto@dealmeida.net>
  • Loading branch information
geido and betodealmeida authored Oct 4, 2024
1 parent 0b34197 commit f314685
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 1 deletion.
5 changes: 4 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@ def get_time_filter( # pylint: disable=too-many-arguments
)
return and_(*l)

def values_for_column(
def values_for_column( # pylint: disable=too-many-locals
self,
column_name: str,
limit: int = 10000,
Expand Down Expand Up @@ -1345,6 +1345,9 @@ def values_for_column(
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))

rls_filters = self.get_sqla_row_level_filters(template_processor=tp)
qry = qry.where(and_(*rls_filters))

with self.database.get_sqla_engine() as engine:
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
sql = self._apply_cte(sql, cte)
Expand Down
29 changes: 29 additions & 0 deletions tests/integration_tests/datasource/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from unittest.mock import ANY, patch

import pytest
from sqlalchemy.sql.elements import TextClause

from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -176,3 +177,31 @@ def test_get_column_values_denormalize_column(self, denormalize_name_mock):
table.normalize_columns = False
self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/") # noqa: F841
denormalize_name_mock.assert_called_with(ANY, "col2")

@pytest.mark.usefixtures("app_context", "virtual_dataset")
def test_get_column_values_with_rls(self):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
with patch.object(
table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'b'")]
):
rv = self.client.get(
f"api/v1/datasource/table/{table.id}/column/col2/values/"
)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["result"], ["b"])

@pytest.mark.usefixtures("app_context", "virtual_dataset")
def test_get_column_values_with_rls_no_values(self):
self.login(ADMIN_USERNAME)
table = self.get_virtual_dataset()
with patch.object(
table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'q'")]
):
rv = self.client.get(
f"api/v1/datasource/table/{table.id}/column/col2/values/"
)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["result"], [])
26 changes: 26 additions & 0 deletions tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,32 @@ def test_values_for_column_on_text_column(text_column_table):
assert len(with_null) == 8


def test_values_for_column_on_text_column_with_rls(text_column_table):
with patch.object(
text_column_table,
"get_sqla_row_level_filters",
return_value=[
TextClause("foo = 'foo'"),
],
):
with_rls = text_column_table.values_for_column(column_name="foo", limit=10000)
assert with_rls == ["foo"]
assert len(with_rls) == 1


def test_values_for_column_on_text_column_with_rls_no_values(text_column_table):
with patch.object(
text_column_table,
"get_sqla_row_level_filters",
return_value=[
TextClause("foo = 'bar'"),
],
):
with_rls = text_column_table.values_for_column(column_name="foo", limit=10000)
assert with_rls == []
assert len(with_rls) == 0


def test_filter_on_text_column(text_column_table):
table = text_column_table
# null value should be replaced
Expand Down
53 changes: 53 additions & 0 deletions tests/unit_tests/models/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from contextlib import contextmanager
from typing import TYPE_CHECKING
from unittest.mock import patch

import pytest
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -85,6 +86,58 @@ def test_values_for_column(database: Database) -> None:
assert table.values_for_column("a") == [1, None]


def test_values_for_column_with_rls(database: Database) -> None:
"""
Test the `values_for_column` method with RLS enabled.
"""
from sqlalchemy.sql.elements import TextClause

from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a"),
],
)
with patch.object(
table,
"get_sqla_row_level_filters",
return_value=[
TextClause("a = 1"),
],
):
assert table.values_for_column("a") == [1]


def test_values_for_column_with_rls_no_values(database: Database) -> None:
"""
Test the `values_for_column` method with RLS enabled and no values.
"""
from sqlalchemy.sql.elements import TextClause

from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(column_name="a"),
],
)
with patch.object(
table,
"get_sqla_row_level_filters",
return_value=[
TextClause("a = 2"),
],
):
assert table.values_for_column("a") == []


def test_values_for_column_calculated(
mocker: MockerFixture,
database: Database,
Expand Down

0 comments on commit f314685

Please sign in to comment.