Skip to content

Commit

Permalink
refactor: migrate table chart to new API
Browse files Browse the repository at this point in the history
  • Loading branch information
ktmud committed Jan 26, 2021
1 parent 6bf5d2c commit 46cb0cb
Show file tree
Hide file tree
Showing 16 changed files with 218 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ ignored-argument-names=_.*
max-locals=15

# Maximum number of return / yield for function / method body
max-returns=6
max-returns=10

# Maximum number of branch for function / method body
max-branches=12
Expand Down
10 changes: 10 additions & 0 deletions superset-frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions superset-frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
"@types/react-dom": "^16.9.8",
"@types/react-gravatar": "^2.6.8",
"@types/react-json-tree": "^0.6.11",
"@types/react-loadable": "^5.5.4",
"@types/react-redux": "^7.1.10",
"@types/react-router-dom": "^5.1.5",
"@types/react-select": "^3.0.19",
Expand Down
2 changes: 1 addition & 1 deletion superset-frontend/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
],
// for supressing errors caused by incompatible @types/react when `npm link`
// Ref: https://github.com/Microsoft/typescript/issues/6496#issuecomment-384786222
"react": ["./node_modules/@types/react"]
"react": ["./node_modules/@types/react", "react"]
},
"skipLibCheck": true,
"sourceMap": true,
Expand Down
8 changes: 6 additions & 2 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.exceptions import SupersetSecurityException
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
from superset.extensions import event_logger
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
Expand Down Expand Up @@ -566,9 +566,13 @@ def data(self) -> Response:
command = ChartDataCommand()
query_context = command.set_query_context(json_body)
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
except ValidationError as error:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=error.messages)
message=_(
"Request is incorrect: %(error)s", error=error.normalized_messages()
)
)
except SupersetSecurityException:
return self.response_401()
Expand Down
31 changes: 20 additions & 11 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]:
return df.to_dict(orient="records")

