Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Apr 25, 2024
1 parent 8cdf38b commit 3cf1ac6
Show file tree
Hide file tree
Showing 17 changed files with 76 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ jest.mock(
<div data-test="mock-column-element">{column.name}</div>
),
);
const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*';
const getTableMetadataEndpoint =
/\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/;
const getExtraTableMetadataEndpoint =
'glob:**/api/v1/database/*/table_metadata/extra/*';
/\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/;
const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded';

beforeEach(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const DatasetPanelWrapper = ({
setLoading(true);
setHasColumns?.(false);
const path = schema
? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/`
? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}`
: `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`;
try {
const response = await SupersetClient.get({
Expand Down
20 changes: 11 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def short_data(self) -> dict[str, Any]:
"edit_url": self.url,
"id": self.id,
"uid": self.uid,
"schema": self.schema,
"schema": self.schema or None,
"name": self.name,
"type": self.type,
"connection": self.connection,
Expand Down Expand Up @@ -383,7 +383,7 @@ def data(self) -> dict[str, Any]:
"datasource_name": self.datasource_name,
"table_name": self.datasource_name,
"type": self.type,
"schema": self.schema,
"schema": self.schema or None,
"offset": self.offset,
"cache_timeout": self.cache_timeout,
"params": self.params,
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def link(self) -> Markup:

def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
return security_manager.get_schema_perm(self.database, self.schema or None)

def get_perm(self) -> str:
"""
Expand Down Expand Up @@ -1320,7 +1320,7 @@ def external_metadata(self) -> list[ResultSetColumnType]:
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database,
table=Table(self.table_name, self.schema, self.catalog),
table=Table(self.table_name, self.schema or None, self.catalog),
normalize_columns=self.normalize_columns,
)

Expand All @@ -1336,7 +1336,7 @@ def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
Table(self.table_name, self.schema, self.catalog),
Table(self.table_name, self.schema or None, self.catalog),
show_cols=False,
latest_partition=False,
)
Expand Down Expand Up @@ -1528,7 +1528,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
col_desc = get_columns_description(
self.database,
self.catalog,
self.schema,
self.schema or None,
sql,
)
if not col_desc:
Expand Down Expand Up @@ -1735,7 +1735,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
return df

try:
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
df = self.database.get_df(
sql, self.schema or None, mutator=assign_column_label
)
except (SupersetErrorException, SupersetErrorsException) as ex:
# SupersetError(s) exception should not be captured; instead, they should
# bubble up to the Flask error handler so they are returned as proper SIP-40
Expand Down Expand Up @@ -1772,7 +1774,7 @@ def get_sqla_table_object(self) -> Table:
return self.database.get_table(
Table(
self.table_name,
self.schema,
self.schema or None,
self.catalog,
)
)
Expand All @@ -1790,7 +1792,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
for metric in self.database.get_metrics(
Table(
self.table_name,
self.schema,
self.schema or None,
self.catalog,
)
)
Expand Down
6 changes: 3 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,11 @@ def supports_backend(cls, backend: str, driver: str | None = None) -> bool:
return driver in cls.drivers

@classmethod
def get_default_schema(cls, database: Database) -> str | None:
def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
Return the default schema in a given database.
"""
with database.get_inspector() as inspector:
with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name

@classmethod
Expand Down Expand Up @@ -698,7 +698,7 @@ def get_default_schema_for_query(
return schema

# return the default schema of the database
return cls.get_default_schema(database)
return cls.get_default_schema(database, query.catalog)

@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/db2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def get_table_comment(
Ibm Db2 return comments as tuples, so we need to get the first element
:param inspector: SqlAlchemy Inspector instance
:param table_name: Table name
:param schema: Schema name. If omitted, uses default schema for database
:param table: Table instance
:return: comment of table
"""
comment = None
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,8 +969,7 @@ def _show_columns(
"""
Show presto column names
:param inspector: object that performs database schema inspection
:param table_name: table name
:param schema: schema name
:param table: table instance
:return: list of column objects
"""
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
Expand All @@ -990,8 +989,7 @@ def get_columns(
Get columns from a Presto data source. This includes handling row and
array data types
:param inspector: object that performs database schema inspection
:param table_name: table name
:param schema: schema name
:param table: table instance
:param options: Extra configuration options, not used by this backend
:return: a list of results that contain column info
(i.e. column name and data type)
Expand Down
9 changes: 6 additions & 3 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
)
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.models.core import Database
from superset.sql_parse import Table

with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
Expand Down Expand Up @@ -96,8 +96,11 @@ def get_extra_table_metadata(
),
}

if database.has_view_by_name(table.table, table.schema):
with database.get_inspector() as inspector:
if database.has_view(Table(table.table, table.schema)):
with database.get_inspector(
catalog=table.catalog,
schema=table.schema,
) as inspector:
metadata["view"] = inspector.get_view_definition(
table.table,
table.schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Add catalog column
Revision ID: 5f57af97bc3f
Revises: 5ad7321c2169
Revises: d60591c5515f
Create Date: 2024-04-11 15:41:34.663989
"""
Expand All @@ -27,7 +27,7 @@

# revision identifiers, used by Alembic.
revision = "5f57af97bc3f"
down_revision = "5ad7321c2169"
down_revision = "d60591c5515f"


def upgrade():
Expand Down
4 changes: 2 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=line-too-long, too-many-lines, too-many-arguments
# pylint: disable=too-many-lines, too-many-arguments

"""A collection of ORM sqlalchemy models for Superset"""

Expand Down Expand Up @@ -607,7 +607,7 @@ def mutate_sql_based_on_config(self, sql_: str, is_split: bool = False) -> str:
)
return sql_

def get_df(
def get_df( # pylint: disable=too-many-locals
self,
sql: str,
catalog: str | None = None,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/databases/commands/upload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _setup_csv_upload(allowed_schemas: list[str] | None = None):
yield

upload_db = get_upload_db()
with upload_db.get_sqla_engine_with_context() as engine:
with upload_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
db.session.delete(upload_db)
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_csv_upload_with_nulls():
None,
CSVReader({"null_values": ["N/A", "None"]}),
).run()
with upload_database.get_sqla_engine_with_context() as engine:
with upload_database.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [
("name1", None, "city1", "1-1-1980"),
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,13 @@ def test_create_dataset_validate_tables_exists(self):

@patch("superset.models.core.Database.get_columns")
@patch("superset.models.core.Database.has_table")
@patch("superset.models.core.Database.has_view_by_name")
@patch("superset.models.core.Database.has_view")
@patch("superset.models.core.Database.get_table")
def test_create_dataset_validate_view_exists(
self,
mock_get_table,
mock_has_table,
mock_has_view_by_name,
mock_has_view,
mock_get_columns,
):
"""
Expand All @@ -797,7 +797,7 @@ def test_create_dataset_validate_view_exists(
]

mock_has_table.return_value = False
mock_has_view_by_name.return_value = True
mock_has_view.return_value = True
mock_get_table.return_value = None

example_db = get_example_database()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_get_table_names(self):
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()
sqla_table = example_db.get_table("energy_usage")
sqla_table = example_db.get_table(Table("energy_usage"))
dialect = example_db.get_dialect()

# TODO: fix column type conversion for presto.
Expand Down Expand Up @@ -540,8 +540,7 @@ def test_get_indexes():
BaseEngineSpec.get_indexes(
database=mock.Mock(),
inspector=inspector,
table_name="bar",
schema="foo",
table=Table("bar", "foo"),
)
== indexes
)
9 changes: 3 additions & 6 deletions tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def test_get_indexes(self):
BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
Table(table_name, schema),
)
== []
)
Expand All @@ -184,8 +183,7 @@ def test_get_indexes(self):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
Table(table_name, schema),
) == [
{
"name": "partition",
Expand All @@ -207,8 +205,7 @@ def test_get_indexes(self):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
Table(table_name, schema),
) == [
{
"name": "partition",
Expand Down
16 changes: 12 additions & 4 deletions tests/integration_tests/db_engine_specs/hive_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table, ParsedQuery
from superset.sql_parse import ParsedQuery, Table
from tests.integration_tests.test_app import app


Expand Down Expand Up @@ -328,7 +328,10 @@ def test_where_latest_partition(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", database, select(), columns
database,
Table("test_table", "test_schema"),
select(),
columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
Expand All @@ -341,7 +344,10 @@ def test_where_latest_partition_super_method_exception(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", database, select(), columns
database,
Table("test_table", "test_schema"),
select(),
columns,
)
assert result is None
mock_method.assert_called()
Expand All @@ -353,7 +359,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method):
db = mock.Mock()
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select()
db,
Table("test_table", "test_schema"),
select(),
)
assert result is None

Expand Down
Loading

0 comments on commit 3cf1ac6

Please sign in to comment.