Skip to content

Commit

Permalink
Merge branch 'master' into hack2021/adhoc-columns
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro committed Nov 15, 2021
2 parents 177ce6e + 5d3e1b5 commit d426d2f
Show file tree
Hide file tree
Showing 8 changed files with 990 additions and 909 deletions.
18 changes: 12 additions & 6 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@

from superset import is_feature_enabled, security_manager
from superset.charts.api import ChartRestApi
from superset.charts.commands.data import ChartDataCommand
from superset.charts.commands.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.charts.data.commands import (
ChartDataCommand,
CreateAsyncChartDataJobCommand,
)
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.post_processing import apply_post_process
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
Expand Down Expand Up @@ -145,7 +148,7 @@ def get_data(self, pk: int) -> Response:
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(command)
return self._run_async(json_body, command)

try:
form_data = json.loads(chart.params)
Expand Down Expand Up @@ -231,7 +234,7 @@ def data(self) -> Response:
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(command)
return self._run_async(json_body, command)

return self._get_data_response(command)

Expand Down Expand Up @@ -289,7 +292,9 @@ def data_from_cache(self, cache_key: str) -> Response:

return self._get_data_response(command, True)

def _run_async(self, command: ChartDataCommand) -> Response:
def _run_async(
self, form_data: Dict[str, Any], command: ChartDataCommand
) -> Response:
"""
Execute command as an async query.
"""
Expand All @@ -309,12 +314,13 @@ def _run_async(self, command: ChartDataCommand) -> Response:
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint
# to retrieve the results.
async_command = CreateAsyncChartDataJobCommand()
try:
command.validate_async_request(request)
async_command.validate(request)
except AsyncQueryTokenException:
return self.response_401()

result = command.run_async(g.user.get_id())
result = async_command.run(form_data, g.user.get_id())
return self.response(202, **result)

def _send_chart_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@


class ChartDataCommand(BaseCommand):
def __init__(self) -> None:
self._form_data: Dict[str, Any]
self._query_context: QueryContext
self._async_channel_id: str
_query_context: QueryContext

def run(self, **kwargs: Any) -> Dict[str, Any]:
# caching is handled in query_context.get_df_payload
Expand Down Expand Up @@ -66,26 +63,27 @@ def run(self, **kwargs: Any) -> Dict[str, Any]:

return return_value

def run_async(self, user_id: Optional[str]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, self._form_data)

return job_metadata

def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext:
self._form_data = form_data
try:
self._query_context = ChartDataQueryContextSchema().load(self._form_data)
self._query_context = ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
except ValidationError as error:
raise error

return self._query_context

def validate(self) -> None:
self._query_context.raise_for_access()

def validate_async_request(self, request: Request) -> None:

class CreateAsyncChartDataJobCommand:
_async_channel_id: str

def validate(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]

def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""rename to schemas_allowed_for_file_upload in dbs.extra
Revision ID: 0ca9e5f1dacd
Revises: b92d69a6643c
Create Date: 2021-11-11 04:18:26.171851
"""

# revision identifiers, used by Alembic.
revision = "0ca9e5f1dacd"
down_revision = "b92d69a6643c"

import json
import logging

from alembic import op
from sqlalchemy import Column, Integer, Text
from sqlalchemy.ext.declarative import declarative_base

from superset import db

Base = declarative_base()


class Database(Base):

__tablename__ = "dbs"
id = Column(Integer, primary_key=True)
extra = Column(Text)


def upgrade():
bind = op.get_bind()
session = db.Session(bind=bind)

for database in session.query(Database).all():
try:
extra = json.loads(database.extra)
except json.decoder.JSONDecodeError as ex:
logging.warning(str(ex))
continue

if "schemas_allowed_for_csv_upload" in extra:
extra["schemas_allowed_for_file_upload"] = extra.pop(
"schemas_allowed_for_csv_upload"
)

database.extra = json.dumps(extra)

session.commit()
session.close()


def downgrade():
bind = op.get_bind()
session = db.Session(bind=bind)

for database in session.query(Database).all():
try:
extra = json.loads(database.extra)
except json.decoder.JSONDecodeError as ex:
logging.warning(str(ex))
continue

if "schemas_allowed_for_file_upload" in extra:
extra["schemas_allowed_for_csv_upload"] = extra.pop(
"schemas_allowed_for_file_upload"
)

database.extra = json.dumps(extra)

session.commit()
session.close()
2 changes: 1 addition & 1 deletion superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
) -> None:
# pylint: disable=import-outside-toplevel
from superset.charts.commands.data import ChartDataCommand
from superset.charts.data.commands import ChartDataCommand

try:
ensure_user_is_set(job_metadata.get("user_id"))
Expand Down
Loading

0 comments on commit d426d2f

Please sign in to comment.