From 84910b6ac8b640473c1c3db2f5cf4c03561d4109 Mon Sep 17 00:00:00 2001 From: Elizabeth Thompson Date: Fri, 8 Jul 2022 17:30:19 -0700 Subject: [PATCH] use existing row when id is found --- superset/connectors/sqla/models.py | 35 +++-- .../commands/importers/v1/import_test.py | 123 ++++++++++++++++++ 2 files changed, 144 insertions(+), 14 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 67a2f97c841d0..b5bfe1ffda3bf 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -453,20 +453,21 @@ def to_sl_column( if value: extra_json[attr] = value + # column id is primary key, so make sure that we check uuid against + # the id as well if not column.id: with session.no_autoflush: - saved_column = ( + saved_column: NewColumn = ( session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none() ) - if saved_column: + if saved_column is not None: logger.warning( - "sl_column already exists. Assigning existing id %s", self + "sl_column already exists. Using this row for db update %s", + self, ) - # uuid isn't a primary key, so add the id of the existing column to - # ensure that the column is modified instead of created - # in order to avoid a uuid collision - column.id = saved_column.id + # overwrite the existing column instead of creating a new one + column = saved_column column.uuid = self.uuid column.created_on = self.created_on @@ -534,6 +535,9 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): update_from_object_fields = list(s for s in export_fields if s != "table_id") export_parent = "table" + def __repr__(self) -> str: + return str(self.metric_name) + def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.metric_name tp = self.table.get_template_processor() @@ -585,19 +589,22 @@ def to_sl_column( self.metric_type and self.metric_type.lower() in ADDITIVE_METRIC_TYPES_LOWER ) + # column id is primary key, so make sure that we check uuid against + # the id as well if not column.id: with session.no_autoflush: - saved_column = ( + saved_column: NewColumn = ( session.query(NewColumn).filter_by(uuid=self.uuid).one_or_none() ) - if saved_column: + + if saved_column is not None: logger.warning( - "sl_column already exists. Assigning existing id %s", self + "sl_column already exists. Using this row for db update %s", + self, ) - # uuid isn't a primary key, so add the id of the existing column to - # ensure that the column is modified instead of created - # in order to avoid a uuid collision - column.id = saved_column.id + + # overwrite the existing column instead of creating a new one + column = saved_column column.uuid = self.uuid column.name = self.metric_name diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index 07ea8c49d04d9..996c0d3c41ad3 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -137,6 +137,129 @@ def test_import_dataset(app_context: None, session: Session) -> None: assert sqla_table.database.id == database.id +def test_import_dataset_duplicate_column(app_context: None, session: Session) -> None: + """ + Test importing a dataset with a column that already exists. + """ + from superset.columns.models import Column as NewColumn + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.commands.importers.v1.utils import import_dataset + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + dataset_uuid = uuid.uuid4() + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + + session.add(database) + session.flush() + + dataset = SqlaTable( + uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id + ) + column = TableColumn(column_name="existing_column") + session.add(dataset) + session.add(column) + session.flush() + + config = { + "table_name": dataset.table_name, + "main_dttm_col": "ds", + "description": "This is the description", + "default_endpoint": None, + "offset": -8, + "cache_timeout": 3600, + "schema": "my_schema", + "sql": None, + "params": { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + }, + "template_params": { + "answer": "42", + }, + "filter_select_enabled": True, + "fetch_values_predicate": "foo IN (1, 2)", + "extra": {"warning_markdown": "*WARNING*"}, + "uuid": dataset_uuid, + "metrics": [ + { + "metric_name": "cnt", + "verbose_name": None, + "metric_type": None, + "expression": "COUNT(*)", + "description": None, + "d3format": None, + "extra": {"warning_markdown": None}, + "warning_text": None, + } + ], + "columns": [ + { + "column_name": column.column_name, + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "INTEGER", + "groupby": None, + "filterable": None, + "expression": "revenue-expenses", + "description": None, + "python_date_format": None, + "extra": { + "certified_by": "User", + }, + } + ], + "database_uuid": database.uuid, + "database_id": database.id, + } + + sqla_table = import_dataset(session, config, overwrite=True) + assert sqla_table.table_name == dataset.table_name + assert sqla_table.main_dttm_col == "ds" + assert sqla_table.description == "This is the description" + assert sqla_table.default_endpoint is None + assert sqla_table.offset == -8 + assert sqla_table.cache_timeout == 3600 + assert sqla_table.schema == "my_schema" + assert sqla_table.sql is None + assert sqla_table.params == json.dumps( + {"remote_id": 64, "database_name": "examples", "import_time": 1606677834} + ) + assert sqla_table.template_params == json.dumps({"answer": "42"}) + assert sqla_table.filter_select_enabled is True + assert sqla_table.fetch_values_predicate == "foo IN (1, 2)" + assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}' + assert sqla_table.uuid == dataset_uuid + assert len(sqla_table.metrics) == 1 + assert sqla_table.metrics[0].metric_name == "cnt" + assert sqla_table.metrics[0].verbose_name is None + assert sqla_table.metrics[0].metric_type is None + assert sqla_table.metrics[0].expression == "COUNT(*)" + assert sqla_table.metrics[0].description is None + assert sqla_table.metrics[0].d3format is None + assert sqla_table.metrics[0].extra == '{"warning_markdown": null}' + assert sqla_table.metrics[0].warning_text is None + assert len(sqla_table.columns) == 1 + assert sqla_table.columns[0].column_name == column.column_name + assert sqla_table.columns[0].verbose_name is None + assert sqla_table.columns[0].is_dttm is False + assert sqla_table.columns[0].is_active is True + assert sqla_table.columns[0].type == "INTEGER" + assert sqla_table.columns[0].groupby is True + assert sqla_table.columns[0].filterable is True + assert sqla_table.columns[0].expression == "revenue-expenses" + assert sqla_table.columns[0].description is None + assert sqla_table.columns[0].python_date_format is None + assert sqla_table.columns[0].extra == '{"certified_by": "User"}' + assert sqla_table.database.uuid == database.uuid + assert sqla_table.database.id == database.id + + def test_import_column_extra_is_string(app_context: None, session: Session) -> None: """ Test importing a dataset when the column extra is a string.