Skip to content

Commit

Permalink
fix: datasource save, improve data validation (#22038)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar authored Nov 7, 2022
1 parent 358a4ec commit e33a086
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
3 changes: 3 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,9 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Typically these should not be allowed.
PREVENT_UNSAFE_DB_CONNECTIONS = True

# Prevents unsafe default endpoints to be registered on datasets.
PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True

# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
Expand Down
19 changes: 18 additions & 1 deletion superset/utils/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unicodedata
import urllib
from typing import Any
from urllib.parse import urlparse

from flask import current_app, url_for
from flask import current_app, request, url_for


def get_url_host(user_friendly: bool = False) -> str:
Expand Down Expand Up @@ -48,3 +50,18 @@ def modify_url_query(url: str, **kwargs: Any) -> str:

parts[3] = "&".join(f"{k}={urllib.parse.quote(v[0])}" for k, v in params.items())
return urllib.parse.urlunsplit(parts)


def is_safe_url(url: str) -> bool:
if url.startswith("///"):
return False
try:
ref_url = urlparse(request.host_url)
test_url = urlparse(url)
except ValueError:
return False
if unicodedata.category(url[0])[0] == "C":
return False
if test_url.scheme != ref_url.scheme or ref_url.netloc != test_url.netloc:
return False
return True
17 changes: 16 additions & 1 deletion superset/views/datasource/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import Counter
from typing import Any

from flask import redirect, request
from flask import current_app, redirect, request
from flask_appbuilder import expose, permission_name
from flask_appbuilder.api import rison
from flask_appbuilder.security.decorators import has_access, has_access_api
Expand All @@ -40,6 +40,7 @@
from superset.models.core import Database
from superset.superset_typing import FlaskResponse
from superset.utils.core import DatasourceType
from superset.utils.urls import is_safe_url
from superset.views.base import (
api,
BaseSupersetView,
Expand Down Expand Up @@ -77,6 +78,20 @@ def save(self) -> FlaskResponse:
datasource_id = datasource_dict.get("id")
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
default_endpoint = datasource_dict["default_endpoint"]
if (
default_endpoint
and not is_safe_url(default_endpoint)
and current_app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"]
):
return json_error_response(
_(
"The submitted URL is not considered safe,"
" only use URLs with the same domain as Superset."
),
status=400,
)

orm_datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
)
Expand Down
38 changes: 38 additions & 0 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,44 @@ def test_save(self):
print(k)
self.assertEqual(resp[k], datasource_post[k])

def test_save_default_endpoint_validation_fail(self):
self.login(username="admin")
tbl_id = self.get_table(name="birth_names").id

datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["owners"] = [1]
datasource_post["default_endpoint"] = "http://www.google.com"
data = dict(data=json.dumps(datasource_post))
resp = self.client.post("/datasource/save/", data=data)
assert resp.status_code == 400

def test_save_default_endpoint_validation_unsafe(self):
self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = False
self.login(username="admin")
tbl_id = self.get_table(name="birth_names").id

datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["owners"] = [1]
datasource_post["default_endpoint"] = "http://www.google.com"
data = dict(data=json.dumps(datasource_post))
resp = self.client.post("/datasource/save/", data=data)
assert resp.status_code == 200
self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = True

def test_save_default_endpoint_validation_success(self):
self.login(username="admin")
tbl_id = self.get_table(name="birth_names").id

datasource_post = get_datasource_post()
datasource_post["id"] = tbl_id
datasource_post["owners"] = [1]
datasource_post["default_endpoint"] = "http://localhost/superset/1"
data = dict(data=json.dumps(datasource_post))
resp = self.client.post("/datasource/save/", data=data)
assert resp.status_code == 200

def save_datasource_from_dict(self, datasource_post):
data = dict(data=json.dumps(datasource_post))
resp = self.get_json_resp("/datasource/save/", data)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit_tests/utils/urls_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from superset.utils.urls import modify_url_query

EXPLORE_CHART_LINK = "http://localhost:9000/explore/?form_data=%7B%22slice_id%22%3A+76%7D&standalone=true&force=false"
Expand All @@ -33,3 +35,27 @@ def test_convert_chart_link() -> None:
def test_convert_dashboard_link() -> None:
test_url = modify_url_query(EXPLORE_DASHBOARD_LINK, standalone="0")
assert test_url == "http://localhost:9000/superset/dashboard/3/?standalone=0"


@pytest.mark.parametrize(
"url,is_safe",
[
("http://localhost/", True),
("http://localhost/superset/1", True),
("https://localhost/", False),
("https://localhost/superset/1", False),
("localhost/superset/1", False),
("ftp://localhost/superset/1", False),
("http://external.com", False),
("https://external.com", False),
("external.com", False),
("///localhost", False),
("xpto://localhost:[3/1/", False),
],
)
def test_is_safe_url(url: str, is_safe: bool) -> None:
from superset import app
from superset.utils.urls import is_safe_url

with app.test_request_context("/"):
assert is_safe_url(url) == is_safe

0 comments on commit e33a086

Please sign in to comment.