diff --git a/superset/charts/api.py b/superset/charts/api.py index dd11d3bb1a5ce..e59c3a41bb733 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -19,7 +19,7 @@ from datetime import datetime from io import BytesIO from typing import Any, Optional -from zipfile import ZipFile +from zipfile import is_zipfile, ZipFile from flask import g, redirect, request, Response, send_file, url_for from flask_appbuilder.api import expose, protect, rison, safe @@ -64,7 +64,10 @@ screenshot_query_schema, thumbnail_query_schema, ) -from superset.commands.importers.exceptions import NoValidFilesFoundError +from superset.commands.importers.exceptions import ( + IncorrectFormatError, + NoValidFilesFoundError, +) from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.extensions import event_logger @@ -868,6 +871,8 @@ def import_(self) -> Response: upload = request.files.get("formData") if not upload: return self.response_400() + if not is_zipfile(upload): + raise IncorrectFormatError("Not a ZIP file") with ZipFile(upload) as bundle: contents = get_contents_from_bundle(bundle) diff --git a/superset/databases/api.py b/superset/databases/api.py index 1afa71c6f056b..b9fe4ca3d7e36 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -20,7 +20,7 @@ from datetime import datetime from io import BytesIO from typing import Any, Dict, List, Optional -from zipfile import ZipFile +from zipfile import is_zipfile, ZipFile from flask import g, request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe @@ -29,7 +29,10 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError from superset import app, event_logger -from superset.commands.importers.exceptions import NoValidFilesFoundError +from superset.commands.importers.exceptions import ( + IncorrectFormatError, + NoValidFilesFoundError, +) from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.databases.commands.create import CreateDatabaseCommand @@ -958,6 +961,8 @@ def import_(self) -> Response: upload = request.files.get("formData") if not upload: return self.response_400() + if not is_zipfile(upload): + raise IncorrectFormatError("Not a ZIP file") with ZipFile(upload) as bundle: contents = get_contents_from_bundle(bundle) diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index 04df2345c1c2f..19d4191f257b7 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -19,14 +19,17 @@ from datetime import datetime from io import BytesIO from typing import Any -from zipfile import ZipFile +from zipfile import is_zipfile, ZipFile from flask import g, request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext -from superset.commands.importers.exceptions import NoValidFilesFoundError +from superset.commands.importers.exceptions import ( + IncorrectFormatError, + NoValidFilesFoundError, +) from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.databases.filters import DatabaseFilter @@ -325,6 +328,8 @@ def import_(self) -> Response: upload = request.files.get("formData") if not upload: return self.response_400() + if not is_zipfile(upload): + raise IncorrectFormatError("Not a ZIP file") with ZipFile(upload) as bundle: contents = get_contents_from_bundle(bundle) diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py new file mode 100644 index 0000000000000..d6f8897c4a090 --- /dev/null +++ b/tests/unit_tests/databases/api_test.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=unused-argument, import-outside-toplevel, line-too-long + +import json +from io import BytesIO +from typing import Any +from uuid import UUID + +import pytest +from pytest_mock import MockFixture +from sqlalchemy.orm.session import Session + + +def test_post_with_uuid( + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we can set the database UUID when creating it. + """ + from superset.models.core import Database + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + response = client.post( + "/api/v1/database/", + json={ + "database_name": "my_db", + "sqlalchemy_uri": "sqlite://", + "uuid": "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb", + }, + ) + assert response.status_code == 201 + + database = session.query(Database).one() + assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb") + + +def test_password_mask( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that sensitive information is masked. + """ + from superset.databases.api import DatabaseRestApi + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "114567578578109757129", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + + response = client.get("/api/v1/database/1") + assert ( + response.json["result"]["parameters"]["service_account_info"]["private_key"] + == "XXXXXXXXXX" + ) + assert "encrypted_extra" not in response.json["result"] + + +@pytest.mark.skip(reason="Works locally but fails on CI") +def test_update_with_password_mask( + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that an update with a masked password doesn't overwrite the existing password. + """ + from superset.databases.api import DatabaseRestApi + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "project_id": "black-sanctum-314419", + "private_key": "SECRET", + }, + } + ), + ) + session.add(database) + session.commit() + + client.put( + "/api/v1/database/1", + json={ + "encrypted_extra": json.dumps( + { + "service_account_info": { + "project_id": "yellow-unicorn-314419", + "private_key": "XXXXXXXXXX", + }, + } + ), + }, + ) + database = session.query(Database).one() + assert ( + database.encrypted_extra + == '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}' + ) + + +def test_non_zip_import(client: Any, full_api_access: None) -> None: + """ + Test that non-ZIP imports are not allowed. + """ + buf = BytesIO(b"definitely_not_a_zip_file") + form_data = { + "formData": (buf, "evil.pdf"), + } + response = client.post( + "/api/v1/database/import/", + data=form_data, + content_type="multipart/form-data", + ) + assert response.status_code == 422 + assert response.json == { + "errors": [ + { + "message": "Not a ZIP file", + "error_type": "GENERIC_COMMAND_ERROR", + "level": "warning", + "extra": { + "issue_codes": [ + { + "code": 1010, + "message": "Issue 1010 - Superset encountered an error while running a command.", + } + ] + }, + } + ] + } diff --git a/tests/unit_tests/importexport/api_test.py b/tests/unit_tests/importexport/api_test.py index e5dee975d8cd8..1ef4309fc21b9 100644 --- a/tests/unit_tests/importexport/api_test.py +++ b/tests/unit_tests/importexport/api_test.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, import-outside-toplevel + +# pylint: disable=invalid-name, import-outside-toplevel, unused-argument import json from io import BytesIO