From 3cf1ac6fcfb25d6ec2f29f3927a971d0d41c5db4 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 18 Apr 2024 14:43:24 -0400 Subject: [PATCH] Small fixes --- .../TableElement/TableElement.test.tsx | 5 ++-- .../AddDataset/DatasetPanel/index.tsx | 2 +- superset/connectors/sqla/models.py | 20 ++++++++-------- superset/db_engine_specs/base.py | 6 ++--- superset/db_engine_specs/db2.py | 3 +-- superset/db_engine_specs/presto.py | 6 ++--- superset/db_engine_specs/trino.py | 9 +++++--- ...1_15-41_5f57af97bc3f_add_catalog_column.py | 4 ++-- superset/models/core.py | 4 ++-- .../databases/commands/upload_test.py | 4 ++-- tests/integration_tests/datasets/api_tests.py | 6 ++--- .../db_engine_specs/base_engine_spec_tests.py | 7 +++--- .../db_engine_specs/bigquery_tests.py | 9 +++----- .../db_engine_specs/hive_tests.py | 16 +++++++++---- .../db_engine_specs/presto_tests.py | 23 +++++++++++++------ tests/unit_tests/db_engine_specs/test_db2.py | 5 ++-- .../unit_tests/db_engine_specs/test_trino.py | 6 ++--- 17 files changed, 76 insertions(+), 59 deletions(-) diff --git a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx index 7b95b6d0f3492..1489f23a13a06 100644 --- a/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx +++ b/superset-frontend/src/SqlLab/components/TableElement/TableElement.test.tsx @@ -47,9 +47,10 @@ jest.mock(
{column.name}
), ); -const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*'; +const getTableMetadataEndpoint = + /\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/; const getExtraTableMetadataEndpoint = - 'glob:**/api/v1/database/*/table_metadata/extra/*'; + /\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/; const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded'; beforeEach(() => { diff --git a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx index c964fc32faaf0..b3f8aec8f99a0 100644 --- a/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx +++ b/superset-frontend/src/features/datasets/AddDataset/DatasetPanel/index.tsx @@ -75,7 +75,7 @@ const DatasetPanelWrapper = ({ setLoading(true); setHasColumns?.(false); const path = schema - ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/` + ? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}` : `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`; try { const response = await SupersetClient.get({ diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e48f0b9bd8eca..719d5af588852 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -329,7 +329,7 @@ def short_data(self) -> dict[str, Any]: "edit_url": self.url, "id": self.id, "uid": self.uid, - "schema": self.schema, + "schema": self.schema or None, "name": self.name, "type": self.type, "connection": self.connection, @@ -383,7 +383,7 @@ def data(self) -> dict[str, Any]: "datasource_name": self.datasource_name, "table_name": self.datasource_name, "type": self.type, - "schema": self.schema, + "schema": self.schema or None, "offset": self.offset, "cache_timeout": self.cache_timeout, "params": self.params, @@ -1263,7 +1263,7 @@ def link(self) -> Markup: def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" - return security_manager.get_schema_perm(self.database, self.schema) + return security_manager.get_schema_perm(self.database, self.schema or None) def get_perm(self) -> str: """ @@ -1320,7 +1320,7 @@ def external_metadata(self) -> list[ResultSetColumnType]: return get_virtual_table_metadata(dataset=self) return get_physical_table_metadata( database=self.database, - table=Table(self.table_name, self.schema, self.catalog), + table=Table(self.table_name, self.schema or None, self.catalog), normalize_columns=self.normalize_columns, ) @@ -1336,7 +1336,7 @@ def select_star(self) -> str | None: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( - Table(self.table_name, self.schema, self.catalog), + Table(self.table_name, self.schema or None, self.catalog), show_cols=False, latest_partition=False, ) @@ -1528,7 +1528,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals col_desc = get_columns_description( self.database, self.catalog, - self.schema, + self.schema or None, sql, ) if not col_desc: @@ -1735,7 +1735,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: return df try: - df = self.database.get_df(sql, self.schema, mutator=assign_column_label) + df = self.database.get_df( + sql, self.schema or None, mutator=assign_column_label + ) except (SupersetErrorException, SupersetErrorsException) as ex: # SupersetError(s) exception should not be captured; instead, they should # bubble up to the Flask error handler so they are returned as proper SIP-40 @@ -1772,7 +1774,7 @@ def get_sqla_table_object(self) -> Table: return self.database.get_table( Table( self.table_name, - self.schema, + self.schema or None, self.catalog, ) ) @@ -1790,7 +1792,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: for metric in self.database.get_metrics( Table( self.table_name, - self.schema, + self.schema or None, self.catalog, ) ) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index f458b62165143..3cc1315129571 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -639,11 +639,11 @@ def supports_backend(cls, backend: str, driver: str | None = None) -> bool: return driver in cls.drivers @classmethod - def get_default_schema(cls, database: Database) -> str | None: + def get_default_schema(cls, database: Database, catalog: str | None) -> str | None: """ Return the default schema in a given database. """ - with database.get_inspector() as inspector: + with database.get_inspector(catalog=catalog) as inspector: return inspector.default_schema_name @classmethod @@ -698,7 +698,7 @@ def get_default_schema_for_query( return schema # return the default schema of the database - return cls.get_default_schema(database) + return cls.get_default_schema(database, query.catalog) @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index 8a04ee5d3b0f2..b2151767d2d72 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -75,8 +75,7 @@ def get_table_comment( Ibm Db2 return comments as tuples, so we need to get the first element :param inspector: SqlAlchemy Inspector instance - :param table_name: Table name - :param schema: Schema name. If omitted, uses default schema for database + :param table: Table instance :return: comment of table """ comment = None diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 4b7efcfe4921d..34c47eb522c00 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -969,8 +969,7 @@ def _show_columns( """ Show presto column names :param inspector: object that performs database schema inspection - :param table_name: table name - :param schema: schema name + :param table: table instance :return: list of column objects """ quote = inspector.engine.dialect.identifier_preparer.quote_identifier @@ -990,8 +989,7 @@ def get_columns( Get columns from a Presto data source. This includes handling row and array data types :param inspector: object that performs database schema inspection - :param table_name: table name - :param schema: schema name + :param table: table instance :param options: Extra configuration options, not used by this backend :return: a list of results that contain column info (i.e. column name and data type) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 3245cdca4b73b..08a38894e6645 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -40,12 +40,12 @@ ) from superset.db_engine_specs.presto import PrestoBaseEngineSpec from superset.models.sql_lab import Query +from superset.sql_parse import Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: from superset.models.core import Database - from superset.sql_parse import Table with contextlib.suppress(ImportError): # trino may not be installed from trino.dbapi import Cursor @@ -96,8 +96,11 @@ def get_extra_table_metadata( ), } - if database.has_view_by_name(table.table, table.schema): - with database.get_inspector() as inspector: + if database.has_view(Table(table.table, table.schema)): + with database.get_inspector( + catalog=table.catalog, + schema=table.schema, + ) as inspector: metadata["view"] = inspector.get_view_definition( table.table, table.schema, diff --git a/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py index 5fa35fb963b2e..ec5733e151044 100644 --- a/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py +++ b/superset/migrations/versions/2024-04-11_15-41_5f57af97bc3f_add_catalog_column.py @@ -17,7 +17,7 @@ """Add catalog column Revision ID: 5f57af97bc3f -Revises: 5ad7321c2169 +Revises: d60591c5515f Create Date: 2024-04-11 15:41:34.663989 """ @@ -27,7 +27,7 @@ # revision identifiers, used by Alembic. revision = "5f57af97bc3f" -down_revision = "5ad7321c2169" +down_revision = "d60591c5515f" def upgrade(): diff --git a/superset/models/core.py b/superset/models/core.py index 509e29ff475b4..9a4a1de40376c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=line-too-long, too-many-lines, too-many-arguments +# pylint: disable=too-many-lines, too-many-arguments """A collection of ORM sqlalchemy models for Superset""" @@ -607,7 +607,7 @@ def mutate_sql_based_on_config(self, sql_: str, is_split: bool = False) -> str: ) return sql_ - def get_df( + def get_df( # pylint: disable=too-many-locals self, sql: str, catalog: str | None = None, diff --git a/tests/integration_tests/databases/commands/upload_test.py b/tests/integration_tests/databases/commands/upload_test.py index 26379aa9769fb..1af85c3ab1fe1 100644 --- a/tests/integration_tests/databases/commands/upload_test.py +++ b/tests/integration_tests/databases/commands/upload_test.py @@ -73,7 +73,7 @@ def _setup_csv_upload(allowed_schemas: list[str] | None = None): yield upload_db = get_upload_db() - with upload_db.get_sqla_engine_with_context() as engine: + with upload_db.get_sqla_engine() as engine: engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}") engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}") db.session.delete(upload_db) @@ -107,7 +107,7 @@ def test_csv_upload_with_nulls(): None, CSVReader({"null_values": ["N/A", "None"]}), ).run() - with upload_database.get_sqla_engine_with_context() as engine: + with upload_database.get_sqla_engine() as engine: data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() assert data == [ ("name1", None, "city1", "1-1-1980"), diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 359547118e9b0..c10d589d97fd9 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -774,13 +774,13 @@ def test_create_dataset_validate_tables_exists(self): @patch("superset.models.core.Database.get_columns") @patch("superset.models.core.Database.has_table") - @patch("superset.models.core.Database.has_view_by_name") + @patch("superset.models.core.Database.has_view") @patch("superset.models.core.Database.get_table") def test_create_dataset_validate_view_exists( self, mock_get_table, mock_has_table, - mock_has_view_by_name, + mock_has_view, mock_get_columns, ): """ @@ -797,7 +797,7 @@ def test_create_dataset_validate_view_exists( ] mock_has_table.return_value = False - mock_has_view_by_name.return_value = True + mock_has_view.return_value = True mock_get_table.return_value = None example_db = get_example_database() diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index d7498dc4fee84..c8db1f912ad21 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -30,7 +30,7 @@ from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.sqlite import SqliteEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table from superset.utils.database import get_example_database from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec from tests.integration_tests.test_app import app @@ -238,7 +238,7 @@ def test_get_table_names(self): @pytest.mark.usefixtures("load_energy_table_with_slice") def test_column_datatype_to_string(self): example_db = get_example_database() - sqla_table = example_db.get_table("energy_usage") + sqla_table = example_db.get_table(Table("energy_usage")) dialect = example_db.get_dialect() # TODO: fix column type conversion for presto. @@ -540,8 +540,7 @@ def test_get_indexes(): BaseEngineSpec.get_indexes( database=mock.Mock(), inspector=inspector, - table_name="bar", - schema="foo", + table=Table("bar", "foo"), ) == indexes ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index ce184685db540..53f9137076bb8 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -165,8 +165,7 @@ def test_get_indexes(self): BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [] ) @@ -184,8 +183,7 @@ def test_get_indexes(self): assert BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [ { "name": "partition", @@ -207,8 +205,7 @@ def test_get_indexes(self): assert BigQueryEngineSpec.get_indexes( database, inspector, - table_name, - schema, + Table(table_name, schema), ) == [ { "name": "partition", diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 39d2c30fd1162..4d1a84508167b 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -23,7 +23,7 @@ from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3 from superset.exceptions import SupersetException -from superset.sql_parse import Table, ParsedQuery +from superset.sql_parse import ParsedQuery, Table from tests.integration_tests.test_app import app @@ -328,7 +328,10 @@ def test_where_latest_partition(mock_method): columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", database, select(), columns + database, + Table("test_table", "test_schema"), + select(), + columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result @@ -341,7 +344,10 @@ def test_where_latest_partition_super_method_exception(mock_method): columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", database, select(), columns + database, + Table("test_table", "test_schema"), + select(), + columns, ) assert result is None mock_method.assert_called() @@ -353,7 +359,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method): db = mock.Mock() with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select() + db, + Table("test_table", "test_schema"), + select(), ) assert result is None diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 3f7bc52a57d2a..607afa6953fcd 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -82,7 +82,7 @@ def verify_presto_column(self, column, expected_results): row = mock.Mock() row.Column, row.Type, row.Null = column inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row]) - results = PrestoEngineSpec.get_columns(inspector, "", "") + results = PrestoEngineSpec.get_columns(inspector, Table("", "")) self.assertEqual(len(expected_results), len(results)) for expected_result, result in zip(expected_results, results): self.assertEqual(expected_result[0], result["column_name"]) @@ -573,7 +573,10 @@ def test_presto_where_latest_partition(self): db.get_df = mock.Mock(return_value=df) columns = [{"name": "ds"}, {"name": "hour"}] result = PrestoEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select(), columns + db, + Table("test_table", "test_schema"), + select(), + columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result) @@ -802,7 +805,7 @@ def test_show_columns(self): return_value=["a", "b"] ) table_name = "table_name" - result = PrestoEngineSpec._show_columns(inspector, table_name, None) + result = PrestoEngineSpec._show_columns(inspector, Table(table_name)) assert result == ["a", "b"] inspector.bind.execute.assert_called_once_with( f'SHOW COLUMNS FROM "{table_name}"' @@ -818,7 +821,7 @@ def test_show_columns_with_schema(self): ) table_name = "table_name" schema = "schema" - result = PrestoEngineSpec._show_columns(inspector, table_name, schema) + result = PrestoEngineSpec._show_columns(inspector, Table(table_name, schema)) assert result == ["a", "b"] inspector.bind.execute.assert_called_once_with( f'SHOW COLUMNS FROM "{schema}"."{table_name}"' @@ -848,7 +851,14 @@ def test_select_star_no_presto_expand_data(self, mock_select_star): ] PrestoEngineSpec.select_star(database, Table(table_name), engine, cols=cols) mock_select_star.assert_called_once_with( - database, table_name, engine, None, 100, False, True, True, cols + database, + Table(table_name), + engine, + 100, + False, + True, + True, + cols, ) @mock.patch("superset.db_engine_specs.presto.is_feature_enabled") @@ -877,9 +887,8 @@ def test_select_star_presto_expand_data( ) mock_select_star.assert_called_once_with( database, - table_name, + Table(table_name), engine, - None, 100, True, True, diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py index fa215357951b7..017fcd7b80e7e 100644 --- a/tests/unit_tests/db_engine_specs/test_db2.py +++ b/tests/unit_tests/db_engine_specs/test_db2.py @@ -60,8 +60,9 @@ def test_get_table_comment_empty(mocker: MockerFixture): mock_inspector = mocker.MagicMock() mock_inspector.get_table_comment.return_value = {} - assert Db2EngineSpec.get_table_comment( - mock_inspector, Table("my_table", "my_schema") + assert ( + Db2EngineSpec.get_table_comment(mock_inspector, Table("my_table", "my_schema")) + is None ) diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index d7aeaf1c5f036..5bd83828ed2c6 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -311,15 +311,15 @@ def test_convert_dttm( assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm) -def test_get_extra_table_metadata() -> None: +def test_get_extra_table_metadata(mocker: MockerFixture) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec - db_mock = Mock() + db_mock = mocker.MagicMock() db_mock.get_indexes = Mock( return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] ) db_mock.get_extra = Mock(return_value={}) - db_mock.has_view_by_name = Mock(return_value=None) + db_mock.has_view = Mock(return_value=None) db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})) result = TrinoEngineSpec.get_extra_table_metadata( db_mock,