From e33a08693bf789284d21f493074263712f17116f Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Mon, 7 Nov 2022 10:33:24 +0000 Subject: [PATCH] fix: datasource save, improve data validation (#22038) --- superset/config.py | 3 ++ superset/utils/urls.py | 19 ++++++++++- superset/views/datasource/views.py | 17 ++++++++- tests/integration_tests/datasource_tests.py | 38 +++++++++++++++++++++ tests/unit_tests/utils/urls_tests.py | 26 ++++++++++++++ 5 files changed, 101 insertions(+), 2 deletions(-) diff --git a/superset/config.py b/superset/config.py index 36e5e547db9ef..f163997c6ee4b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1328,6 +1328,9 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # Typically these should not be allowed. PREVENT_UNSAFE_DB_CONNECTIONS = True +# Prevents unsafe default endpoints to be registered on datasets. +PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True + # Path used to store SSL certificates that are generated when using custom certs. # Defaults to temporary directory. # Example: SSL_CERT_PATH = "/certs" diff --git a/superset/utils/urls.py b/superset/utils/urls.py index a8a6148813d96..c31bfb1a5103c 100644 --- a/superset/utils/urls.py +++ b/superset/utils/urls.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import unicodedata import urllib from typing import Any +from urllib.parse import urlparse -from flask import current_app, url_for +from flask import current_app, request, url_for def get_url_host(user_friendly: bool = False) -> str: @@ -48,3 +50,18 @@ def modify_url_query(url: str, **kwargs: Any) -> str: parts[3] = "&".join(f"{k}={urllib.parse.quote(v[0])}" for k, v in params.items()) return urllib.parse.urlunsplit(parts) + + +def is_safe_url(url: str) -> bool: + if url.startswith("///"): + return False + try: + ref_url = urlparse(request.host_url) + test_url = urlparse(url) + except ValueError: + return False + if unicodedata.category(url[0])[0] == "C": + return False + if test_url.scheme != ref_url.scheme or ref_url.netloc != test_url.netloc: + return False + return True diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 2c137fab79610..c2db174cb1daf 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -18,7 +18,7 @@ from collections import Counter from typing import Any -from flask import redirect, request +from flask import current_app, redirect, request from flask_appbuilder import expose, permission_name from flask_appbuilder.api import rison from flask_appbuilder.security.decorators import has_access, has_access_api @@ -40,6 +40,7 @@ from superset.models.core import Database from superset.superset_typing import FlaskResponse from superset.utils.core import DatasourceType +from superset.utils.urls import is_safe_url from superset.views.base import ( api, BaseSupersetView, @@ -77,6 +78,20 @@ def save(self) -> FlaskResponse: datasource_id = datasource_dict.get("id") datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") + default_endpoint = datasource_dict["default_endpoint"] + if ( + default_endpoint + and not is_safe_url(default_endpoint) + and current_app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] + ): + return json_error_response( + _( + "The submitted URL is not considered safe," + " only use URLs with the same domain as Superset." + ), + status=400, + ) + orm_datasource = DatasourceDAO.get_datasource( db.session, DatasourceType(datasource_type), datasource_id ) diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 0896971743a34..edee0028467f1 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -297,6 +297,44 @@ def test_save(self): print(k) self.assertEqual(resp[k], datasource_post[k]) + def test_save_default_endpoint_validation_fail(self): + self.login(username="admin") + tbl_id = self.get_table(name="birth_names").id + + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["owners"] = [1] + datasource_post["default_endpoint"] = "http://www.google.com" + data = dict(data=json.dumps(datasource_post)) + resp = self.client.post("/datasource/save/", data=data) + assert resp.status_code == 400 + + def test_save_default_endpoint_validation_unsafe(self): + self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = False + self.login(username="admin") + tbl_id = self.get_table(name="birth_names").id + + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["owners"] = [1] + datasource_post["default_endpoint"] = "http://www.google.com" + data = dict(data=json.dumps(datasource_post)) + resp = self.client.post("/datasource/save/", data=data) + assert resp.status_code == 200 + self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = True + + def test_save_default_endpoint_validation_success(self): + self.login(username="admin") + tbl_id = self.get_table(name="birth_names").id + + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["owners"] = [1] + datasource_post["default_endpoint"] = "http://localhost/superset/1" + data = dict(data=json.dumps(datasource_post)) + resp = self.client.post("/datasource/save/", data=data) + assert resp.status_code == 200 + def save_datasource_from_dict(self, datasource_post): data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data) diff --git a/tests/unit_tests/utils/urls_tests.py b/tests/unit_tests/utils/urls_tests.py index f62c276f89ae6..a3893953b8ba1 100644 --- a/tests/unit_tests/utils/urls_tests.py +++ b/tests/unit_tests/utils/urls_tests.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + from superset.utils.urls import modify_url_query EXPLORE_CHART_LINK = "http://localhost:9000/explore/?form_data=%7B%22slice_id%22%3A+76%7D&standalone=true&force=false" @@ -33,3 +35,27 @@ def test_convert_chart_link() -> None: def test_convert_dashboard_link() -> None: test_url = modify_url_query(EXPLORE_DASHBOARD_LINK, standalone="0") assert test_url == "http://localhost:9000/superset/dashboard/3/?standalone=0" + + +@pytest.mark.parametrize( + "url,is_safe", + [ + ("http://localhost/", True), + ("http://localhost/superset/1", True), + ("https://localhost/", False), + ("https://localhost/superset/1", False), + ("localhost/superset/1", False), + ("ftp://localhost/superset/1", False), + ("http://external.com", False), + ("https://external.com", False), + ("external.com", False), + ("///localhost", False), + ("xpto://localhost:[3/1/", False), + ], +) +def test_is_safe_url(url: str, is_safe: bool) -> None: + from superset import app + from superset.utils.urls import is_safe_url + + with app.test_request_context("/"): + assert is_safe_url(url) == is_safe