From e273d17e3d4207383a16d3f1ae57e8dfe40b79c1 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Wed, 13 Jan 2021 13:39:28 +0200 Subject: [PATCH] feat(db-engine-specs): add support for Postgres root cert (#11720) * feat(db-engine-specs): add support for Postgres root cert * remove logging of json decode exception message * fix error message * fix error message --- .../CRUD/data/database/DatabaseModal.tsx | 4 ++- superset/databases/schemas.py | 2 +- superset/db_engine_specs/druid.py | 7 +++-- superset/db_engine_specs/postgres.py | 29 ++++++++++++++++++ tests/databases/api_tests.py | 4 +-- tests/db_engine_specs/druid_tests.py | 20 +++++++++++++ tests/db_engine_specs/postgres_tests.py | 30 +++++++++++++++++++ tests/fixtures/database.py | 22 ++++++++++++++ 8 files changed, 111 insertions(+), 7 deletions(-) create mode 100644 tests/fixtures/database.py diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx index e9f994246845a..d38f92e0c9be8 100644 --- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx +++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal.tsx @@ -182,7 +182,9 @@ const DatabaseModal: FunctionComponent = ({ ? `${t('ERROR: ')}${ typeof error.message === 'string' ? error.message - : (error.message as Record).sqlalchemy_uri + : Object.entries(error.message as Record) + .map(([key, value]) => `(${key}) ${value.join(', ')}`) + .join('\n') }` : t('ERROR: Connection failed. '), ); diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index c046aef7ab06e..2c705f35b1fa5 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -139,7 +139,7 @@ def sqlalchemy_uri_validator(value: str) -> str: [ _( "Invalid connection string, a valid string usually follows: " - "dirver://user:password@database-host/database-name" + "driver://user:password@database-host/database-name" ) ] ) diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 1a774983987c9..3ab2dedb4c0d1 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING from superset.db_engine_specs.base import BaseEngineSpec +from superset.exceptions import SupersetException from superset.utils import core as utils if TYPE_CHECKING: @@ -65,12 +66,12 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: :param database: database instance from which to extract extras :raises CertificateException: If certificate is not valid/unparseable + :raises SupersetException: If database extra json payload is unparseable """ try: extra = json.loads(database.extra or "{}") - except json.JSONDecodeError as ex: - logger.error(ex) - raise ex + except json.JSONDecodeError: + raise SupersetException("Unable to parse database extras") if database.server_cert: engine_params = extra.get("engine_params", {}) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index f34128f1f2255..a63ffdd8b707e 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json +import logging import re from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING @@ -22,11 +24,14 @@ from sqlalchemy.dialects.postgresql.base import PGInspector from superset.db_engine_specs.base import BaseEngineSpec +from superset.exceptions import SupersetException from superset.utils import core as utils if TYPE_CHECKING: from superset.models.core import Database # pragma: no cover +logger = logging.getLogger() + # Replace psycopg2.tz.FixedOffsetTimezone with pytz, which is serializable by PyArrow # https://github.com/stub42/pytz/blob/b70911542755aeeea7b5a9e066df5e1c87e8f2c8/src/pytz/reference.py#L25 @@ -115,3 +120,27 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: dttm_formatted = dttm.isoformat(sep=" ", timespec="microseconds") return f"""TO_TIMESTAMP('{dttm_formatted}', 'YYYY-MM-DD HH24:MI:SS.US')""" return None + + @staticmethod + def get_extra_params(database: "Database") -> Dict[str, Any]: + """ + For Postgres, the path to a SSL certificate is placed in `connect_args`. + + :param database: database instance from which to extract extras + :raises CertificateException: If certificate is not valid/unparseable + :raises SupersetException: If database extra json payload is unparseable + """ + try: + extra = json.loads(database.extra or "{}") + except json.JSONDecodeError: + raise SupersetException("Unable to parse database extras") + + if database.server_cert: + engine_params = extra.get("engine_params", {}) + connect_args = engine_params.get("connect_args", {}) + connect_args["sslmode"] = connect_args.get("sslmode", "verify-full") + path = utils.create_ssl_cert_file(database.server_cert) + connect_args["sslrootcert"] = path + engine_params["connect_args"] = connect_args + extra["engine_params"] = engine_params + return extra diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index 46af4f06293c0..1c55a4cce123f 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -200,7 +200,7 @@ def test_create_database(self): database_data = { "database_name": "test-create-database", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, - "server_cert": ssl_certificate, + "server_cert": None, "extra": json.dumps(extra), } @@ -761,7 +761,7 @@ def test_test_connection(self): "extra": json.dumps(extra), "impersonate_user": False, "sqlalchemy_uri": example_db.safe_sqlalchemy_uri(), - "server_cert": ssl_certificate, + "server_cert": None, } url = "api/v1/database/test_connection" rv = self.post_assert_metric(url, data, "test_connection") diff --git a/tests/db_engine_specs/druid_tests.py b/tests/db_engine_specs/druid_tests.py index fb9f3f4d34c37..3ff561640b6e7 100644 --- a/tests/db_engine_specs/druid_tests.py +++ b/tests/db_engine_specs/druid_tests.py @@ -14,10 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from unittest import mock + from sqlalchemy import column from superset.db_engine_specs.druid import DruidEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.fixtures.certificates import ssl_certificate +from tests.fixtures.database import default_db_extra class TestDruidDbEngineSpec(TestDbEngineSpec): @@ -54,3 +58,19 @@ def test_timegrain_expressions(self): col=sqla_col, pdf=None, time_grain=grain ) self.assertEqual(str(actual), expected) + + def test_extras_without_ssl(self): + db = mock.Mock() + db.extra = default_db_extra + db.server_cert = None + extras = DruidEngineSpec.get_extra_params(db) + assert "connect_args" not in extras["engine_params"] + + def test_extras_with_ssl(self): + db = mock.Mock() + db.extra = default_db_extra + db.server_cert = ssl_certificate + extras = DruidEngineSpec.get_extra_params(db) + connect_args = extras["engine_params"]["connect_args"] + assert connect_args["scheme"] == "https" + assert "ssl_verify_cert" in connect_args diff --git a/tests/db_engine_specs/postgres_tests.py b/tests/db_engine_specs/postgres_tests.py index 3f45f45254062..4f362d010da42 100644 --- a/tests/db_engine_specs/postgres_tests.py +++ b/tests/db_engine_specs/postgres_tests.py @@ -22,6 +22,8 @@ from superset.db_engine_specs import engines from superset.db_engine_specs.postgres import PostgresEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec +from tests.fixtures.certificates import ssl_certificate +from tests.fixtures.database import default_db_extra class TestPostgresDbEngineSpec(TestDbEngineSpec): @@ -124,3 +126,31 @@ def test_engine_alias_name(self): DB Eng Specs (postgres): Test "postgres" in engine spec """ self.assertIn("postgres", engines) + + def test_extras_without_ssl(self): + db = mock.Mock() + db.extra = default_db_extra + db.server_cert = None + extras = PostgresEngineSpec.get_extra_params(db) + assert "connect_args" not in extras["engine_params"] + + def test_extras_with_ssl_default(self): + db = mock.Mock() + db.extra = default_db_extra + db.server_cert = ssl_certificate + extras = PostgresEngineSpec.get_extra_params(db) + connect_args = extras["engine_params"]["connect_args"] + assert connect_args["sslmode"] == "verify-full" + assert "sslrootcert" in connect_args + + def test_extras_with_ssl_custom(self): + db = mock.Mock() + db.extra = default_db_extra.replace( + '"engine_params": {}', + '"engine_params": {"connect_args": {"sslmode": "verify-ca"}}', + ) + db.server_cert = ssl_certificate + extras = PostgresEngineSpec.get_extra_params(db) + connect_args = extras["engine_params"]["connect_args"] + assert connect_args["sslmode"] == "verify-ca" + assert "sslrootcert" in connect_args diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py new file mode 100644 index 0000000000000..dcb423e5465e0 --- /dev/null +++ b/tests/fixtures/database.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +default_db_extra = """{ + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_csv_upload": [] +}"""