Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sql_json] Ensuring the request body is JSON encoded #8256

Merged
merged 1 commit into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/assets/cypress/integration/sqllab/query.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export default () => {
cy.server();
cy.visit('/superset/sqllab');

cy.route('POST', '/superset/sql_json/**').as('sqlLabQuery');
cy.route('POST', '/superset/sql_json/').as('sqlLabQuery');
});

it('supports entering and running a query', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ describe('async actions', () => {
});

describe('runQuery', () => {
const runQueryEndpoint = 'glob:*/superset/sql_json/*';
const runQueryEndpoint = 'glob:*/superset/sql_json/';
fetchMock.post(runQueryEndpoint, '{ "data": ' + mockBigNumber + ' }');

const makeRequest = () => {
Expand Down
6 changes: 3 additions & 3 deletions superset/assets/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ export function runQuery(query) {
};

return SupersetClient.post({
endpoint: `/superset/sql_json/${window.location.search}`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

postPayload,
stringify: false,
endpoint: '/superset/sql_json/',
Copy link

@graceguo-supercat graceguo-supercat Sep 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this stringify: false setting is from PR #5896. @williaster is there a special reason to turn off stringify?

if use stringify, do we still need parseMethod: 'text'?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we really want stringify: false? I was under the impression that prevented us from sending up "null" and "false" versus null and false

And parseMethod determines how we parse the response, not the request

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will see we use post method in many, many places, only in sql lab added stringify: false. I assume there must be reason to do extra...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as it is, since we're setting body instead of postPayload, the stringify param is no longer used

body: JSON.stringify(postPayload),
headers: { 'Content-Type': 'application/json' },
parseMethod: 'text',
})
.then(({ text = '{}' }) => {
Expand Down
28 changes: 15 additions & 13 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2675,34 +2675,36 @@ def _sql_json_sync(
return json_success(payload)

@has_access_api
@expose("/sql_json/", methods=["POST", "GET"])
@expose("/sql_json/", methods=["POST"])
etr2460 marked this conversation as resolved.
Show resolved Hide resolved
@event_logger.log_this
def sql_json(self):
"""Runs arbitrary sql and returns and json"""
# Collect Values
database_id: int = int(request.form.get("database_id"))
schema: str = request.form.get("schema")
sql: str = request.form.get("sql")
database_id: int = request.json.get("database_id")
schema: str = request.json.get("schema")
sql: str = request.json.get("sql")
try:
template_params: dict = json.loads(request.form.get("templateParams", "{}"))
template_params: dict = json.loads(
request.json.get("templateParams") or "{}"
)
except json.decoder.JSONDecodeError:
logging.warning(
f"Invalid template parameter {request.form.get('templateParams')}"
f"Invalid template parameter {request.json.get('templateParams')}"
" specified. Defaulting to empty dict"
)
template_params = {}
limit = int(request.form.get("queryLimit", app.config.get("SQL_MAX_ROW")))
async_flag: bool = request.form.get("runAsync") == "true"
limit = request.json.get("queryLimit") or app.config.get("SQL_MAX_ROW")
async_flag: bool = request.json.get("runAsync")
if limit < 0:
logging.warning(
f"Invalid limit of {limit} specified. Defaulting to max limit."
)
limit = 0
select_as_cta: bool = request.form.get("select_as_cta") == "true"
tmp_table_name: str = request.form.get("tmp_table_name")
client_id: str = request.form.get("client_id") or utils.shortid()[:10]
sql_editor_id: str = request.form.get("sql_editor_id")
tab_name: str = request.form.get("tab")
select_as_cta: bool = request.json.get("select_as_cta")
tmp_table_name: str = request.json.get("tmp_table_name")
client_id: str = request.json.get("client_id") or utils.shortid()[:10]
sql_editor_id: str = request.json.get("sql_editor_id")
tab_name: str = request.json.get("tab")
status: bool = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING

session = db.session()
Expand Down
14 changes: 10 additions & 4 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,25 @@ def get_datasource_mock(self):
datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
return datasource

def get_resp(self, url, data=None, follow_redirects=True, raise_on_error=True):
def get_resp(
self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
):
"""Shortcut to get the parsed results while following redirects"""
if data:
resp = self.client.post(url, data=data, follow_redirects=follow_redirects)
elif json_:
resp = self.client.post(url, json=json_, follow_redirects=follow_redirects)
else:
resp = self.client.get(url, follow_redirects=follow_redirects)
if raise_on_error and resp.status_code > 400:
raise Exception("http request failed with code {}".format(resp.status_code))
return resp.data.decode("utf-8")

def get_json_resp(self, url, data=None, follow_redirects=True, raise_on_error=True):
def get_json_resp(
self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
):
"""Shortcut to get the parsed results while following redirects"""
resp = self.get_resp(url, data, follow_redirects, raise_on_error)
resp = self.get_resp(url, data, follow_redirects, raise_on_error, json_)
return json.loads(resp)

def get_access_requests(self, username, ds_type, ds_id):
Expand Down Expand Up @@ -190,7 +196,7 @@ def run_sql(
resp = self.get_json_resp(
"/superset/sql_json/",
raise_on_error=False,
data=dict(
json_=dict(
database_id=dbid,
sql=sql,
select_as_create_as=False,
Expand Down
14 changes: 6 additions & 8 deletions tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def tearDownClass(cls):
)

def run_sql(
self, db_id, sql, client_id=None, cta="false", tmp_table="tmp", async_="false"
self, db_id, sql, client_id=None, cta=False, tmp_table="tmp", async_=False
):
self.login()
resp = self.client.post(
"/superset/sql_json/",
data=dict(
json=dict(
database_id=db_id,
sql=sql,
runAsync=async_,
Expand All @@ -135,7 +135,7 @@ def test_run_sync_query_dont_exist(self):
main_db = get_example_database()
db_id = main_db.id
sql_dont_exist = "SELECT name FROM table_dont_exist"
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta="true")
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta=True)
self.assertTrue("error" in result1)

def test_run_sync_query_cta(self):
Expand All @@ -146,9 +146,7 @@ def test_run_sync_query_cta(self):
self.drop_table_if_exists(tmp_table_name, main_db)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(
db_id, sql_where, "2", tmp_table=tmp_table_name, cta="true"
)
result = self.run_sql(db_id, sql_where, "2", tmp_table=tmp_table_name, cta=True)
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"])
self.assertEqual([], result["data"])
self.assertEqual([], result["columns"])
Expand Down Expand Up @@ -190,7 +188,7 @@ def test_run_async_query(self):

sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id, sql_where, "4", async_="true", tmp_table="tmp_async_1", cta="true"
db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True
)
assert result["query"]["state"] in (
QueryStatus.PENDING,
Expand Down Expand Up @@ -224,7 +222,7 @@ def test_run_async_query_with_lower_limit(self):

sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql(
db_id, sql_where, "5", async_="true", tmp_table=tmp_table, cta="true"
db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True
)
assert result["query"]["state"] in (
QueryStatus.PENDING,
Expand Down