From 1ebdaac487ec1684050174957a1d5699912bf001 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 2 May 2022 14:50:56 -0700 Subject: [PATCH] fix: memoize primitives (#19930) --- superset/db_engine_specs/base.py | 26 ++++++++------ superset/db_engine_specs/sqlite.py | 30 +++++++++------- superset/models/core.py | 28 +++++++++------ superset/utils/cache.py | 13 ++++++- superset/views/core.py | 34 +++++++++++-------- .../unit_tests/db_engine_specs/test_sqlite.py | 4 +-- 6 files changed, 86 insertions(+), 49 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1393fcdac5915..7d133e6245c1d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -864,18 +864,24 @@ def get_all_datasource_names( all_datasources: List[utils.DatasourceName] = [] for schema in schemas: if datasource_type == "table": - all_datasources += database.get_all_table_names_in_schema( - schema=schema, - force=True, - cache=database.table_cache_enabled, - cache_timeout=database.table_cache_timeout, + all_datasources.extend( + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_table_names_in_schema( + schema=schema, + force=True, + cache=database.table_cache_enabled, + cache_timeout=database.table_cache_timeout, + ) ) elif datasource_type == "view": - all_datasources += database.get_all_view_names_in_schema( - schema=schema, - force=True, - cache=database.table_cache_enabled, - cache_timeout=database.table_cache_timeout, + all_datasources.extend( + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_view_names_in_schema( + schema=schema, + force=True, + cache=database.table_cache_enabled, + cache_timeout=database.table_cache_timeout, + ) ) else: raise Exception(f"Unsupported datasource_type: {datasource_type}") diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index 23512b3cb492f..c6edd4977c720 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -81,19 +81,25 @@ def get_all_datasource_names( ) schema = schemas[0] if datasource_type == "table": - return database.get_all_table_names_in_schema( - schema=schema, - force=True, - cache=database.table_cache_enabled, - cache_timeout=database.table_cache_timeout, - ) + return [ + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_table_names_in_schema( + schema=schema, + force=True, + cache=database.table_cache_enabled, + cache_timeout=database.table_cache_timeout, + ) + ] if datasource_type == "view": - return database.get_all_view_names_in_schema( - schema=schema, - force=True, - cache=database.table_cache_enabled, - cache_timeout=database.table_cache_timeout, - ) + return [ + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_view_names_in_schema( + schema=schema, + force=True, + cache=database.table_cache_enabled, + cache_timeout=database.table_cache_timeout, + ) + ] raise Exception(f"Unsupported datasource_type: {datasource_type}") @classmethod diff --git a/superset/models/core.py b/superset/models/core.py index d90aa2569625c..8d16ea39f8c5f 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -522,11 +522,16 @@ def get_all_table_names_in_database( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[bool] = None, force: bool = False, - ) -> List[utils.DatasourceName]: + ) -> List[Tuple[str, str]]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] - return self.db_engine_spec.get_all_datasource_names(self, "table") + return [ + (datasource_name.table, datasource_name.schema) + for datasource_name in self.db_engine_spec.get_all_datasource_names( + self, "table" + ) + ] @cache_util.memoized_func( key="db:{self.id}:schema:None:view_list", @@ -537,11 +542,16 @@ def get_all_view_names_in_database( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[bool] = None, force: bool = False, - ) -> List[utils.DatasourceName]: + ) -> List[Tuple[str, str]]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] - return self.db_engine_spec.get_all_datasource_names(self, "view") + return [ + (datasource_name.table, datasource_name.schema) + for datasource_name in self.db_engine_spec.get_all_datasource_names( + self, "view" + ) + ] @cache_util.memoized_func( key="db:{self.id}:schema:{schema}:table_list", @@ -553,7 +563,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> List[utils.DatasourceName]: + ) -> List[Tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -569,9 +579,7 @@ def get_all_table_names_in_schema( # pylint: disable=unused-argument tables = self.db_engine_spec.get_table_names( database=self, inspector=self.inspector, schema=schema ) - return [ - utils.DatasourceName(table=table, schema=schema) for table in tables - ] + return [(table, schema) for table in tables] except Exception as ex: # pylint: disable=broad-except logger.warning(ex) return [] @@ -586,7 +594,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> List[utils.DatasourceName]: + ) -> List[Tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -602,7 +610,7 @@ def get_all_view_names_in_schema( # pylint: disable=unused-argument views = self.db_engine_spec.get_view_names( database=self, inspector=self.inspector, schema=schema ) - return [utils.DatasourceName(table=view, schema=schema) for view in views] + return [(view, schema) for view in views] except Exception as ex: # pylint: disable=broad-except logger.warning(ex) return [] diff --git a/superset/utils/cache.py b/superset/utils/cache.py index d86f92398b570..cdbe34bd72662 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -98,7 +98,18 @@ def memoized_func( key: Optional[str] = None, cache: Cache = cache_manager.cache, ) -> Callable[..., Any]: - """Use this decorator to cache functions that have predefined first arg. + """ + Decorator with configurable key and cache backend. + + @memoized_func(key="{a}+{b}", cache=cache_manager.data_cache) + def sum(a: int, b: int) -> int: + return a + b + + In the example above the result for `1+2` will be stored under the key of name "1+2", + in the `cache_manager.data_cache` cache. + + Note: this decorator should be used only with functions that return primitives, + otherwise the deserialization might not work correctly. enable_cache is treated as True by default, except enable_cache = False is passed to the decorated function. diff --git a/superset/views/core.py b/superset/views/core.py index de5c42837d243..478cd93a739d9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1115,31 +1115,37 @@ def tables( # pylint: disable=too-many-locals,no-self-use,too-many-arguments substr_parsed = utils.parse_js_uri_path_item(substr, eval_undefined=True) if schema_parsed: - tables = ( - database.get_all_table_names_in_schema( + tables = [ + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_table_names_in_schema( schema=schema_parsed, force=force_refresh_parsed, cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) - or [] - ) - views = ( - database.get_all_view_names_in_schema( + ] or [] + views = [ + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_view_names_in_schema( schema=schema_parsed, force=force_refresh_parsed, cache=database.table_cache_enabled, cache_timeout=database.table_cache_timeout, ) - or [] - ) + ] or [] else: - tables = database.get_all_table_names_in_database( - cache=True, force=False, cache_timeout=24 * 60 * 60 - ) - views = database.get_all_view_names_in_database( - cache=True, force=False, cache_timeout=24 * 60 * 60 - ) + tables = [ + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_table_names_in_database( + cache=True, force=False, cache_timeout=24 * 60 * 60 + ) + ] + views = [ + utils.DatasourceName(*datasource_name) + for datasource_name in database.get_all_view_names_in_database( + cache=True, force=False, cache_timeout=24 * 60 * 60 + ) + ] tables = security_manager.get_datasources_accessible_by_user( database, tables, schema_parsed ) diff --git a/tests/unit_tests/db_engine_specs/test_sqlite.py b/tests/unit_tests/db_engine_specs/test_sqlite.py index 2ee8ea9a2c603..6b2da20b4fae0 100644 --- a/tests/unit_tests/db_engine_specs/test_sqlite.py +++ b/tests/unit_tests/db_engine_specs/test_sqlite.py @@ -46,7 +46,7 @@ def test_get_all_datasource_names_table(app_context: AppContext) -> None: database = mock.MagicMock() database.get_all_schema_names.return_value = ["schema1"] - table_names = ["table1", "table2"] + table_names = [("table1", "schema1"), ("table2", "schema1")] get_tables = mock.MagicMock(return_value=table_names) database.get_all_table_names_in_schema = get_tables result = SqliteEngineSpec.get_all_datasource_names(database, "table") @@ -65,7 +65,7 @@ def test_get_all_datasource_names_view(app_context: AppContext) -> None: database = mock.MagicMock() database.get_all_schema_names.return_value = ["schema1"] - views_names = ["view1", "view2"] + views_names = [("view1", "schema1"), ("view2", "schema1")] get_views = mock.MagicMock(return_value=views_names) database.get_all_view_names_in_schema = get_views result = SqliteEngineSpec.get_all_datasource_names(database, "view")