diff --git a/superset/models/slice.py b/superset/models/slice.py index b8f1d93570156..c7f198b6e9d4a 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -155,10 +155,12 @@ def datasource_edit_url(self) -> Optional[str]: @property # type: ignore @utils.memoized - def viz(self) -> BaseViz: + def viz(self) -> Optional[BaseViz]: form_data = json.loads(self.params) - viz_class = viz_types[self.viz_type] - return viz_class(datasource=self.datasource, form_data=form_data) + viz_class = viz_types.get(self.viz_type) + if viz_class: + return viz_class(datasource=self.datasource, form_data=form_data) + return None @property def description_markeddown(self) -> str: @@ -170,8 +172,9 @@ def data(self) -> Dict[str, Any]: data: Dict[str, Any] = {} self.token = "" try: - data = self.viz.data - self.token = data.get("token") # type: ignore + viz = self.viz + data = viz.data if viz else self.form_data + self.token = utils.get_form_data_token(data) except Exception as ex: # pylint: disable=broad-except logger.exception(ex) data["error"] = str(ex) diff --git a/superset/utils/core.py b/superset/utils/core.py index c464d78d60148..297127e03e787 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1360,6 +1360,16 @@ def get_iterable(x: Any) -> List[Any]: return x if isinstance(x, list) else [x] +def get_form_data_token(form_data: Dict[str, Any]) -> str: + """ + Return the token contained within form data or generate a new one. + + :param form_data: chart form data + :return: original token if predefined, otherwise new uuid4 based token + """ + return form_data.get("token") or "token_" + uuid.uuid4().hex[:8] + + class LenientEnum(Enum): """Enums that do not raise ValueError when value is invalid""" diff --git a/superset/viz.py b/superset/viz.py index 1db8b32997773..bdf4c5351e49f 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -121,7 +121,7 @@ def __init__( self.form_data = form_data self.query = "" - self.token = self.form_data.get("token", "token_" + uuid.uuid4().hex[:8]) + self.token = utils.get_form_data_token(form_data) self.groupby: List[str] = self.form_data.get("groupby") or [] self.time_shift = timedelta() diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 4c2092ae8501e..7cb32b729b1a9 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -22,6 +22,7 @@ import hashlib import json import os +import re from unittest.mock import Mock, patch import numpy @@ -40,6 +41,7 @@ convert_legacy_filters_into_adhoc, create_ssl_cert_file, format_timedelta, + get_form_data_token, get_iterable, get_email_address_list, get_or_create_db, @@ -1365,3 +1367,8 @@ def test_schema_one_of_case_insensitive(self): self.assertEqual("BaZ", validator("BaZ")) self.assertRaises(marshmallow.ValidationError, validator, "qwerty") self.assertRaises(marshmallow.ValidationError, validator, 4) + + def test_get_form_data_token(self): + assert get_form_data_token({"token": "token_abcdefg1"}) == "token_abcdefg1" + generated_token = get_form_data_token({}) + assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None