diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx index 7b95b6d0f3492..368c098bad406 100644 --- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx +++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx @@ -47,9 +47,9 @@ jest.mock(
{column.name}
), ); -const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*'; +const getTableMetadataEndpoint = /\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/; const getExtraTableMetadataEndpoint = - 'glob:**/api/v1/database/*/table_metadata/extra/*'; + /\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/; const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded'; beforeEach(() => { diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx index c964fc32faaf0..b3f8aec8f99a0 100644 --- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx +++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx @@ -75,7 +75,7 @@ const DatasetPanelWrapper = ({ setLoading(true); setHasColumns?.(false); const path = schema - ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/` + ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}` : `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`; try { const response = await SupersetClient.get({ diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d3a681eb1f387..524ba71210b66 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -331,7 +331,7 @@ def short_data(self) -> dict[str, Any]: "edit_url": self.url, "id": self.id, "uid": self.uid, - "schema": self.schema, + "schema": self.schema or None, "name": self.name, "type": self.type, "connection": self.connection, @@ -385,7 +385,7 @@ def data(self) -> dict[str, Any]: "datasource_name": self.datasource_name, "table_name": self.datasource_name, "type": self.type, - "schema": self.schema, + "schema": self.schema or None, "offset": self.offset, "cache_timeout": self.cache_timeout, "params": self.params, @@ -1266,7 +1266,7 @@ def link(self) -> Markup: def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" - return security_manager.get_schema_perm(self.database, self.schema) + return security_manager.get_schema_perm(self.database, self.schema or None) def get_perm(self) -> str: """ @@ -1323,7 +1323,7 @@ def external_metadata(self) -> list[ResultSetColumnType]: return get_virtual_table_metadata(dataset=self) return get_physical_table_metadata( database=self.database, - table=Table(self.table_name, self.schema, self.catalog), + table=Table(self.table_name, self.schema or None, self.catalog), normalize_columns=self.normalize_columns, ) @@ -1339,7 +1339,7 @@ def select_star(self) -> str | None: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( - Table(self.table_name, self.schema, self.catalog), + Table(self.table_name, self.schema or None, self.catalog), show_cols=False, latest_partition=False, ) @@ -1545,7 +1545,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals col_desc = get_columns_description( self.database, self.catalog, - self.schema, + self.schema or None, sql, ) if not col_desc: @@ -1752,7 +1752,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: return df try: - df = self.database.get_df(sql, self.schema, mutator=assign_column_label) + df = self.database.get_df( + sql, self.schema or None, mutator=assign_column_label + ) except (SupersetErrorException, SupersetErrorsException) as ex: # SupersetError(s) exception should not be captured; instead, they should # bubble up to the Flask error handler so they are returned as proper SIP-40 @@ -1789,7 +1791,7 @@ def get_sqla_table_object(self) -> Table: return self.database.get_table( Table( self.table_name, - self.schema, + self.schema or None, self.catalog, ) ) @@ -1807,7 +1809,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: for metric in self.database.get_metrics( Table( self.table_name, - self.schema, + self.schema or None, self.catalog, ) ) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index c73e46f33c6f9..1df52fbe27b57 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -640,11 +640,11 @@ def supports_backend(cls, backend: str, driver: str | None = None) -> bool: return driver in cls.drivers @classmethod - def get_default_schema(cls, database: Database) -> str | None: + def get_default_schema(cls, database: Database, catalog: str | None) -> str | None: """ Return the default schema in a given database. """ - with database.get_inspector() as inspector: + with database.get_inspector(catalog=catalog) as inspector: return inspector.default_schema_name @classmethod @@ -699,7 +699,7 @@ def get_default_schema_for_query( return schema # return the default schema of the database - return cls.get_default_schema(database) + return cls.get_default_schema(database, query.catalog) @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 3245cdca4b73b..08a38894e6645 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -40,12 +40,12 @@ ) from superset.db_engine_specs.presto import PrestoBaseEngineSpec from superset.models.sql_lab import Query +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: from superset.models.core import Database - from superset.sql_parse import Table with contextlib.suppress(ImportError): # trino may not be installed from trino.dbapi import Cursor @@ -96,8 +96,11 @@ def get_extra_table_metadata( ), } - if database.has_view_by_name(table.table, table.schema): - with database.get_inspector() as inspector: + if database.has_view(Table(table.table, table.schema)): + with database.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: metadata["view"] = inspector.get_view_definition( table.table, table.schema, diff --git a/tests/integration_tests/databases/commands/csv_upload_test.py b/tests/integration_tests/databases/commands/csv_upload_test.py index 18cc6f4a8da78..99182ff29740d 100644 --- a/tests/integration_tests/databases/commands/csv_upload_test.py +++ b/tests/integration_tests/databases/commands/csv_upload_test.py @@ -85,7 +85,7 @@ def _setup_csv_upload(allowed_schemas: list[str] | None = None): yield upload_db = get_upload_db() - with upload_db.get_sqla_engine_with_context() as engine: + with upload_db.get_sqla_engine() as engine: engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") db.session.delete(upload_db) @@ -221,7 +221,7 @@ def test_csv_upload_options(csv_data, options, table_data): create_csv_file(csv_data), options=options, ).run() - with upload_database.get_sqla_engine_with_context() as engine: + with upload_database.get_sqla_engine() as engine: data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() assert data == table_data diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 142b7c085eca3..0465503a86a0f 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -773,13 +773,13 @@ def test_create_dataset_validate_tables_exists(self): @patch("superset.models.core.Database.get_columns") @patch("superset.models.core.Database.has_table") - @patch("superset.models.core.Database.has_view_by_name") + @patch("superset.models.core.Database.has_view") @patch("superset.models.core.Database.get_table") def test_create_dataset_validate_view_exists( self, mock_get_table, mock_has_table, - mock_has_view_by_name, + mock_has_view, mock_get_columns, ): """ @@ -796,7 +796,7 @@ def test_create_dataset_validate_view_exists( ] mock_has_table.return_value = False - mock_has_view_by_name.return_value = True + mock_has_view.return_value = True mock_get_table.return_value = None example_db = get_example_database() diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index ababce38e5c69..47d649a32eda3 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -30,7 +30,7 @@ from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table from superset.utils.database import get_example_database from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.test_app import app @@ -238,7 +238,7 @@ def test_get_table_names(self): @pytest.mark.usefixtures("load_energy_table_with_slice") def test_column_datatype_to_string(self): example_db = get_example_database() - sqla_table = example_db.get_table("energy_usage") + sqla_table = example_db.get_table(Table("energy_usage")) dialect = example_db.get_dialect() # TODO: fix column type conversion for presto. @@ -540,8 +540,7 @@ def test_get_indexes(): BaseEngineSpec.get_indexes( database=mock.Mock(), inspector=inspector, - table_name="bar", - schema="foo", + table=Table("bar", "foo"), ) == indexes ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index c012376b3999e..dc5264aa791b6 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -165,8 +165,7 @@ def test_get_indexes(self): BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [] ) @@ -184,8 +183,7 @@ def test_get_indexes(self): assert BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [ { "name": "partition", @@ -207,8 +205,7 @@ def test_get_indexes(self): assert BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [ { "name": "partition", diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index d4b2e14d5820f..ff2a9050d5a68 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -23,7 +23,7 @@ from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException -from superset.sql_parse import Table, ParsedQuery +from superset.sql_parse import ParsedQuery, Table from tests.integration_tests.test_app import app @@ -344,7 +344,10 @@ def test_where_latest_partition(mock_method): columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", database, select(), columns + database, + Table("test_table", "test_schema"), + select(), + columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result @@ -357,7 +360,10 @@ def test_where_latest_partition_super_method_exception(mock_method): columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", database, select(), columns + database, + Table("test_table", "test_schema"), + select(), + columns, ) assert result is None mock_method.assert_called() @@ -369,7 +375,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method): db = mock.Mock() with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select() + db, + Table("test_table", "test_schema"), + select(), ) assert result is None diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 2bbd911668efb..8fb751620ae24 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -82,7 +82,7 @@ def verify_presto_column(self, column, expected_results): row = mock.Mock() row.Column, row.Type, row.Null = column inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row]) - results = PrestoEngineSpec.get_columns(inspector, "", "") + results = PrestoEngineSpec.get_columns(inspector, Table("", "")) self.assertEqual(len(expected_results), len(results)) for expected_result, result in zip(expected_results, results): self.assertEqual(expected_result[0], result["column_name"]) @@ -573,7 +573,10 @@ def test_presto_where_latest_partition(self): db.get_df = mock.Mock(return_value=df) columns = [{"name": "ds"}, {"name": "hour"}] result = PrestoEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select(), columns + db, + Table("test_table", "test_schema"), + select(), + columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result) @@ -802,7 +805,7 @@ def test_show_columns(self): return_value=["a", "b"] ) table_name = "table_name" - result = PrestoEngineSpec._show_columns(inspector, table_name, None) + result = PrestoEngineSpec._show_columns(inspector, Table(table_name)) assert result == ["a", "b"] inspector.bind.execute.assert_called_once_with( f'SHOW COLUMNS FROM "{table_name}"' @@ -818,7 +821,7 @@ def test_show_columns_with_schema(self): ) table_name = "table_name" schema = "schema" - result = PrestoEngineSpec._show_columns(inspector, table_name, schema) + result = PrestoEngineSpec._show_columns(inspector, Table(table_name, schema)) assert result == ["a", "b"] inspector.bind.execute.assert_called_once_with( f'SHOW COLUMNS FROM "{schema}"."{table_name}"' @@ -848,7 +851,14 @@ def test_select_star_no_presto_expand_data(self, mock_select_star): ] PrestoEngineSpec.select_star(database, Table(table_name), engine, cols=cols) mock_select_star.assert_called_once_with( - database, table_name, engine, None, 100, False, True, True, cols + database, + Table(table_name), + engine, + 100, + False, + True, + True, + cols, ) @mock.patch("superset.db_engine_specs.presto.is_feature_enabled") @@ -877,9 +887,8 @@ def test_select_star_presto_expand_data( ) mock_select_star.assert_called_once_with( database, - table_name, + Table(table_name), engine, - None, 100, True, True, diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 0a1902e78730f..1036e5a0154fe 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -311,15 +311,15 @@ def test_convert_dttm( assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm) -def test_get_extra_table_metadata() -> None: +def test_get_extra_table_metadata(mocker: MockerFixture) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec - db_mock = Mock() + db_mock = mocker.MagicMock() db_mock.get_indexes = Mock( return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] ) db_mock.get_extra = Mock(return_value={}) - db_mock.has_view_by_name = Mock(return_value=None) + db_mock.has_view = Mock(return_value=None) db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})) result = TrinoEngineSpec.get_extra_table_metadata( db_mock,