Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Apr 22, 2024
1 parent 239f827 commit 1f73f2f
Show file tree
Hide file tree
Showing 14 changed files with 70 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ jest.mock(
<div data-test="mock-column-element">{column.name}</div>
),
);
const getTableMetadataEndpoint = 'glob:**/api/v1/database/*/table_metadata/*';
const getTableMetadataEndpoint =
/\/api\/v1\/database\/\d+\/table_metadata\/(?:\?.*)?$/;
const getExtraTableMetadataEndpoint =
'glob:**/api/v1/database/*/table_metadata/extra/*';
/\/api\/v1\/database\/\d+\/table_metadata\/extra\/(?:\?.*)?$/;
const updateTableSchemaEndpoint = 'glob:*/tableschemaview/*/expanded';

beforeEach(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const DatasetPanelWrapper = ({
setLoading(true);
setHasColumns?.(false);
const path = schema
? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}/`
? `/api/v1/database/${dbId}/table_metadata/?name=${tableName}&schema=${schema}`
: `/api/v1/database/${dbId}/table_metadata/?name=${tableName}`;
try {
const response = await SupersetClient.get({
Expand Down
20 changes: 11 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def short_data(self) -> dict[str, Any]:
"edit_url": self.url,
"id": self.id,
"uid": self.uid,
"schema": self.schema,
"schema": self.schema or None,
"name": self.name,
"type": self.type,
"connection": self.connection,
Expand Down Expand Up @@ -385,7 +385,7 @@ def data(self) -> dict[str, Any]:
"datasource_name": self.datasource_name,
"table_name": self.datasource_name,
"type": self.type,
"schema": self.schema,
"schema": self.schema or None,
"offset": self.offset,
"cache_timeout": self.cache_timeout,
"params": self.params,
Expand Down Expand Up @@ -1266,7 +1266,7 @@ def link(self) -> Markup:

def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
return security_manager.get_schema_perm(self.database, self.schema or None)

def get_perm(self) -> str:
"""
Expand Down Expand Up @@ -1323,7 +1323,7 @@ def external_metadata(self) -> list[ResultSetColumnType]:
return get_virtual_table_metadata(dataset=self)
return get_physical_table_metadata(
database=self.database,
table=Table(self.table_name, self.schema, self.catalog),
table=Table(self.table_name, self.schema or None, self.catalog),
normalize_columns=self.normalize_columns,
)

Expand All @@ -1339,7 +1339,7 @@ def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
Table(self.table_name, self.schema, self.catalog),
Table(self.table_name, self.schema or None, self.catalog),
show_cols=False,
latest_partition=False,
)
Expand Down Expand Up @@ -1545,7 +1545,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
col_desc = get_columns_description(
self.database,
self.catalog,
self.schema,
self.schema or None,
sql,
)
if not col_desc:
Expand Down Expand Up @@ -1752,7 +1752,9 @@ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
return df

try:
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
df = self.database.get_df(
sql, self.schema or None, mutator=assign_column_label
)
except (SupersetErrorException, SupersetErrorsException) as ex:
# SupersetError(s) exception should not be captured; instead, they should
# bubble up to the Flask error handler so they are returned as proper SIP-40
Expand Down Expand Up @@ -1789,7 +1791,7 @@ def get_sqla_table_object(self) -> Table:
return self.database.get_table(
Table(
self.table_name,
self.schema,
self.schema or None,
self.catalog,
)
)
Expand All @@ -1807,7 +1809,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
for metric in self.database.get_metrics(
Table(
self.table_name,
self.schema,
self.schema or None,
self.catalog,
)
)
Expand Down
6 changes: 3 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,11 @@ def supports_backend(cls, backend: str, driver: str | None = None) -> bool:
return driver in cls.drivers

@classmethod
def get_default_schema(cls, database: Database) -> str | None:
def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
Return the default schema in a given database.
"""
with database.get_inspector() as inspector:
with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name

@classmethod
Expand Down Expand Up @@ -699,7 +699,7 @@ def get_default_schema_for_query(
return schema

# return the default schema of the database
return cls.get_default_schema(database)
return cls.get_default_schema(database, query.catalog)

@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
Expand Down
9 changes: 6 additions & 3 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@
)
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.models.core import Database
from superset.sql_parse import Table

with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
Expand Down Expand Up @@ -96,8 +96,11 @@ def get_extra_table_metadata(
),
}

if database.has_view_by_name(table.table, table.schema):
with database.get_inspector() as inspector:
if database.has_view(Table(table.table, table.schema)):
with database.get_inspector(
catalog=table.catalog,
schema=table.schema,
) as inspector:
metadata["view"] = inspector.get_view_definition(
table.table,
table.schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
"""Add catalog column
Revision ID: 5f57af97bc3f
Revises: 5ad7321c2169
Revises: d60591c5515f
Create Date: 2024-04-11 15:41:34.663989
"""

# revision identifiers, used by Alembic.
revision = "5f57af97bc3f"
down_revision = "5ad7321c2169"
down_revision = "d60591c5515f"

import sqlalchemy as sa
from alembic import op
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/databases/commands/csv_upload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _setup_csv_upload(allowed_schemas: list[str] | None = None):
yield

upload_db = get_upload_db()
with upload_db.get_sqla_engine_with_context() as engine:
with upload_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {CSV_UPLOAD_TABLE_W_SCHEMA}")
db.session.delete(upload_db)
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_csv_upload_options(csv_data, options, table_data):
create_csv_file(csv_data),
options=options,
).run()
with upload_database.get_sqla_engine_with_context() as engine:
with upload_database.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == table_data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _setup_excel_upload(allowed_schemas: list[str] | None = None):
yield

upload_db = get_upload_db()
with upload_db.get_sqla_engine_with_context() as engine:
with upload_db.get_sqla_engine() as engine:
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE}")
engine.execute(f"DROP TABLE IF EXISTS {EXCEL_UPLOAD_TABLE_W_SCHEMA}")
db.session.delete(upload_db)
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_excel_upload_options(excel_data, options, table_data):
create_excel_file(excel_data),
options=options,
).run()
with upload_database.get_sqla_engine_with_context() as engine:
with upload_database.get_sqla_engine() as engine:
data = engine.execute(f"SELECT * from {EXCEL_UPLOAD_TABLE}").fetchall()
assert data == table_data

Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,13 +773,13 @@ def test_create_dataset_validate_tables_exists(self):

@patch("superset.models.core.Database.get_columns")
@patch("superset.models.core.Database.has_table")
@patch("superset.models.core.Database.has_view_by_name")
@patch("superset.models.core.Database.has_view")
@patch("superset.models.core.Database.get_table")
def test_create_dataset_validate_view_exists(
self,
mock_get_table,
mock_has_table,
mock_has_view_by_name,
mock_has_view,
mock_get_columns,
):
"""
Expand All @@ -796,7 +796,7 @@ def test_create_dataset_validate_view_exists(
]

mock_has_table.return_value = False
mock_has_view_by_name.return_value = True
mock_has_view.return_value = True
mock_get_table.return_value = None

example_db = get_example_database()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_get_table_names(self):
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()
sqla_table = example_db.get_table("energy_usage")
sqla_table = example_db.get_table(Table("energy_usage"))
dialect = example_db.get_dialect()

# TODO: fix column type conversion for presto.
Expand Down Expand Up @@ -540,8 +540,7 @@ def test_get_indexes():
BaseEngineSpec.get_indexes(
database=mock.Mock(),
inspector=inspector,
table_name="bar",
schema="foo",
table=Table("bar", "foo"),
)
== indexes
)
9 changes: 3 additions & 6 deletions tests/integration_tests/db_engine_specs/bigquery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def test_get_indexes(self):
BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
Table(table_name, schema),
)
== []
)
Expand All @@ -184,8 +183,7 @@ def test_get_indexes(self):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
Table(table_name, schema),
) == [
{
"name": "partition",
Expand All @@ -207,8 +205,7 @@ def test_get_indexes(self):
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
table_name,
schema,
Table(table_name, schema),
) == [
{
"name": "partition",
Expand Down
16 changes: 12 additions & 4 deletions tests/integration_tests/db_engine_specs/hive_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import Table, ParsedQuery
from superset.sql_parse import ParsedQuery, Table
from tests.integration_tests.test_app import app


Expand Down Expand Up @@ -344,7 +344,10 @@ def test_where_latest_partition(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", database, select(), columns
database,
Table("test_table", "test_schema"),
select(),
columns,
)
query_result = str(result.compile(compile_kwargs={"literal_binds": True}))
assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result
Expand All @@ -357,7 +360,10 @@ def test_where_latest_partition_super_method_exception(mock_method):
columns = [{"name": "ds"}, {"name": "hour"}]
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", database, select(), columns
database,
Table("test_table", "test_schema"),
select(),
columns,
)
assert result is None
mock_method.assert_called()
Expand All @@ -369,7 +375,9 @@ def test_where_latest_partition_no_columns_no_values(mock_method):
db = mock.Mock()
with app.app_context():
result = HiveEngineSpec.where_latest_partition(
"test_table", "test_schema", db, select()
db,
Table("test_table", "test_schema"),
select(),
)
assert result is None

Expand Down
Loading

0 comments on commit 1f73f2f

Please sign in to comment.