From 0ca4312212ee58b9461106d704321e0907c59e57 Mon Sep 17 00:00:00 2001 From: Shiva Raisinghani Date: Sun, 14 Nov 2021 09:57:09 -0800 Subject: [PATCH 1/3] fix: rename to schemas_allowed_for_file_upload in dbs.extra (#17323) * rename to schemas_allowed_for_file_upload in dbs.extra * black * I should really setup pre-commit hooks * Apply suggestions Co-authored-by: John Bodley <4567245+john-bodley@users.noreply.github.com> * move changes to a seperate migration * fix spaces * black Co-authored-by: John Bodley <4567245+john-bodley@users.noreply.github.com> --- ...acd_rename_to_schemas_allowed_for_file_.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 superset/migrations/versions/0ca9e5f1dacd_rename_to_schemas_allowed_for_file_.py diff --git a/superset/migrations/versions/0ca9e5f1dacd_rename_to_schemas_allowed_for_file_.py b/superset/migrations/versions/0ca9e5f1dacd_rename_to_schemas_allowed_for_file_.py new file mode 100644 index 0000000000000..5a2a4b94f4208 --- /dev/null +++ b/superset/migrations/versions/0ca9e5f1dacd_rename_to_schemas_allowed_for_file_.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""rename to schemas_allowed_for_file_upload in dbs.extra + +Revision ID: 0ca9e5f1dacd +Revises: b92d69a6643c +Create Date: 2021-11-11 04:18:26.171851 + +""" + +# revision identifiers, used by Alembic. +revision = "0ca9e5f1dacd" +down_revision = "b92d69a6643c" + +import json +import logging + +from alembic import op +from sqlalchemy import Column, Integer, Text +from sqlalchemy.ext.declarative import declarative_base + +from superset import db + +Base = declarative_base() + + +class Database(Base): + + __tablename__ = "dbs" + id = Column(Integer, primary_key=True) + extra = Column(Text) + + +def upgrade(): + bind = op.get_bind() + session = db.Session(bind=bind) + + for database in session.query(Database).all(): + try: + extra = json.loads(database.extra) + except json.decoder.JSONDecodeError as ex: + logging.warning(str(ex)) + continue + + if "schemas_allowed_for_csv_upload" in extra: + extra["schemas_allowed_for_file_upload"] = extra.pop( + "schemas_allowed_for_csv_upload" + ) + + database.extra = json.dumps(extra) + + session.commit() + session.close() + + +def downgrade(): + bind = op.get_bind() + session = db.Session(bind=bind) + + for database in session.query(Database).all(): + try: + extra = json.loads(database.extra) + except json.decoder.JSONDecodeError as ex: + logging.warning(str(ex)) + continue + + if "schemas_allowed_for_file_upload" in extra: + extra["schemas_allowed_for_csv_upload"] = extra.pop( + "schemas_allowed_for_file_upload" + ) + + database.extra = json.dumps(extra) + + session.commit() + session.close() From d8851c9a893aa2b231f2634d9d083e8badc0601f Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Sun, 14 Nov 2021 23:35:23 +0200 Subject: [PATCH 2/3] refactor(TestChartApi): move chart data api tests into TestChartDataApi (#17407) * refactor charts api tests * move new added test * refactor charts api tests --- tests/integration_tests/charts/api_tests.py | 878 +----------------- .../integration_tests/charts/data/__init__.py | 16 + .../charts/data/api_tests.py | 822 ++++++++++++++++ 3 files changed, 852 insertions(+), 864 deletions(-) create mode 100644 tests/integration_tests/charts/data/__init__.py create mode 100644 tests/integration_tests/charts/data/api_tests.py diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index ebcc21c72bd62..027788dba90bc 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -17,20 +17,10 @@ # isort:skip_file """Unit tests for Superset""" import json -import unittest from datetime import datetime from io import BytesIO -from typing import Optional, List -from unittest import mock from zipfile import is_zipfile, ZipFile -from tests.integration_tests.conftest import with_feature_flags -from superset.models.sql_lab import Query -from tests.integration_tests.insert_chart_mixin import InsertChartMixin -from tests.integration_tests.fixtures.birth_names_dashboard import ( - load_birth_names_dashboard_with_slices, -) - import humanize import prison import pytest @@ -38,36 +28,23 @@ from sqlalchemy import and_ from sqlalchemy.sql import func -from tests.integration_tests.fixtures.world_bank_dashboard import ( - load_world_bank_dashboard_with_slices, -) -from tests.integration_tests.test_app import app -from superset import security_manager -from superset.charts.commands.data import ChartDataCommand -from superset.connectors.sqla.models import SqlaTable, TableColumn -from superset.errors import SupersetErrorType -from superset.extensions import async_query_manager, cache_manager, db -from superset.models.annotations import AnnotationLayer +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import cache_manager, db from superset.models.core import Database, FavStar, FavStarClassName from superset.models.dashboard import Dashboard from superset.models.reports import ReportSchedule, ReportScheduleType from superset.models.slice import Slice -from superset.utils.core import ( - AnnotationType, - get_example_database, - get_example_default_schema, - get_main_database, - AdhocMetricExpressionType, -) -from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType +from superset.utils.core import get_example_default_schema from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin -from tests.integration_tests.base_tests import ( - SupersetTestCase, - post_assert_metric, - test_client, +from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.insert_chart_mixin import InsertChartMixin +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, +) +from tests.integration_tests.fixtures.energy_dashboard import ( + load_energy_table_with_slice, ) - from tests.integration_tests.fixtures.importexport import ( chart_config, chart_metadata_config, @@ -75,17 +52,13 @@ dataset_config, dataset_metadata_config, ) -from tests.integration_tests.fixtures.energy_dashboard import ( - load_energy_table_with_slice, -) -from tests.integration_tests.fixtures.query_context import ( - get_query_context, - ANNOTATION_LAYERS, -) from tests.integration_tests.fixtures.unicode_dashboard import ( load_unicode_dashboard_with_slice, ) -from tests.integration_tests.annotation_layers.fixtures import create_annotation_layers +from tests.integration_tests.fixtures.world_bank_dashboard import ( + load_world_bank_dashboard_with_slices, +) +from tests.integration_tests.test_app import app from tests.integration_tests.utils.get_dashboards import get_dashboards_ids CHART_DATA_URI = "api/v1/chart/data" @@ -1067,641 +1040,6 @@ def test_get_charts_no_data_access(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data["count"], 0) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_simple(self): - """ - Chart data API: Test chart data query - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - data = json.loads(rv.data.decode("utf-8")) - expected_row_count = self.get_expected_row_count("client_id_1") - self.assertEqual(data["result"][0]["rowcount"], expected_row_count) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_get_no_query_context(self): - """ - Chart data API: Test GET endpoint when query context is null - """ - self.login(username="admin") - chart = db.session.query(Slice).filter_by(slice_name="Genders").one() - rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") - data = json.loads(rv.data.decode("utf-8")) - assert data == { - "message": "Chart has no query context saved. Please save the chart again." - } - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_get(self): - """ - Chart data API: Test GET endpoint - """ - self.login(username="admin") - chart = db.session.query(Slice).filter_by(slice_name="Genders").one() - chart.query_context = json.dumps( - { - "datasource": {"id": chart.table.id, "type": "table"}, - "force": False, - "queries": [ - { - "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", - "granularity": "ds", - "filters": [], - "extras": { - "time_range_endpoints": ["inclusive", "exclusive"], - "having": "", - "having_druid": [], - "where": "", - }, - "applied_time_extras": {}, - "columns": ["gender"], - "metrics": ["sum__num"], - "orderby": [["sum__num", False]], - "annotation_layers": [], - "row_limit": 50000, - "timeseries_limit": 0, - "order_desc": True, - "url_params": {}, - "custom_params": {}, - "custom_form_data": {}, - } - ], - "result_format": "json", - "result_type": "full", - } - ) - rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") - data = json.loads(rv.data.decode("utf-8")) - assert data["result"][0]["status"] == "success" - assert data["result"][0]["rowcount"] == 2 - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_applied_time_extras(self): - """ - Chart data API: Test chart data query with applied time extras - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["applied_time_extras"] = { - "__time_range": "100 years ago : now", - "__time_origin": "now", - } - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["result"][0]["applied_filters"], - [ - {"column": "gender"}, - {"column": "num"}, - {"column": "name"}, - {"column": "__time_range"}, - ], - ) - self.assertEqual( - data["result"][0]["rejected_filters"], - [{"column": "__time_origin", "reason": "not_druid_datasource"},], - ) - expected_row_count = self.get_expected_row_count("client_id_2") - self.assertEqual(data["result"][0]["rowcount"], expected_row_count) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_limit_offset(self): - """ - Chart data API: Test chart data query with limit and offset - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["row_limit"] = 5 - request_payload["queries"][0]["row_offset"] = 0 - request_payload["queries"][0]["orderby"] = [["name", True]] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 5) - - # TODO: fix offset for presto DB - if get_example_database().backend == "presto": - return - - # ensure that offset works properly - offset = 2 - expected_name = result["data"][offset]["name"] - request_payload["queries"][0]["row_offset"] = offset - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 5) - self.assertEqual(result["data"][0]["name"], expected_name) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 7}, - ) - def test_chart_data_default_row_limit(self): - """ - Chart data API: Ensure row count doesn't exceed default limit - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - del request_payload["queries"][0]["row_limit"] - - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 7) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10}, - ) - def test_chart_data_sql_max_row_limit(self): - """ - Chart data API: Ensure row count doesn't exceed max global row limit - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["row_limit"] = 10000000 - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 10) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, - ) - def test_chart_data_sample_default_limit(self): - """ - Chart data API: Ensure sample response row count defaults to config defaults - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = ChartDataResultType.SAMPLES - del request_payload["queries"][0]["row_limit"] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 5) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.common.query_actions.config", - {**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15}, - ) - def test_chart_data_sample_custom_limit(self): - """ - Chart data API: Ensure requested sample response row count is between - default and SQL max row limit - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = ChartDataResultType.SAMPLES - request_payload["queries"][0]["row_limit"] = 10 - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 10) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @mock.patch( - "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5}, - ) - def test_chart_data_sql_max_row_sample_limit(self): - """ - Chart data API: Ensure requested sample response row count doesn't - exceed SQL max row limit - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = ChartDataResultType.SAMPLES - request_payload["queries"][0]["row_limit"] = 10000000 - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 5) - - def test_chart_data_incorrect_result_type(self): - """ - Chart data API: Test chart data with unsupported result type - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = "qwerty" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 400) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_incorrect_result_format(self): - """ - Chart data API: Test chart data with unsupported result format - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_format"] = "qwerty" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 400) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_invalid_form_data(self): - """ - Chart data API: Test chart data with invalid form_data json - """ - self.login(username="admin") - data = {"form_data": "NOT VALID JSON"} - - rv = self.client.post( - CHART_DATA_URI, data=data, content_type="multipart/form-data" - ) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertEqual(response["message"], "Request is not JSON") - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_query_result_type(self): - """ - Chart data API: Test chart data with query result format - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = ChartDataResultType.QUERY - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_csv_result_format(self): - """ - Chart data API: Test chart data with CSV result format - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_format"] = "csv" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - - # Test chart csv without permission - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_csv_result_format_permission_denined(self): - """ - Chart data API: Test chart data with CSV result format - """ - self.login(username="gamma_no_csv") - request_payload = get_query_context("birth_names") - request_payload["result_format"] = "csv" - - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 403) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_mixed_case_filter_op(self): - """ - Chart data API: Ensure mixed case filter operator generates valid result - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["filters"][0]["op"] = "In" - request_payload["queries"][0]["row_limit"] = 10 - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 10) - - @unittest.skip("Failing due to timezone difference") - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_dttm_filter(self): - """ - Chart data API: Ensure temporal column filter converts epoch to dttm expression - """ - table = self.get_birth_names_dataset() - if table.database.backend == "presto": - # TODO: date handling on Presto not fully in line with other engine specs - return - - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["time_range"] = "" - dttm = self.get_dttm() - ms_epoch = dttm.timestamp() * 1000 - request_payload["queries"][0]["filters"][0] = { - "col": "ds", - "op": "!=", - "val": ms_epoch, - } - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - - # assert that unconverted timestamp is not present in query - assert str(ms_epoch) not in result["query"] - - # assert that converted timestamp is present in query where supported - dttm_col: Optional[TableColumn] = None - for col in table.columns: - if col.column_name == table.main_dttm_col: - dttm_col = col - if dttm_col: - dttm_expression = table.database.db_engine_spec.convert_dttm( - dttm_col.type, dttm, - ) - self.assertIn(dttm_expression, result["query"]) - else: - raise Exception("ds column not found") - - def test_chart_data_prophet(self): - """ - Chart data API: Ensure prophet post transformation works - """ - pytest.importorskip("prophet") - self.login(username="admin") - request_payload = get_query_context("birth_names") - time_grain = "P1Y" - request_payload["queries"][0]["is_timeseries"] = True - request_payload["queries"][0]["groupby"] = [] - request_payload["queries"][0]["extras"] = {"time_grain_sqla": time_grain} - request_payload["queries"][0]["granularity"] = "ds" - request_payload["queries"][0]["post_processing"] = [ - { - "operation": "prophet", - "options": { - "time_grain": time_grain, - "periods": 3, - "confidence_interval": 0.9, - }, - } - ] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - row = result["data"][0] - self.assertIn("__timestamp", row) - self.assertIn("sum__num", row) - self.assertIn("sum__num__yhat", row) - self.assertIn("sum__num__yhat_upper", row) - self.assertIn("sum__num__yhat_lower", row) - self.assertEqual(result["rowcount"], 47) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_query_missing_filter(self): - """ - Chart data API: Ensure filter referencing missing column is ignored - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["filters"] = [ - {"col": "non_existent_filter", "op": "==", "val": "foo"}, - ] - request_payload["result_type"] = ChartDataResultType.QUERY - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - response_payload = json.loads(rv.data.decode("utf-8")) - assert "non_existent_filter" not in response_payload["result"][0]["query"] - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_no_data(self): - """ - Chart data API: Test chart data with empty result - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["filters"] = [ - {"col": "gender", "op": "==", "val": "foo"} - ] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - self.assertEqual(result["rowcount"], 0) - self.assertEqual(result["data"], []) - - def test_chart_data_incorrect_request(self): - """ - Chart data API: Test chart data with invalid SQL - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["filters"] = [] - # erroneus WHERE-clause - request_payload["queries"][0]["extras"]["where"] = "(gender abc def)" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 400) - - def test_chart_data_with_invalid_datasource(self): - """ - Chart data API: Test chart data query with invalid schema - """ - self.login(username="admin") - payload = get_query_context("birth_names") - payload["datasource"] = "abc" - rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") - self.assertEqual(rv.status_code, 400) - - def test_chart_data_with_invalid_enum_value(self): - """ - Chart data API: Test chart data query with invalid enum value - """ - self.login(username="admin") - payload = get_query_context("birth_names") - payload["queries"][0]["extras"]["time_range_endpoints"] = [ - "abc", - "EXCLUSIVE", - ] - rv = self.client.post(CHART_DATA_URI, json=payload) - self.assertEqual(rv.status_code, 400) - - def test_query_exec_not_allowed(self): - """ - Chart data API: Test chart data query not allowed - """ - self.login(username="gamma") - payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") - self.assertEqual(rv.status_code, 401) - response_payload = json.loads(rv.data.decode("utf-8")) - assert ( - response_payload["errors"][0]["error_type"] - == SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR - ) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_jinja_filter_request(self): - """ - Chart data API: Ensure request referencing filters via jinja renders a correct query - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = ChartDataResultType.QUERY - request_payload["queries"][0]["filters"] = [ - {"col": "gender", "op": "==", "val": "boy"} - ] - request_payload["queries"][0]["extras"][ - "where" - ] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0]["query"] - if get_example_database().backend != "presto": - assert "('boy' = 'boy')" in result - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_async(self): - """ - Chart data API: Test chart data query (async) - """ - async_query_manager.init_app(app) - self.login(username="admin") - request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 202) - data = json.loads(rv.data.decode("utf-8")) - keys = list(data.keys()) - self.assertCountEqual( - keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] - ) - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_async_cached_sync_response(self): - """ - Chart data API: Test chart data query returns results synchronously - when results are already cached. - """ - async_query_manager.init_app(app) - self.login(username="admin") - - class QueryContext: - result_format = ChartDataResultFormat.JSON - result_type = ChartDataResultType.FULL - - cmd_run_val = { - "query_context": QueryContext(), - "queries": [{"query": "select * from foo"}], - } - - with mock.patch.object( - ChartDataCommand, "run", return_value=cmd_run_val - ) as patched_run: - request_payload = get_query_context("birth_names") - request_payload["result_type"] = ChartDataResultType.FULL - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - data = json.loads(rv.data.decode("utf-8")) - patched_run.assert_called_once_with(force_cached=True) - self.assertEqual(data, {"result": [{"query": "select * from foo"}]}) - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_async_results_type(self): - """ - Chart data API: Test chart data query non-JSON format (async) - """ - async_query_manager.init_app(app) - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["result_type"] = "results" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_async_invalid_token(self): - """ - Chart data API: Test chart data query (async) - """ - async_query_manager.init_app(app) - self.login(username="admin") - request_payload = get_query_context("birth_names") - test_client.set_cookie( - "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo" - ) - rv = test_client.post(CHART_DATA_URI, json=request_payload) - self.assertEqual(rv.status_code, 401) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @mock.patch("superset.charts.data.api.QueryContextCacheLoader") - def test_chart_data_cache(self, cache_loader): - """ - Chart data cache API: Test chart data async cache request - """ - async_query_manager.init_app(app) - self.login(username="admin") - query_context = get_query_context("birth_names") - cache_loader.load.return_value = query_context - orig_run = ChartDataCommand.run - - def mock_run(self, **kwargs): - assert kwargs["force_cached"] == True - # override force_cached to get result from DB - return orig_run(self, force_cached=False) - - with mock.patch.object(ChartDataCommand, "run", new=mock_run): - rv = self.get_assert_metric( - f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" - ) - data = json.loads(rv.data.decode("utf-8")) - - expected_row_count = self.get_expected_row_count("client_id_3") - self.assertEqual(rv.status_code, 200) - self.assertEqual(data["result"][0]["rowcount"], expected_row_count) - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @mock.patch("superset.charts.data.api.QueryContextCacheLoader") - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_cache_run_failed(self, cache_loader): - """ - Chart data cache API: Test chart data async cache request with run failure - """ - async_query_manager.init_app(app) - self.login(username="admin") - query_context = get_query_context("birth_names") - cache_loader.load.return_value = query_context - rv = self.get_assert_metric( - f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" - ) - data = json.loads(rv.data.decode("utf-8")) - - self.assertEqual(rv.status_code, 422) - self.assertEqual(data["message"], "Error loading data from cache") - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @mock.patch("superset.charts.data.api.QueryContextCacheLoader") - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_cache_no_login(self, cache_loader): - """ - Chart data cache API: Test chart data async cache request (no login) - """ - async_query_manager.init_app(app) - query_context = get_query_context("birth_names") - cache_loader.load.return_value = query_context - orig_run = ChartDataCommand.run - - def mock_run(self, **kwargs): - assert kwargs["force_cached"] == True - # override force_cached to get result from DB - return orig_run(self, force_cached=False) - - with mock.patch.object(ChartDataCommand, "run", new=mock_run): - rv = self.client.get(f"{CHART_DATA_URI}/test-cache-key",) - - self.assertEqual(rv.status_code, 401) - - @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - def test_chart_data_cache_key_error(self): - """ - Chart data cache API: Test chart data async cache request with invalid cache key - """ - async_query_manager.init_app(app) - self.login(username="admin") - rv = self.get_assert_metric( - f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" - ) - - self.assertEqual(rv.status_code, 404) - def test_export_chart(self): """ Chart API: Test export chart @@ -1902,191 +1240,3 @@ def test_import_chart_invalid(self): } ] } - - @pytest.mark.usefixtures( - "create_annotation_layers", "load_birth_names_dashboard_with_slices" - ) - def test_chart_data_annotations(self): - """ - Chart data API: Test chart data query - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - - annotation_layers = [] - request_payload["queries"][0]["annotation_layers"] = annotation_layers - - # formula - annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA]) - - # interval - interval_layer = ( - db.session.query(AnnotationLayer) - .filter(AnnotationLayer.name == "name1") - .one() - ) - interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL] - interval["value"] = interval_layer.id - annotation_layers.append(interval) - - # event - event_layer = ( - db.session.query(AnnotationLayer) - .filter(AnnotationLayer.name == "name2") - .one() - ) - event = ANNOTATION_LAYERS[AnnotationType.EVENT] - event["value"] = event_layer.id - annotation_layers.append(event) - - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 200) - data = json.loads(rv.data.decode("utf-8")) - # response should only contain interval and event data, not formula - self.assertEqual(len(data["result"][0]["annotation_data"]), 2) - - def get_expected_row_count(self, client_id: str) -> int: - start_date = datetime.now() - start_date = start_date.replace( - year=start_date.year - 100, hour=0, minute=0, second=0 - ) - - quoted_table_name = self.quote_name("birth_names") - sql = f""" - SELECT COUNT(*) AS rows_count FROM ( - SELECT name AS name, SUM(num) AS sum__num - FROM {quoted_table_name} - WHERE ds >= '{start_date.strftime("%Y-%m-%d %H:%M:%S")}' - AND gender = 'boy' - GROUP BY name - ORDER BY sum__num DESC - LIMIT 100) AS inner__query - """ - resp = self.run_sql(sql, client_id, raise_on_error=True) - db.session.query(Query).delete() - db.session.commit() - return resp["data"][0]["rows_count"] - - def quote_name(self, name: str): - if get_main_database().backend in {"presto", "hive"}: - return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( - name - ) - return name - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_rowcount(self): - """ - Chart data API: Query total rows - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["is_rowcount"] = True - request_payload["queries"][0]["groupby"] = ["name"] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - expected_row_count = self.get_expected_row_count("client_id_4") - self.assertEqual(result["data"][0]["rowcount"], expected_row_count) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_timegrains(self): - """ - Chart data API: Query timegrains and columns - """ - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"] = [ - {"result_type": ChartDataResultType.TIMEGRAINS}, - {"result_type": ChartDataResultType.COLUMNS}, - ] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - timegrain_result = response_payload["result"][0] - column_result = response_payload["result"][1] - assert list(timegrain_result["data"][0].keys()) == [ - "name", - "function", - "duration", - ] - assert list(column_result["data"][0].keys()) == [ - "column_name", - "verbose_name", - "dtype", - ] - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_series_limit(self): - """ - Chart data API: Query total rows - """ - SERIES_LIMIT = 5 - self.login(username="admin") - request_payload = get_query_context("birth_names") - request_payload["queries"][0]["columns"] = ["state", "name"] - request_payload["queries"][0]["series_columns"] = ["name"] - request_payload["queries"][0]["series_limit"] = SERIES_LIMIT - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - response_payload = json.loads(rv.data.decode("utf-8")) - data = response_payload["result"][0]["data"] - unique_names = set(row["name"] for row in data) - self.maxDiff = None - self.assertEqual(len(unique_names), SERIES_LIMIT) - self.assertEqual( - set(column for column in data[0].keys()), {"state", "name", "sum__num"} - ) - - @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_virtual_table_with_colons(self): - """ - Chart data API: test query with literal colon characters in query, metrics, - where clause and filters - """ - self.login(username="admin") - owner = self.get_user("admin").id - user = db.session.query(security_manager.user_model).get(owner) - - table = SqlaTable( - table_name="virtual_table_1", - schema=get_example_default_schema(), - owners=[user], - database=get_example_database(), - sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names", - ) - db.session.add(table) - db.session.commit() - table.fetch_metadata() - - request_payload = get_query_context("birth_names") - request_payload["datasource"] = { - "type": "table", - "id": table.id, - } - request_payload["queries"][0]["columns"] = ["foo", "bar", "state"] - request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'" - request_payload["queries"][0]["orderby"] = None - request_payload["queries"][0]["metrics"] = [ - { - "expressionType": AdhocMetricExpressionType.SQL, - "sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)", - "label": "count", - } - ] - request_payload["queries"][0]["filters"] = [ - {"col": "foo", "op": "!=", "val": ":qwerty:",} - ] - - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - db.session.delete(table) - db.session.commit() - assert rv.status_code == 200 - response_payload = json.loads(rv.data.decode("utf-8")) - result = response_payload["result"][0] - data = result["data"] - assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} - # make sure results and query parameters are unescaped - assert {row["foo"] for row in data} == {":foo"} - assert {row["bar"] for row in data} == {":bar:"} - assert "':asdf'" in result["query"] - assert "':xyz:qwerty'" in result["query"] - assert "':qwerty:'" in result["query"] diff --git a/tests/integration_tests/charts/data/__init__.py b/tests/integration_tests/charts/data/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/integration_tests/charts/data/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py new file mode 100644 index 0000000000000..45d300b7381d6 --- /dev/null +++ b/tests/integration_tests/charts/data/api_tests.py @@ -0,0 +1,822 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +"""Unit tests for Superset""" +import json +import unittest +import copy +from datetime import datetime +from typing import Optional +from unittest import mock +from flask import Response +from tests.integration_tests.conftest import with_feature_flags +from superset.models.sql_lab import Query +from tests.integration_tests.base_tests import ( + SupersetTestCase, + test_client, +) +from tests.integration_tests.annotation_layers.fixtures import create_annotation_layers +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, +) +from tests.integration_tests.test_app import app + +import pytest + +from superset.charts.commands.data import ChartDataCommand +from superset.connectors.sqla.models import TableColumn, SqlaTable +from superset.errors import SupersetErrorType +from superset.extensions import async_query_manager, db +from superset.models.annotations import AnnotationLayer +from superset.models.slice import Slice +from superset.utils.core import ( + AnnotationType, + get_example_database, + get_example_default_schema, + get_main_database, + AdhocMetricExpressionType, +) +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType + + +from tests.integration_tests.fixtures.query_context import ( + get_query_context, + ANNOTATION_LAYERS, +) + + +CHART_DATA_URI = "api/v1/chart/data" +CHARTS_FIXTURE_COUNT = 10 + + +class BaseTestChartDataApi(SupersetTestCase): + query_context_payload_template = None + + def setUp(self) -> None: + self.login("admin") + if self.query_context_payload_template is None: + BaseTestChartDataApi.query_context_payload_template = get_query_context( + "birth_names" + ) + self.query_context_payload = copy.deepcopy(self.query_context_payload_template) + + def get_expected_row_count(self, client_id: str) -> int: + start_date = datetime.now() + start_date = start_date.replace( + year=start_date.year - 100, hour=0, minute=0, second=0 + ) + + quoted_table_name = self.quote_name("birth_names") + sql = f""" + SELECT COUNT(*) AS rows_count FROM ( + SELECT name AS name, SUM(num) AS sum__num + FROM {quoted_table_name} + WHERE ds >= '{start_date.strftime("%Y-%m-%d %H:%M:%S")}' + AND gender = 'boy' + GROUP BY name + ORDER BY sum__num DESC + LIMIT 100) AS inner__query + """ + resp = self.run_sql(sql, client_id, raise_on_error=True) + db.session.query(Query).delete() + db.session.commit() + return resp["data"][0]["rows_count"] + + def quote_name(self, name: str): + if get_main_database().backend in {"presto", "hive"}: + return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( + name + ) + return name + + +@pytest.mark.chart_data_flow +class TestPostChartDataApi(BaseTestChartDataApi): + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_valid_qc__data_is_returned(self): + # arrange + expected_row_count = self.get_expected_row_count("client_id_1") + # act + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + # assert + assert rv.status_code == 200 + self.assert_row_count(rv, expected_row_count) + + @staticmethod + def assert_row_count(rv: Response, expected_row_count: int): + assert rv.json["result"][0]["rowcount"] == expected_row_count + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_without_row_limit__row_count_as_default_row_limit(self): + # arrange + row_limit_before = app.config["ROW_LIMIT"] + expected_row_count = 7 + app.config["ROW_LIMIT"] = expected_row_count + del self.query_context_payload["queries"][0]["row_limit"] + # act + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + # assert + self.assert_row_count(rv, expected_row_count) + # cleanup + app.config["ROW_LIMIT"] = row_limit_before + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_as_samples_without_row_limit__row_count_as_default_samples_row_limit(self): + # arrange + samples_row_limit_before = app.config["SAMPLES_ROW_LIMIT"] + expected_row_count = 5 + app.config["SAMPLES_ROW_LIMIT"] = expected_row_count + self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES + del self.query_context_payload["queries"][0]["row_limit"] + + # act + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + # assert + self.assert_row_count(rv, expected_row_count) + + # cleanup + app.config["SAMPLES_ROW_LIMIT"] = samples_row_limit_before + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row(self): + # arrange + expected_row_count = 10 + max_row_before = app.config["SQL_MAX_ROW"] + app.config["SQL_MAX_ROW"] = expected_row_count + self.query_context_payload["queries"][0]["row_limit"] = 10000000 + + # act + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + # assert + self.assert_row_count(rv, expected_row_count) + + # cleanup + app.config["SQL_MAX_ROW"] = max_row_before + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch( + "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5}, + ) + def test_as_samples_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row( + self, + ): + expected_row_count = app.config["SQL_MAX_ROW"] + self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES + self.query_context_payload["queries"][0]["row_limit"] = 10000000 + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + # assert + self.assert_row_count(rv, expected_row_count) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch( + "superset.common.query_actions.config", + {**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15}, + ) + def test_with_row_limit_as_samples__rowcount_as_row_limit(self): + + expected_row_count = 10 + self.query_context_payload["result_type"] = ChartDataResultType.SAMPLES + self.query_context_payload["queries"][0]["row_limit"] = expected_row_count + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + # assert + self.assert_row_count(rv, expected_row_count) + + def test_with_incorrect_result_type__400(self): + self.query_context_payload["result_type"] = "qwerty" + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 400 + + def test_with_incorrect_result_format__400(self): + self.query_context_payload["result_format"] = "qwerty" + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 400 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_invalid_payload__400(self): + + invalid_query_context = {"form_data": "NOT VALID JSON"} + + rv = self.client.post( + CHART_DATA_URI, + data=invalid_query_context, + content_type="multipart/form-data", + ) + + assert rv.status_code == 400 + assert rv.json["message"] == "Request is not JSON" + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_query_result_type__200(self): + self.query_context_payload["result_type"] = ChartDataResultType.QUERY + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_csv_result_format(self): + """ + Chart data API: Test chart data with CSV result format + """ + self.query_context_payload["result_format"] = "csv" + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self): + """ + Chart data API: Test chart data with CSV result format + """ + self.logout() + self.login(username="gamma_no_csv") + self.query_context_payload["result_format"] = "csv" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 403 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_row_limit_and_offset__row_limit_and_offset_were_applied(self): + """ + Chart data API: Test chart data query with limit and offset + """ + self.query_context_payload["queries"][0]["row_limit"] = 5 + self.query_context_payload["queries"][0]["row_offset"] = 0 + self.query_context_payload["queries"][0]["orderby"] = [["name", True]] + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + self.assert_row_count(rv, 5) + result = rv.json["result"][0] + + # TODO: fix offset for presto DB + if get_example_database().backend == "presto": + return + + # ensure that offset works properly + offset = 2 + expected_name = result["data"][offset]["name"] + self.query_context_payload["queries"][0]["row_offset"] = offset + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + result = rv.json["result"][0] + assert result["rowcount"] == 5 + assert result["data"][0]["name"] == expected_name + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_applied_time_extras(self): + """ + Chart data API: Test chart data query with applied time extras + """ + self.query_context_payload["queries"][0]["applied_time_extras"] = { + "__time_range": "100 years ago : now", + "__time_origin": "now", + } + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data["result"][0]["applied_filters"], + [ + {"column": "gender"}, + {"column": "num"}, + {"column": "name"}, + {"column": "__time_range"}, + ], + ) + self.assertEqual( + data["result"][0]["rejected_filters"], + [{"column": "__time_origin", "reason": "not_druid_datasource"},], + ) + expected_row_count = self.get_expected_row_count("client_id_2") + self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_in_op_filter__data_is_returned(self): + """ + Chart data API: Ensure mixed case filter operator generates valid result + """ + expected_row_count = 10 + self.query_context_payload["queries"][0]["filters"][0]["op"] = "In" + self.query_context_payload["queries"][0]["row_limit"] = expected_row_count + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + self.assert_row_count(rv, expected_row_count) + + @unittest.skip("Failing due to timezone difference") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_dttm_filter(self): + """ + Chart data API: Ensure temporal column filter converts epoch to dttm expression + """ + table = self.get_birth_names_dataset() + if table.database.backend == "presto": + # TODO: date handling on Presto not fully in line with other engine specs + return + + self.query_context_payload["queries"][0]["time_range"] = "" + dttm = self.get_dttm() + ms_epoch = dttm.timestamp() * 1000 + self.query_context_payload["queries"][0]["filters"][0] = { + "col": "ds", + "op": "!=", + "val": ms_epoch, + } + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + + # assert that unconverted timestamp is not present in query + assert str(ms_epoch) not in result["query"] + + # assert that converted timestamp is present in query where supported + dttm_col: Optional[TableColumn] = None + for col in table.columns: + if col.column_name == table.main_dttm_col: + dttm_col = col + if dttm_col: + dttm_expression = table.database.db_engine_spec.convert_dttm( + dttm_col.type, dttm, + ) + self.assertIn(dttm_expression, result["query"]) + else: + raise Exception("ds column not found") + + def test_chart_data_prophet(self): + """ + Chart data API: Ensure prophet post transformation works + """ + pytest.importorskip("prophet") + time_grain = "P1Y" + self.query_context_payload["queries"][0]["is_timeseries"] = True + self.query_context_payload["queries"][0]["groupby"] = [] + self.query_context_payload["queries"][0]["extras"] = { + "time_grain_sqla": time_grain + } + self.query_context_payload["queries"][0]["granularity"] = "ds" + self.query_context_payload["queries"][0]["post_processing"] = [ + { + "operation": "prophet", + "options": { + "time_grain": time_grain, + "periods": 3, + "confidence_interval": 0.9, + }, + } + ] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + self.assertEqual(rv.status_code, 200) + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + row = result["data"][0] + self.assertIn("__timestamp", row) + self.assertIn("sum__num", row) + self.assertIn("sum__num__yhat", row) + self.assertIn("sum__num__yhat_upper", row) + self.assertIn("sum__num__yhat_lower", row) + self.assertEqual(result["rowcount"], 47) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_query_result_type_and_non_existent_filter__filter_omitted(self): + self.query_context_payload["queries"][0]["filters"] = [ + {"col": "non_existent_filter", "op": "==", "val": "foo"}, + ] + self.query_context_payload["result_type"] = ChartDataResultType.QUERY + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + assert "non_existent_filter" not in rv.json["result"][0]["query"] + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_filter_suppose_to_return_empty_data__no_data_returned(self): + self.query_context_payload["queries"][0]["filters"] = [ + {"col": "gender", "op": "==", "val": "foo"} + ] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 200 + assert rv.json["result"][0]["data"] == [] + self.assert_row_count(rv, 0) + + def test_with_invalid_where_parameter__400(self): + self.query_context_payload["queries"][0]["filters"] = [] + # erroneus WHERE-clause + self.query_context_payload["queries"][0]["extras"]["where"] = "(gender abc def)" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 400 + + def test_with_invalid_datasource__400(self): + self.query_context_payload["datasource"] = "abc" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 400 + + def test_with_invalid_time_range_endpoints_enum_value__400(self): + self.query_context_payload["queries"][0]["extras"]["time_range_endpoints"] = [ + "abc", + "EXCLUSIVE", + ] + + rv = self.client.post(CHART_DATA_URI, json=self.query_context_payload) + + assert rv.status_code == 400 + + def test_with_not_permitted_actor__401(self): + """ + Chart data API: Test chart data query not allowed + """ + self.logout() + self.login(username="gamma") + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.status_code == 401 + assert ( + rv.json["errors"][0]["error_type"] + == SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR + ) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_when_where_parameter_is_template_and_query_result_type__query_is_templated( + self, + ): + + self.query_context_payload["result_type"] = ChartDataResultType.QUERY + self.query_context_payload["queries"][0]["filters"] = [ + {"col": "gender", "op": "==", "val": "boy"} + ] + self.query_context_payload["queries"][0]["extras"][ + "where" + ] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')" + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + result = rv.json["result"][0]["query"] + if get_example_database().backend != "presto": + assert "('boy' = 'boy')" in result + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_async(self): + self.logout() + async_query_manager.init_app(app) + self.login("admin") + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + self.assertEqual(rv.status_code, 202) + data = json.loads(rv.data.decode("utf-8")) + keys = list(data.keys()) + self.assertCountEqual( + keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] + ) + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_async_cached_sync_response(self): + """ + Chart data API: Test chart data query returns results synchronously + when results are already cached. + """ + async_query_manager.init_app(app) + + class QueryContext: + result_format = ChartDataResultFormat.JSON + result_type = ChartDataResultType.FULL + + cmd_run_val = { + "query_context": QueryContext(), + "queries": [{"query": "select * from foo"}], + } + + with mock.patch.object( + ChartDataCommand, "run", return_value=cmd_run_val + ) as patched_run: + self.query_context_payload["result_type"] = ChartDataResultType.FULL + rv = self.post_assert_metric( + CHART_DATA_URI, self.query_context_payload, "data" + ) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + patched_run.assert_called_once_with(force_cached=True) + self.assertEqual(data, {"result": [{"query": "select * from foo"}]}) + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_async_results_type(self): + """ + Chart data API: Test chart data query non-JSON format (async) + """ + async_query_manager.init_app(app) + self.query_context_payload["result_type"] = "results" + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + self.assertEqual(rv.status_code, 200) + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_async_invalid_token(self): + """ + Chart data API: Test chart data query (async) + """ + async_query_manager.init_app(app) + test_client.set_cookie( + "localhost", app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo" + ) + rv = test_client.post(CHART_DATA_URI, json=self.query_context_payload) + self.assertEqual(rv.status_code, 401) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_rowcount(self): + """ + Chart data API: Query total rows + """ + expected_row_count = self.get_expected_row_count("client_id_4") + self.query_context_payload["queries"][0]["is_rowcount"] = True + self.query_context_payload["queries"][0]["groupby"] = ["name"] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + assert rv.json["result"][0]["data"][0]["rowcount"] == expected_row_count + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_timegrains_and_columns_result_types(self): + """ + Chart data API: Query timegrains and columns + """ + self.query_context_payload["queries"] = [ + {"result_type": ChartDataResultType.TIMEGRAINS}, + {"result_type": ChartDataResultType.COLUMNS}, + ] + result = self.post_assert_metric( + CHART_DATA_URI, self.query_context_payload, "data" + ).json["result"] + + timegrain_data_keys = result[0]["data"][0].keys() + column_data_keys = result[1]["data"][0].keys() + assert list(timegrain_data_keys) == [ + "name", + "function", + "duration", + ] + assert list(column_data_keys) == [ + "column_name", + "verbose_name", + "dtype", + ] + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_series_limit(self): + SERIES_LIMIT = 5 + self.query_context_payload["queries"][0]["columns"] = ["state", "name"] + self.query_context_payload["queries"][0]["series_columns"] = ["name"] + self.query_context_payload["queries"][0]["series_limit"] = SERIES_LIMIT + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + + data = rv.json["result"][0]["data"] + + unique_names = set(row["name"] for row in data) + self.maxDiff = None + self.assertEqual(len(unique_names), SERIES_LIMIT) + self.assertEqual( + set(column for column in data[0].keys()), {"state", "name", "sum__num"} + ) + + @pytest.mark.usefixtures( + "create_annotation_layers", "load_birth_names_dashboard_with_slices" + ) + def test_with_annotations_layers__annotations_data_returned(self): + """ + Chart data API: Test chart data query + """ + + annotation_layers = [] + self.query_context_payload["queries"][0][ + "annotation_layers" + ] = annotation_layers + + # formula + annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA]) + + # interval + interval_layer = ( + db.session.query(AnnotationLayer) + .filter(AnnotationLayer.name == "name1") + .one() + ) + interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL] + interval["value"] = interval_layer.id + annotation_layers.append(interval) + + # event + event_layer = ( + db.session.query(AnnotationLayer) + .filter(AnnotationLayer.name == "name2") + .one() + ) + event = ANNOTATION_LAYERS[AnnotationType.EVENT] + event["value"] = event_layer.id + annotation_layers.append(event) + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + # response should only contain interval and event data, not formula + self.assertEqual(len(data["result"][0]["annotation_data"]), 2) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_virtual_table_with_colons_as_datasource(self): + """ + Chart data API: test query with literal colon characters in query, metrics, + where clause and filters + """ + owner = self.get_user("admin") + table = SqlaTable( + table_name="virtual_table_1", + schema=get_example_default_schema(), + owners=[owner], + database=get_example_database(), + sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names", + ) + db.session.add(table) + db.session.commit() + table.fetch_metadata() + + request_payload = self.query_context_payload + request_payload["datasource"] = { + "type": "table", + "id": table.id, + } + request_payload["queries"][0]["columns"] = ["foo", "bar", "state"] + request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'" + request_payload["queries"][0]["orderby"] = None + request_payload["queries"][0]["metrics"] = [ + { + "expressionType": AdhocMetricExpressionType.SQL, + "sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)", + "label": "count", + } + ] + request_payload["queries"][0]["filters"] = [ + {"col": "foo", "op": "!=", "val": ":qwerty:",} + ] + + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + db.session.delete(table) + db.session.commit() + assert rv.status_code == 200 + result = rv.json["result"][0] + data = result["data"] + assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} + # make sure results and query parameters are unescaped + assert {row["foo"] for row in data} == {":foo"} + assert {row["bar"] for row in data} == {":bar:"} + assert "':asdf'" in result["query"] + assert "':xyz:qwerty'" in result["query"] + assert "':qwerty:'" in result["query"] + + +@pytest.mark.chart_data_flow +class TestGetChartDataApi(BaseTestChartDataApi): + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_get_data_when_query_context_is_null(self): + """ + Chart data API: Test GET endpoint when query context is null + """ + chart = db.session.query(Slice).filter_by(slice_name="Genders").one() + rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": "Chart has no query context saved. Please save the chart again." + } + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_get(self): + """ + Chart data API: Test GET endpoint + """ + chart = db.session.query(Slice).filter_by(slice_name="Genders").one() + chart.query_context = json.dumps( + { + "datasource": {"id": chart.table.id, "type": "table"}, + "force": False, + "queries": [ + { + "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", + "granularity": "ds", + "filters": [], + "extras": { + "time_range_endpoints": ["inclusive", "exclusive"], + "having": "", + "having_druid": [], + "where": "", + }, + "applied_time_extras": {}, + "columns": ["gender"], + "metrics": ["sum__num"], + "orderby": [["sum__num", False]], + "annotation_layers": [], + "row_limit": 50000, + "timeseries_limit": 0, + "order_desc": True, + "url_params": {}, + "custom_params": {}, + "custom_form_data": {}, + } + ], + "result_format": "json", + "result_type": "full", + } + ) + rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") + data = json.loads(rv.data.decode("utf-8")) + assert data["result"][0]["status"] == "success" + assert data["result"][0]["rowcount"] == 2 + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @mock.patch("superset.charts.data.api.QueryContextCacheLoader") + def test_chart_data_cache(self, cache_loader): + """ + Chart data cache API: Test chart data async cache request + """ + async_query_manager.init_app(app) + cache_loader.load.return_value = self.query_context_payload + orig_run = ChartDataCommand.run + + def mock_run(self, **kwargs): + assert kwargs["force_cached"] == True + # override force_cached to get result from DB + return orig_run(self, force_cached=False) + + with mock.patch.object(ChartDataCommand, "run", new=mock_run): + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + data = json.loads(rv.data.decode("utf-8")) + + expected_row_count = self.get_expected_row_count("client_id_3") + self.assertEqual(rv.status_code, 200) + self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @mock.patch("superset.charts.data.api.QueryContextCacheLoader") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_cache_run_failed(self, cache_loader): + """ + Chart data cache API: Test chart data async cache request with run failure + """ + async_query_manager.init_app(app) + cache_loader.load.return_value = self.query_context_payload + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + data = json.loads(rv.data.decode("utf-8")) + + self.assertEqual(rv.status_code, 422) + self.assertEqual(data["message"], "Error loading data from cache") + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + @mock.patch("superset.charts.data.api.QueryContextCacheLoader") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_cache_no_login(self, cache_loader): + """ + Chart data cache API: Test chart data async cache request (no login) + """ + self.logout() + async_query_manager.init_app(app) + cache_loader.load.return_value = self.query_context_payload + orig_run = ChartDataCommand.run + + def mock_run(self, **kwargs): + assert kwargs["force_cached"] == True + # override force_cached to get result from DB + return orig_run(self, force_cached=False) + + with mock.patch.object(ChartDataCommand, "run", new=mock_run): + rv = self.client.get(f"{CHART_DATA_URI}/test-cache-key",) + + self.assertEqual(rv.status_code, 401) + + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) + def test_chart_data_cache_key_error(self): + """ + Chart data cache API: Test chart data async cache request with invalid cache key + """ + async_query_manager.init_app(app) + rv = self.get_assert_metric( + f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" + ) + + self.assertEqual(rv.status_code, 404) From 5d3e1b5c2cbcec3ee7354049509796a18980dc17 Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Mon, 15 Nov 2021 01:00:08 +0200 Subject: [PATCH 3/3] refactor: ChartDataCommand into two separate commands (#17425) --- superset/charts/data/api.py | 18 ++++++++----- .../{commands/data.py => data/commands.py} | 26 +++++++++---------- superset/tasks/async_queries.py | 2 +- .../charts/data/api_tests.py | 2 +- .../tasks/async_queries_tests.py | 4 +-- 5 files changed, 27 insertions(+), 25 deletions(-) rename superset/charts/{commands/data.py => data/commands.py} (88%) diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 37703339e7f13..534101bae6be1 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -28,11 +28,14 @@ from superset import is_feature_enabled, security_manager from superset.charts.api import ChartRestApi -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) +from superset.charts.data.commands import ( + ChartDataCommand, + CreateAsyncChartDataJobCommand, +) from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.post_processing import apply_post_process from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -145,7 +148,7 @@ def get_data(self, pk: int) -> Response: and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(command) + return self._run_async(json_body, command) try: form_data = json.loads(chart.params) @@ -231,7 +234,7 @@ def data(self) -> Response: and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(command) + return self._run_async(json_body, command) return self._get_data_response(command) @@ -289,7 +292,9 @@ def data_from_cache(self, cache_key: str) -> Response: return self._get_data_response(command, True) - def _run_async(self, command: ChartDataCommand) -> Response: + def _run_async( + self, form_data: Dict[str, Any], command: ChartDataCommand + ) -> Response: """ Execute command as an async query. """ @@ -309,12 +314,13 @@ def _run_async(self, command: ChartDataCommand) -> Response: # Clients will either poll or be notified of query completion, # at which point they will call the /data/ endpoint # to retrieve the results. + async_command = CreateAsyncChartDataJobCommand() try: - command.validate_async_request(request) + async_command.validate(request) except AsyncQueryTokenException: return self.response_401() - result = command.run_async(g.user.get_id()) + result = async_command.run(form_data, g.user.get_id()) return self.response(202, **result) def _send_chart_response( diff --git a/superset/charts/commands/data.py b/superset/charts/data/commands.py similarity index 88% rename from superset/charts/commands/data.py rename to superset/charts/data/commands.py index ec63362a5c3d0..d434f79a17101 100644 --- a/superset/charts/commands/data.py +++ b/superset/charts/data/commands.py @@ -35,10 +35,7 @@ class ChartDataCommand(BaseCommand): - def __init__(self) -> None: - self._form_data: Dict[str, Any] - self._query_context: QueryContext - self._async_channel_id: str + _query_context: QueryContext def run(self, **kwargs: Any) -> Dict[str, Any]: # caching is handled in query_context.get_df_payload @@ -66,26 +63,27 @@ def run(self, **kwargs: Any) -> Dict[str, Any]: return return_value - def run_async(self, user_id: Optional[str]) -> Dict[str, Any]: - job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) - load_chart_data_into_cache.delay(job_metadata, self._form_data) - - return job_metadata - def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext: - self._form_data = form_data try: - self._query_context = ChartDataQueryContextSchema().load(self._form_data) + self._query_context = ChartDataQueryContextSchema().load(form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex except ValidationError as error: raise error - return self._query_context def validate(self) -> None: self._query_context.raise_for_access() - def validate_async_request(self, request: Request) -> None: + +class CreateAsyncChartDataJobCommand: + _async_channel_id: str + + def validate(self, request: Request) -> None: jwt_data = async_query_manager.parse_jwt_from_request(request) self._async_channel_id = jwt_data["channel"] + + def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]: + job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) + load_chart_data_into_cache.delay(job_metadata, form_data) + return job_metadata diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 18094323ec1ec..c50dbb9a94436 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -55,7 +55,7 @@ def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel - from superset.charts.commands.data import ChartDataCommand + from superset.charts.data.commands import ChartDataCommand try: ensure_user_is_set(job_metadata.get("user_id")) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 45d300b7381d6..1b2ade28f4360 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -37,7 +37,7 @@ import pytest -from superset.charts.commands.data import ChartDataCommand +from superset.charts.data.commands import ChartDataCommand from superset.connectors.sqla.models import TableColumn, SqlaTable from superset.errors import SupersetErrorType from superset.extensions import async_query_manager, db diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 3ea1c6f0ce6de..e2cf21c552624 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -22,10 +22,8 @@ from celery.exceptions import SoftTimeLimitExceeded from flask import g -from superset import db -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.exceptions import ChartDataQueryFailedError -from superset.connectors.sqla.models import SqlaTable +from superset.charts.data.commands import ChartDataCommand from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager from superset.tasks import async_queries