Skip to content

Commit

Permalink
chore: Change get_table_names/get_view_names return type (#22085)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Nov 18, 2022
1 parent e990690 commit 7e54b88
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 62 deletions.
16 changes: 8 additions & 8 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def get_table_names( # pylint: disable=unused-argument
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the real table names within the specified schema.
Expand All @@ -1048,21 +1048,21 @@ def get_table_names( # pylint: disable=unused-argument
"""

try:
tables = inspector.get_table_names(schema)
tables = set(inspector.get_table_names(schema))
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex

if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f"^{schema}\\.", "", table) for table in tables]
return sorted(tables)
tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
return tables

@classmethod
def get_view_names( # pylint: disable=unused-argument
cls,
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the view names within the specified schema.
Expand All @@ -1076,13 +1076,13 @@ def get_view_names( # pylint: disable=unused-argument
"""

try:
views = inspector.get_view_names(schema)
views = set(inspector.get_view_names(schema))
except Exception as ex:
raise cls.get_dbapi_mapped_exception(ex) from ex

if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f"^{schema}\\.", "", view) for view in views]
return sorted(views)
views = {re.sub(f"^{schema}\\.", "", view) for view in views}
return views

@classmethod
def get_table_comment(
Expand Down
12 changes: 5 additions & 7 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, Optional, Set, TYPE_CHECKING

from sqlalchemy.engine.reflection import Inspector

Expand Down Expand Up @@ -103,9 +103,7 @@ def get_table_names(
database: "Database",
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
tables = set(super().get_table_names(database, inspector, schema))
views = set(cls.get_view_names(database, inspector, schema))
actual_tables = tables - views

return list(actual_tables)
) -> Set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
6 changes: 3 additions & 3 deletions superset/db_engine_specs/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -75,5 +75,5 @@ def convert_dttm(
@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
) -> List[str]:
return sorted(inspector.get_table_names(schema))
) -> Set[str]:
return set(inspector.get_table_names(schema))
10 changes: 5 additions & 5 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
Expand Down Expand Up @@ -228,11 +228,11 @@ def query_cost_formatter(
@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
) -> List[str]:
) -> Set[str]:
"""Need to consider foreign tables for PostgreSQL"""
tables = inspector.get_table_names(schema)
tables.extend(inspector.get_foreign_table_names(schema))
return sorted(tables)
return set(inspector.get_table_names(schema)) | set(
inspector.get_foreign_table_names(schema)
)

@classmethod
def convert_dttm(
Expand Down
28 changes: 18 additions & 10 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@
from datetime import datetime
from distutils.version import StrictVersion
from textwrap import dedent
from typing import Any, cast, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
cast,
Dict,
List,
Optional,
Pattern,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from urllib import parse

import pandas as pd
Expand Down Expand Up @@ -396,7 +407,7 @@ def get_table_names(
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the real table names within the specified schema.
Expand All @@ -414,20 +425,17 @@ def get_table_names(
:returns: The physical table names
"""

return sorted(
list(
set(super().get_table_names(database, inspector, schema))
- set(cls.get_view_names(database, inspector, schema))
)
)
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)

@classmethod
def get_view_names(
cls,
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> List[str]:
) -> Set[str]:
"""
Get all the view names within the specified schema.
Expand Down Expand Up @@ -468,7 +476,7 @@ def get_view_names(
cursor.execute(sql, params)
results = cursor.fetchall()

return sorted([row[0] for row in results])
return {row[0] for row in results}

@classmethod
def _create_column_info(
Expand Down
6 changes: 3 additions & 3 deletions superset/db_engine_specs/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -88,6 +88,6 @@ def convert_dttm(
@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
) -> Set[str]:
"""Need to disregard the schema for Sqlite"""
return sorted(inspector.get_table_names())
return set(inspector.get_table_names())
32 changes: 20 additions & 12 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[Tuple[str, str]]:
) -> Set[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
Expand All @@ -553,13 +553,17 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:return: list of tables
:return: The table/schema pairs
"""
try:
tables = self.db_engine_spec.get_table_names(
database=self, inspector=self.inspector, schema=schema
)
return [(table, schema) for table in tables]
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

Expand All @@ -573,7 +577,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
) -> List[Tuple[str, str]]:
) -> Set[Tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
Expand All @@ -583,13 +587,17 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:return: list of views
:return: set of views
"""
try:
views = self.db_engine_spec.get_view_names(
database=self, inspector=self.inspector, schema=schema
)
return [(view, schema) for view in views]
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)

Expand Down
8 changes: 4 additions & 4 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,29 +1173,29 @@ def tables( # pylint: disable=no-self-use
tables = security_manager.get_datasources_accessible_by_user(
database=database,
schema=schema_parsed,
datasource_names=[
datasource_names=sorted(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_table_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
],
),
)

views = security_manager.get_datasources_accessible_by_user(
database=database,
schema=schema_parsed,
datasource_names=[
datasource_names=sorted(
utils.DatasourceName(*datasource_name)
for datasource_name in database.get_all_view_names_in_schema(
schema=schema_parsed,
force=force_refresh_parsed,
cache=database.table_cache_enabled,
cache_timeout=database.table_cache_timeout,
)
],
),
)
except SupersetException as ex:
return json_error_response(ex.message, ex.status)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def test_create_dataset_validate_view_exists(
with patch.object(
dialect, "get_view_names", wraps=dialect.get_view_names
) as patch_get_view_names:
patch_get_view_names.return_value = ["test_case_view"]
patch_get_view_names.return_value = {"test_case_view"}

self.login(username="admin")
table_data = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def test_get_table_names(self):

""" Make sure base engine spec removes schema name from table name
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result_expected = {"table", "table_2"}
base_result = BaseEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)
assert base_result_expected == base_result

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def test_get_table_names(self):
inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])

pg_result_expected = ["schema.table", "table_2", "table_3"]
pg_result_expected = {"schema.table", "table_2", "table_3"}
pg_result = PostgresEngineSpec.get_table_names(
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(pg_result_expected, pg_result)
assert pg_result_expected == pg_result

def test_time_exp_literal_no_grain(self):
"""
Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_get_view_names_with_schema(self):
).strip(),
{"schema": schema},
)
assert result == ["a", "d"]
assert result == {"a", "d"}

def test_get_view_names_without_schema(self):
database = mock.MagicMock()
Expand All @@ -77,7 +77,7 @@ def test_get_view_names_without_schema(self):
).strip(),
{},
)
assert result == ["a", "d"]
assert result == {"a", "d"}

def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
Expand Down Expand Up @@ -670,10 +670,10 @@ def test_get_table_names(
mock_get_view_names,
mock_get_table_names,
):
mock_get_view_names.return_value = ["view1", "view2"]
mock_get_table_names.return_value = ["table1", "table2", "view1", "view2"]
mock_get_view_names.return_value = {"view1", "view2"}
mock_get_table_names.return_value = {"table1", "table2", "view1", "view2"}
tables = PrestoEngineSpec.get_table_names(mock.Mock(), mock.Mock(), None)
assert tables == ["table1", "table2"]
assert tables == {"table1", "table2"}

def test_get_full_name(self):
names = [
Expand Down

0 comments on commit 7e54b88

Please sign in to comment.