def get_single_payload(
self, query_obj: QueryObject, **kwargs: Any
self, query_obj: QueryObject, force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
force_cached = kwargs.get("force_cached", False)
"""Return results payload for a single quey"""
if self.result_type == utils.ChartDataResultType.QUERY:
return {
"query": self.datasource.get_query_str(query_obj.to_dict()),
"language": self.datasource.query_language,
}

if self.result_type == utils.ChartDataResultType.SAMPLES:
row_limit = query_obj.row_limit or math.inf
query_obj = copy.copy(query_obj)
Expand All @@ -174,10 +174,13 @@ def get_single_payload(
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in self.datasource.columns]

payload = self.get_df_payload(query_obj, force_cached=force_cached)
df = payload["df"]
status = payload["status"]
if status != utils.QueryStatus.FAILED:
payload["colnames"] = list(df.columns)
payload["coltypes"] = utils.serialize_pandas_dtypes(df.dtypes)
payload["data"] = self.get_data(df)
del payload["df"]

Expand All @@ -196,13 +199,19 @@ def get_single_payload(
if col not in columns
] + rejected_time_columns

if self.result_type == utils.ChartDataResultType.RESULTS:
if (
self.result_type == utils.ChartDataResultType.RESULTS
and status != utils.QueryStatus.FAILED
):
return {"data": payload["data"]}
return payload

def get_payload(self, **kwargs: Any) -> Dict[str, Any]:
cache_query_context = kwargs.get("cache_query_context", False)
force_cached = kwargs.get("force_cached", False)
def get_payload(
self,
cache_query_context: Optional[bool] = False,
force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
"""Returns the query results with both metadata and data"""

# Get all the payloads from the QueryObjects
query_results = [
Expand Down Expand Up @@ -343,10 +352,9 @@ def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
return annotation_data

def get_df_payload( # pylint: disable=too-many-statements,too-many-locals
self, query_obj: QueryObject, **kwargs: Any
self, query_obj: QueryObject, force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
force_cached = kwargs.get("force_cached", False)
cache_key = self.query_cache_key(query_obj)
logger.info("Cache key: %s", cache_key)
is_loaded = False
Expand Down Expand Up @@ -388,7 +396,7 @@ def get_df_payload( # pylint: disable=too-many-statements,too-many-locals
for col in query_obj.columns
+ query_obj.groupby
+ utils.get_column_names_from_metrics(query_obj.metrics)
if col not in self.datasource.column_names
if col not in self.datasource.column_names and col != DTTM_ALIAS
]
if invalid_columns:
raise QueryObjectValidationError(
Expand Down Expand Up @@ -447,5 +455,6 @@ def raise_for_access(self) -> None:
:raises SupersetSecurityException: If the user cannot access the resource
"""

for query in self.queries:
query.validate()
security_manager.raise_for_access(query_context=self)
59 changes: 47 additions & 12 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric
from superset.utils import pandas_postprocessing
from superset.utils.core import DTTM_ALIAS, get_metric_names, json_int_dttm_ser
from superset.utils.core import (
DTTM_ALIAS,
find_duplicates,
get_metric_names,
json_int_dttm_ser,
)
from superset.utils.date_parser import get_since_until, parse_human_timedelta
from superset.views.utils import get_time_range_endpoints

Expand Down Expand Up @@ -106,6 +111,8 @@ def __init__(
):
annotation_layers = annotation_layers or []
metrics = metrics or []
columns = columns or []
groupby = groupby or []
extras = extras or {}
is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
self.annotation_layers = [
Expand All @@ -126,19 +133,18 @@ def __init__(
time_range=time_range,
time_shift=time_shift,
)
# is_timeseries is True if time column is in groupby
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries
if is_timeseries is not None
else (DTTM_ALIAS in groupby if groupby else False)
else DTTM_ALIAS in columns + groupby
)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]
if not is_sip_38:
self.groupby = groupby or []

# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
Expand All @@ -162,13 +168,16 @@ def __init__(
if config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={})

self.columns = columns or []
if is_sip_38 and groupby:
self.columns += groupby
logger.warning(
"The field `groupby` is deprecated. Viz plugins should "
"pass all selectables via the `columns` field"
)
self.columns = columns
if is_sip_38:
if groupby:
logger.warning(
"The field `groupby` is deprecated. Viz plugins should "
"pass all selectables via the `columns` field"
)
self.columns += groupby
else:
self.groupby = groupby or []

self.orderby = orderby or []

Expand Down Expand Up @@ -214,8 +223,34 @@ def __init__(

@property
def metric_names(self) -> List[str]:
"""Return metrics names (labels), coerce adhoc metrics to strings."""
return get_metric_names(self.metrics)

@property
def column_names(self) -> List[str]:
"""Return column names (labels). Reserved for future adhoc calculated
columns."""
return self.columns

def validate(
self, raise_exceptions: Optional[bool] = True
) -> Optional[QueryObjectValidationError]:
"""Validate query object"""
error: Optional[QueryObjectValidationError] = None
all_labels = self.metric_names + self.column_names
if len(set(all_labels)) < len(all_labels):
dup_labels = find_duplicates(all_labels)
error = QueryObjectValidationError(
_(
"Duplicate column/metric labels: %(labels)s. Please make "
"sure all columns and metrics have a unique label.",
labels=", ".join(map(lambda x: f'"{x}"', dup_labels)),
)
)
if error and raise_exceptions:
raise error
return error

def to_dict(self) -> Dict[str, Any]:
query_object_dict = {
"granularity": self.granularity,
Expand Down
20 changes: 12 additions & 8 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
and (is_sip_38 or (not is_sip_38 and not groupby))
):
raise QueryObjectValidationError(_("Empty query?"))

metrics_exprs: List[ColumnElement] = []
for metric in metrics:
if utils.is_adhoc_metric(metric):
Expand All @@ -950,6 +951,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=metric)
)

if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
Expand All @@ -960,14 +962,16 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
groupby_exprs_sans_timestamp = OrderedDict()

assert extras is not None
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
# dedup columns while preserving order
columns_ = columns if is_sip_38 else groupby
assert columns_
groupby = list(dict.fromkeys(columns_))

# filter out the pseudo column __timestamp from columns
columns = columns or []
columns = [col for col in columns if col != utils.DTTM_ALIAS]

if (is_sip_38 and metrics and columns) or (not is_sip_38 and metrics):
# dedup columns while preserving order
columns = columns if is_sip_38 else (groupby or columns)
select_exprs = []
for selected in groupby:
for selected in columns:
# if groupby field/expr equals granularity field/expr
if selected == granularity:
time_grain = extras.get("time_grain_sqla")
Expand All @@ -979,7 +983,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
else:
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)

groupby_exprs_sans_timestamp[outer.name] = outer
select_exprs.append(outer)
elif columns:
Expand All @@ -1001,7 +1004,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma

if is_timeseries:
timestamp = dttm_col.get_timestamp_expression(time_grain)
select_exprs += [timestamp]
# always put timestamp as the first column
select_exprs.insert(0, timestamp)
groupby_exprs_with_timestamp[timestamp.name] = timestamp

# Use main dttm column to support index with secondary dttm columns.
Expand Down
24 changes: 9 additions & 15 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,28 +158,22 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False

# default matching patterns for identifying column types
db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[Any], ...]] = {
# default matching patterns to convert database specific column types to
# more generic types
db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[str], ...]] = {
utils.GenericDataType.NUMERIC: (
re.compile(r"BIT", re.IGNORECASE),
re.compile(r".*DOUBLE.*", re.IGNORECASE),
re.compile(r".*FLOAT.*", re.IGNORECASE),
re.compile(r".*INT.*", re.IGNORECASE),
re.compile(r".*NUMBER.*", re.IGNORECASE),
re.compile(
r".*(DOUBLE|FLOAT|INT|NUMBER|REAL|NUMERIC|DECIMAL|MONEY).*",
re.IGNORECASE,
),
re.compile(r".*LONG$", re.IGNORECASE),
re.compile(r".*REAL.*", re.IGNORECASE),
re.compile(r".*NUMERIC.*", re.IGNORECASE),
re.compile(r".*DECIMAL.*", re.IGNORECASE),
re.compile(r".*MONEY.*", re.IGNORECASE),
),
utils.GenericDataType.STRING: (
re.compile(r".*CHAR.*", re.IGNORECASE),
re.compile(r".*STRING.*", re.IGNORECASE),
re.compile(r".*TEXT.*", re.IGNORECASE),
re.compile(r".*(CHAR|STRING|TEXT).*", re.IGNORECASE),
),
utils.GenericDataType.TEMPORAL: (
re.compile(r".*DATE.*", re.IGNORECASE),
re.compile(r".*TIME.*", re.IGNORECASE),
re.compile(r".*(DATE|TIME).*", re.IGNORECASE),
),
}

Expand Down
12 changes: 8 additions & 4 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ def get_df( # pylint: disable=too-many-locals
username = utils.get_username()

def needs_conversion(df_series: pd.Series) -> bool:
return not df_series.empty and isinstance(df_series[0], (list, dict))
return (
not df_series.empty
and isinstance(df_series, pd.Series)
and isinstance(df_series[0], (list, dict))
)

def _log_query(sql: str) -> None:
if log_query:
Expand All @@ -397,9 +401,9 @@ def _log_query(sql: str) -> None:
if mutator:
mutator(df)

for k, v in df.dtypes.items():
if v.type == numpy.object_ and needs_conversion(df[k]):
df[k] = df[k].apply(utils.json_dumps_w_dates)
for col, coltype in df.dtypes.to_dict().items():
if coltype == numpy.object_ and needs_conversion(df[col]):
df[col] = df[col].apply(utils.json_dumps_w_dates)

return df

Expand Down
Loading

0 comments on commit 46cb0cb

Please sign in to comment.