From a849c292888f32b8912a1a06dfe1591706a53220 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Mon, 7 Oct 2024 13:17:27 -0700 Subject: [PATCH] chore: enable lint PT009 'use regular assert over self.assert.*' (#30521) --- pyproject.toml | 1 + .../async_events/api_tests.py | 6 +- tests/integration_tests/base_api_tests.py | 35 +- tests/integration_tests/cache_tests.py | 24 +- tests/integration_tests/charts/api_tests.py | 400 +++++++-------- .../charts/commands_tests.py | 16 +- .../charts/data/api_tests.py | 69 ++- .../integration_tests/charts/schema_tests.py | 4 +- tests/integration_tests/core_tests.py | 168 +++---- tests/integration_tests/dashboard_tests.py | 20 +- .../integration_tests/dashboards/api_tests.py | 469 +++++++++-------- .../integration_tests/dashboards/base_case.py | 10 +- .../integration_tests/dashboards/dao_tests.py | 30 +- .../security/security_dataset_tests.py | 16 +- .../security/security_rbac_tests.py | 6 +- .../integration_tests/databases/api_tests.py | 381 +++++++------- .../databases/commands_tests.py | 6 +- tests/integration_tests/datasets/api_tests.py | 75 ++- .../datasets/commands_tests.py | 6 +- .../integration_tests/datasource/api_tests.py | 43 +- tests/integration_tests/datasource_tests.py | 91 ++-- .../db_engine_specs/ascend_tests.py | 10 +- .../db_engine_specs/base_engine_spec_tests.py | 65 ++- .../db_engine_specs/base_tests.py | 2 +- .../db_engine_specs/bigquery_tests.py | 12 +- .../db_engine_specs/elasticsearch_tests.py | 2 +- .../db_engine_specs/mysql_tests.py | 6 +- .../db_engine_specs/pinot_tests.py | 30 +- .../db_engine_specs/postgres_tests.py | 42 +- .../db_engine_specs/presto_tests.py | 46 +- .../dict_import_export_tests.py | 68 ++- .../dynamic_plugins_tests.py | 4 +- tests/integration_tests/email_tests.py | 4 +- tests/integration_tests/embedded/dao_tests.py | 10 +- tests/integration_tests/event_logger_tests.py | 108 ++-- tests/integration_tests/form_tests.py | 12 +- .../integration_tests/import_export_tests.py | 198 ++++---- tests/integration_tests/log_api_tests.py | 137 +++-- .../logging_configurator_tests.py | 2 +- tests/integration_tests/model_tests.py | 67 +-- tests/integration_tests/queries/api_tests.py | 28 +- .../queries/saved_queries/api_tests.py | 18 +- .../integration_tests/query_context_tests.py | 105 ++-- tests/integration_tests/reports/api_tests.py | 6 +- tests/integration_tests/result_set_tests.py | 333 ++++++------ tests/integration_tests/security/api_tests.py | 6 +- .../security/guest_token_security_tests.py | 30 +- .../security/row_level_security_tests.py | 78 +-- tests/integration_tests/security_tests.py | 475 +++++++++--------- tests/integration_tests/sql_lab/api_tests.py | 48 +- .../integration_tests/sql_validator_tests.py | 4 +- tests/integration_tests/sqla_models_tests.py | 34 +- tests/integration_tests/sqllab_tests.py | 80 ++- tests/integration_tests/strategy_tests.py | 12 +- tests/integration_tests/tagging_tests.py | 68 +-- tests/integration_tests/tags/api_tests.py | 66 +-- tests/integration_tests/thumbnails_tests.py | 36 +- tests/integration_tests/users/api_tests.py | 18 +- .../integration_tests/utils/encrypt_tests.py | 12 +- .../utils/machine_auth_tests.py | 2 +- tests/integration_tests/utils_tests.py | 238 +++++---- tests/integration_tests/viz_tests.py | 240 ++++----- 62 files changed, 2217 insertions(+), 2421 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e39b9da690f03..ada303a4e9662 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -446,6 +446,7 @@ select = [ "E7", "E9", "F", + "PT009", "TRY201", ] ignore = [] diff --git a/tests/integration_tests/async_events/api_tests.py b/tests/integration_tests/async_events/api_tests.py index 5a8189f9a7277..8397b8cf977c8 100644 --- a/tests/integration_tests/async_events/api_tests.py +++ b/tests/integration_tests/async_events/api_tests.py @@ -59,7 +59,7 @@ def _test_events_logic(self, mock_cache): assert rv.status_code == 200 channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID mock_xrange.assert_called_with(channel_id, "-", "+", 100) - self.assertEqual(response, {"result": []}) + assert response == {"result": []} def _test_events_last_id_logic(self, mock_cache): with mock.patch.object(mock_cache, "xrange") as mock_xrange: @@ -69,7 +69,7 @@ def _test_events_last_id_logic(self, mock_cache): assert rv.status_code == 200 channel_id = app.config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] + self.UUID mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100) - self.assertEqual(response, {"result": []}) + assert response == {"result": []} def _test_events_results_logic(self, mock_cache): with mock.patch.object(mock_cache, "xrange") as mock_xrange: @@ -115,7 +115,7 @@ def _test_events_results_logic(self, mock_cache): }, ] } - self.assertEqual(response, expected) + assert response == expected @mock.patch("uuid.uuid4", return_value=UUID) def test_events_redis_cache_backend(self, mock_uuid4): diff --git a/tests/integration_tests/base_api_tests.py b/tests/integration_tests/base_api_tests.py index de003ff945b6e..6c10b7cf26f78 100644 --- a/tests/integration_tests/base_api_tests.py +++ b/tests/integration_tests/base_api_tests.py @@ -69,7 +69,7 @@ def test_open_api_spec(self): self.login(ADMIN_USERNAME) uri = "api/v1/_openapi" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) validate_spec(response) @@ -87,20 +87,20 @@ def test_default_missing_declaration_get(self): self.login(ADMIN_USERNAME) uri = "api/v1/model1api/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["list_columns"], ["id"]) + assert response["list_columns"] == ["id"] for result in response["result"]: - self.assertEqual(list(result.keys()), ["id"]) + assert list(result.keys()) == ["id"] # Check get response dashboard = db.session.query(Dashboard).first() uri = f"api/v1/model1api/{dashboard.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["show_columns"], ["id"]) - self.assertEqual(list(response["result"].keys()), ["id"]) + assert response["show_columns"] == ["id"] + assert list(response["result"].keys()) == ["id"] def test_default_missing_declaration_put_spec(self): """ @@ -113,17 +113,18 @@ def test_default_missing_declaration_put_spec(self): uri = "api/v1/_openapi" rv = self.client.get(uri) # dashboard model accepts all fields are null - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) expected_mutation_spec = { "properties": {"id": {"type": "integer"}}, "type": "object", } - self.assertEqual( - response["components"]["schemas"]["Model1Api.post"], expected_mutation_spec + assert ( + response["components"]["schemas"]["Model1Api.post"] + == expected_mutation_spec ) - self.assertEqual( - response["components"]["schemas"]["Model1Api.put"], expected_mutation_spec + assert ( + response["components"]["schemas"]["Model1Api.put"] == expected_mutation_spec ) def test_default_missing_declaration_post(self): @@ -145,7 +146,7 @@ def test_default_missing_declaration_post(self): uri = "api/v1/model1api/" rv = self.client.post(uri, json=dashboard_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": { "css": ["Unknown field."], @@ -156,7 +157,7 @@ def test_default_missing_declaration_post(self): "slug": ["Unknown field."], } } - self.assertEqual(response, expected_response) + assert response == expected_response def test_refuse_invalid_format_request(self): """ @@ -169,7 +170,7 @@ def test_refuse_invalid_format_request(self): rv = self.client.post( uri, data="a: value\nb: 1\n", content_type="application/yaml" ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_default_missing_declaration_put(self): @@ -185,14 +186,14 @@ def test_default_missing_declaration_put(self): uri = f"api/v1/model1api/{dashboard.id}" rv = self.client.put(uri, json=dashboard_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": { "dashboard_title": ["Unknown field."], "slug": ["Unknown field."], } } - self.assertEqual(response, expected_response) + assert response == expected_response class ApiOwnersTestCaseMixin: diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py index 88b20282f40a5..1356e32cd81f3 100644 --- a/tests/integration_tests/cache_tests.py +++ b/tests/integration_tests/cache_tests.py @@ -59,8 +59,8 @@ def test_no_data_cache(self): ) # restore DATA_CACHE_CONFIG app.config["DATA_CACHE_CONFIG"] = data_cache_config - self.assertFalse(resp["is_cached"]) - self.assertFalse(resp_from_cache["is_cached"]) + assert not resp["is_cached"] + assert not resp_from_cache["is_cached"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_slice_data_cache(self): @@ -84,20 +84,20 @@ def test_slice_data_cache(self): resp_from_cache = self.get_json_resp( json_endpoint, {"form_data": json.dumps(slc.viz.form_data)} ) - self.assertFalse(resp["is_cached"]) - self.assertTrue(resp_from_cache["is_cached"]) + assert not resp["is_cached"] + assert resp_from_cache["is_cached"] # should fallback to default cache timeout - self.assertEqual(resp_from_cache["cache_timeout"], 10) - self.assertEqual(resp_from_cache["status"], QueryStatus.SUCCESS) - self.assertEqual(resp["data"], resp_from_cache["data"]) - self.assertEqual(resp["query"], resp_from_cache["query"]) + assert resp_from_cache["cache_timeout"] == 10 + assert resp_from_cache["status"] == QueryStatus.SUCCESS + assert resp["data"] == resp_from_cache["data"] + assert resp["query"] == resp_from_cache["query"] # should exists in `data_cache` - self.assertEqual( - cache_manager.data_cache.get(resp_from_cache["cache_key"])["query"], - resp_from_cache["query"], + assert ( + cache_manager.data_cache.get(resp_from_cache["cache_key"])["query"] + == resp_from_cache["query"] ) # should not exists in `cache` - self.assertIsNone(cache_manager.cache.get(resp_from_cache["cache_key"])) + assert cache_manager.cache.get(resp_from_cache["cache_key"]) is None # reset cache config app.config["DATA_CACHE_CONFIG"] = data_cache_config diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index a8bae9b64d4a7..a99ba04f78427 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -336,9 +336,9 @@ def test_delete_chart(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_charts(self): """ @@ -355,13 +355,13 @@ def test_delete_bulk_charts(self): argument = chart_ids uri = f"api/v1/chart/?q={prison.dumps(argument)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": f"Deleted {chart_count} charts"} - self.assertEqual(response, expected_response) + assert response == expected_response for chart_id in chart_ids: model = db.session.query(Slice).get(chart_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_chart_bad_request(self): """ @@ -372,7 +372,7 @@ def test_delete_bulk_chart_bad_request(self): argument = chart_ids uri = f"api/v1/chart/?q={prison.dumps(argument)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 def test_delete_not_found_chart(self): """ @@ -382,7 +382,7 @@ def test_delete_not_found_chart(self): chart_id = 1000 uri = f"api/v1/chart/{chart_id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_chart_with_report") def test_delete_chart_with_report(self): @@ -398,11 +398,11 @@ def test_delete_chart_with_report(self): uri = f"api/v1/chart/{chart.id}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": "There are associated alerts or reports: report_with_chart" } - self.assertEqual(response, expected_response) + assert response == expected_response def test_delete_bulk_charts_not_found(self): """ @@ -413,7 +413,7 @@ def test_delete_bulk_charts_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/?q={prison.dumps(chart_ids)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_chart_with_report", "create_charts") def test_bulk_delete_chart_with_report(self): @@ -434,11 +434,11 @@ def test_bulk_delete_chart_with_report(self): uri = f"api/v1/chart/?q={prison.dumps(chart_ids)}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": "There are associated alerts or reports: report_with_chart" } - self.assertEqual(response, expected_response) + assert response == expected_response def test_delete_chart_admin_not_owned(self): """ @@ -450,9 +450,9 @@ def test_delete_chart_admin_not_owned(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_chart_admin_not_owned(self): """ @@ -471,13 +471,13 @@ def test_delete_bulk_chart_admin_not_owned(self): uri = f"api/v1/chart/?q={prison.dumps(argument)}" rv = self.delete_assert_metric(uri, "bulk_delete") response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_response = {"message": f"Deleted {chart_count} charts"} - self.assertEqual(response, expected_response) + assert response == expected_response for chart_id in chart_ids: model = db.session.query(Slice).get(chart_id) - self.assertEqual(model, None) + assert model is None def test_delete_chart_not_owned(self): """ @@ -493,7 +493,7 @@ def test_delete_chart_not_owned(self): self.login(username="alpha2", password="password") uri = f"api/v1/chart/{chart.id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 db.session.delete(chart) db.session.delete(user_alpha1) db.session.delete(user_alpha2) @@ -525,19 +525,19 @@ def test_delete_bulk_chart_not_owned(self): arguments = [chart.id for chart in charts] uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": "Forbidden"} - self.assertEqual(response, expected_response) + assert response == expected_response # # nothing is deleted in bulk with a list of owned and not owned charts arguments = [chart.id for chart in charts] + [owned_chart.id] uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": "Forbidden"} - self.assertEqual(response, expected_response) + assert response == expected_response for chart in charts: db.session.delete(chart) @@ -572,7 +572,7 @@ def test_create_chart(self): self.login(ADMIN_USERNAME) uri = "api/v1/chart/" rv = self.post_assert_metric(uri, chart_data, "post") - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 data = json.loads(rv.data.decode("utf-8")) model = db.session.query(Slice).get(data.get("id")) db.session.delete(model) @@ -590,7 +590,7 @@ def test_create_simple_chart(self): self.login(ADMIN_USERNAME) uri = "api/v1/chart/" rv = self.post_assert_metric(uri, chart_data, "post") - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 data = json.loads(rv.data.decode("utf-8")) model = db.session.query(Slice).get(data.get("id")) db.session.delete(model) @@ -609,10 +609,10 @@ def test_create_chart_validate_owners(self): self.login(ADMIN_USERNAME) uri = "api/v1/chart/" rv = self.post_assert_metric(uri, chart_data, "post") - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"owners": ["Owners are invalid"]}} - self.assertEqual(response, expected_response) + assert response == expected_response def test_create_chart_validate_params(self): """ @@ -627,7 +627,7 @@ def test_create_chart_validate_params(self): self.login(ADMIN_USERNAME) uri = "api/v1/chart/" rv = self.post_assert_metric(uri, chart_data, "post") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 def test_create_chart_validate_datasource(self): """ @@ -640,29 +640,24 @@ def test_create_chart_validate_datasource(self): "datasource_type": "unknown", } rv = self.post_assert_metric("/api/v1/chart/", chart_data, "post") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, - { - "message": { - "datasource_type": [ - "Must be one of: table, dataset, query, saved_query, view." - ] - } - }, - ) + assert response == { + "message": { + "datasource_type": [ + "Must be one of: table, dataset, query, saved_query, view." + ] + } + } chart_data = { "slice_name": "title1", "datasource_id": 0, "datasource_type": "table", } rv = self.post_assert_metric("/api/v1/chart/", chart_data, "post") - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, {"message": {"datasource_id": ["Datasource does not exist"]}} - ) + assert response == {"message": {"datasource_id": ["Datasource does not exist"]}} @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_create_chart_validate_user_is_dashboard_owner(self): @@ -682,12 +677,11 @@ def test_create_chart_validate_user_is_dashboard_owner(self): self.login(ALPHA_USERNAME) uri = "api/v1/chart/" rv = self.post_assert_metric(uri, chart_data, "post") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, - {"message": "Changing one or more of these dashboards is forbidden"}, - ) + assert response == { + "message": "Changing one or more of these dashboards is forbidden" + } @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_update_chart(self): @@ -720,23 +714,23 @@ def test_update_chart(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart_id) related_dashboard = db.session.query(Dashboard).filter_by(slug="births").first() - self.assertEqual(model.created_by, admin) - self.assertEqual(model.slice_name, "title1_changed") - self.assertEqual(model.description, "description1") - self.assertNotIn(admin, model.owners) - self.assertIn(gamma, model.owners) - self.assertEqual(model.viz_type, "viz_type1") - self.assertEqual(model.params, """{"a": 1}""") - self.assertEqual(model.cache_timeout, 1000) - self.assertEqual(model.datasource_id, birth_names_table_id) - self.assertEqual(model.datasource_type, "table") - self.assertEqual(model.datasource_name, full_table_name) - self.assertEqual(model.certified_by, "Mario Rossi") - self.assertEqual(model.certification_details, "Edited certification") - self.assertIn(model.id, [slice.id for slice in related_dashboard.slices]) + assert model.created_by == admin + assert model.slice_name == "title1_changed" + assert model.description == "description1" + assert admin not in model.owners + assert gamma in model.owners + assert model.viz_type == "viz_type1" + assert model.params == '{"a": 1}' + assert model.cache_timeout == 1000 + assert model.datasource_id == birth_names_table_id + assert model.datasource_type == "table" + assert model.datasource_name == full_table_name + assert model.certified_by == "Mario Rossi" + assert model.certification_details == "Edited certification" + assert model.id in [slice.id for slice in related_dashboard.slices] db.session.delete(model) db.session.commit() @@ -755,16 +749,16 @@ def test_chart_get_list_no_username(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart_id) response = self.get_assert_metric("api/v1/chart/", "get_list") res = json.loads(response.data.decode("utf-8"))["result"] current_chart = [d for d in res if d["id"] == chart_id][0] - self.assertEqual(current_chart["slice_name"], new_name) - self.assertNotIn("username", current_chart["changed_by"].keys()) - self.assertNotIn("username", current_chart["owners"][0].keys()) + assert current_chart["slice_name"] == new_name + assert "username" not in current_chart["changed_by"].keys() + assert "username" not in current_chart["owners"][0].keys() db.session.delete(model) db.session.commit() @@ -784,14 +778,14 @@ def test_chart_get_no_username(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart_id) response = self.get_assert_metric(uri, "get") res = json.loads(response.data.decode("utf-8"))["result"] - self.assertEqual(res["slice_name"], new_name) - self.assertNotIn("username", res["owners"][0].keys()) + assert res["slice_name"] == new_name + assert "username" not in res["owners"][0].keys() db.session.delete(model) db.session.commit() @@ -829,10 +823,10 @@ def test_update_chart_new_owner_admin(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart_id) - self.assertNotIn(admin, model.owners) - self.assertIn(gamma, model.owners) + assert admin not in model.owners + assert gamma in model.owners db.session.delete(model) db.session.commit() @@ -848,8 +842,8 @@ def test_update_chart_preserve_ownership(self): self.login(username="admin") uri = f"api/v1/chart/{self.chart.id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) - self.assertEqual([admin], self.chart.owners) + assert rv.status_code == 200 + assert [admin] == self.chart.owners @pytest.mark.usefixtures("add_dashboard_to_chart") def test_update_chart_clear_owner_list(self): @@ -861,8 +855,8 @@ def test_update_chart_clear_owner_list(self): self.login(username="admin") uri = f"api/v1/chart/{self.chart.id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) - self.assertEqual([], self.chart.owners) + assert rv.status_code == 200 + assert [] == self.chart.owners def test_update_chart_populate_owner(self): """ @@ -873,15 +867,15 @@ def test_update_chart_populate_owner(self): admin = self.get_user("admin") chart_id = self.insert_chart("title", [], 1).id model = db.session.query(Slice).get(chart_id) - self.assertEqual(model.owners, []) + assert model.owners == [] chart_data = {"owners": [gamma.id]} self.login(username="admin") uri = f"api/v1/chart/{chart_id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model_updated = db.session.query(Slice).get(chart_id) - self.assertNotIn(admin, model_updated.owners) - self.assertIn(gamma, model_updated.owners) + assert admin not in model_updated.owners + assert gamma in model_updated.owners db.session.delete(model_updated) db.session.commit() @@ -897,9 +891,9 @@ def test_update_chart_new_dashboards(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{self.chart.id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) - self.assertIn(self.new_dashboard, self.chart.dashboards) - self.assertNotIn(self.original_dashboard, self.chart.dashboards) + assert rv.status_code == 200 + assert self.new_dashboard in self.chart.dashboards + assert self.original_dashboard not in self.chart.dashboards @pytest.mark.usefixtures("add_dashboard_to_chart") def test_not_update_chart_none_dashboards(self): @@ -910,9 +904,9 @@ def test_not_update_chart_none_dashboards(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{self.chart.id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) - self.assertIn(self.original_dashboard, self.chart.dashboards) - self.assertEqual(len(self.chart.dashboards), 1) + assert rv.status_code == 200 + assert self.original_dashboard in self.chart.dashboards + assert len(self.chart.dashboards) == 1 def test_update_chart_not_owned(self): """ @@ -930,7 +924,7 @@ def test_update_chart_not_owned(self): chart_data = {"slice_name": "title1_changed"} uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 db.session.delete(chart) db.session.delete(user_alpha1) db.session.delete(user_alpha2) @@ -976,13 +970,13 @@ def test_update_chart_linked_with_not_owned_dashboard(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, chart_data_with_invalid_dashboard, "put") - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"dashboards": ["Dashboards do not exist"]}} - self.assertEqual(response, expected_response) + assert response == expected_response rv = self.put_assert_metric(uri, chart_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 db.session.delete(chart) db.session.delete(original_dashboard) @@ -1001,26 +995,21 @@ def test_update_chart_validate_datasource(self): chart_data = {"datasource_id": 1, "datasource_type": "unknown"} rv = self.put_assert_metric(f"/api/v1/chart/{chart.id}", chart_data, "put") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, - { - "message": { - "datasource_type": [ - "Must be one of: table, dataset, query, saved_query, view." - ] - } - }, - ) + assert response == { + "message": { + "datasource_type": [ + "Must be one of: table, dataset, query, saved_query, view." + ] + } + } chart_data = {"datasource_id": 0, "datasource_type": "table"} rv = self.put_assert_metric(f"/api/v1/chart/{chart.id}", chart_data, "put") - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, {"message": {"datasource_id": ["Datasource does not exist"]}} - ) + assert response == {"message": {"datasource_id": ["Datasource does not exist"]}} db.session.delete(chart) db.session.commit() @@ -1038,10 +1027,10 @@ def test_update_chart_validate_owners(self): self.login(ADMIN_USERNAME) uri = "api/v1/chart/" # noqa: F541 rv = self.client.post(uri, json=chart_data) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"owners": ["Owners are invalid"]}} - self.assertEqual(response, expected_response) + assert response == expected_response @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_chart(self): @@ -1053,7 +1042,7 @@ def test_get_chart(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart.id}" rv = self.get_assert_metric(uri, "get") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_result = { "cache_timeout": None, "certified_by": None, @@ -1075,10 +1064,10 @@ def test_get_chart(self): "is_managed_externally": False, } data = json.loads(rv.data.decode("utf-8")) - self.assertIn("changed_on_delta_humanized", data["result"]) - self.assertIn("id", data["result"]) - self.assertIn("thumbnail_url", data["result"]) - self.assertIn("url", data["result"]) + assert "changed_on_delta_humanized" in data["result"] + assert "id" in data["result"] + assert "thumbnail_url" in data["result"] + assert "url" in data["result"] for key, value in data["result"].items(): # We can't assert timestamp values or id/urls if key not in ( @@ -1087,7 +1076,7 @@ def test_get_chart(self): "thumbnail_url", "url", ): - self.assertEqual(value, expected_result[key]) + assert value == expected_result[key] db.session.delete(chart) db.session.commit() @@ -1099,7 +1088,7 @@ def test_get_chart_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{chart_id}" rv = self.get_assert_metric(uri, "get") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_chart_no_data_access(self): @@ -1114,7 +1103,7 @@ def test_get_chart_no_data_access(self): ) uri = f"api/v1/chart/{chart_no_access.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures( "load_energy_table_with_slice", @@ -1129,9 +1118,9 @@ def test_get_charts(self): self.login(ADMIN_USERNAME) uri = "api/v1/chart/" # noqa: F541 rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 33) + assert data["count"] == 33 @pytest.mark.usefixtures("load_energy_table_with_slice", "add_dashboard_to_chart") def test_get_charts_dashboards(self): @@ -1146,7 +1135,7 @@ def test_get_charts_dashboards(self): } uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["result"][0]["dashboards"] == [ { @@ -1172,7 +1161,7 @@ def test_get_charts_dashboard_filter(self): } uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) result = data["result"] assert len(result) == 1 @@ -1206,26 +1195,20 @@ def test_get_charts_tag_filters(self): # Filter by tag ID filter_params = get_filter_params("chart_tag_id", tag.id) response_by_id = self.get_list("chart", filter_params) - self.assertEqual(response_by_id.status_code, 200) + assert response_by_id.status_code == 200 data_by_id = json.loads(response_by_id.data.decode("utf-8")) # Filter by tag name filter_params = get_filter_params("chart_tags", tag.name) response_by_name = self.get_list("chart", filter_params) - self.assertEqual(response_by_name.status_code, 200) + assert response_by_name.status_code == 200 data_by_name = json.loads(response_by_name.data.decode("utf-8")) # Compare results - self.assertEqual( - data_by_id["count"], - data_by_name["count"], - len(expected_charts), - ) - self.assertEqual( - set(chart["id"] for chart in data_by_id["result"]), - set(chart["id"] for chart in data_by_name["result"]), - set(chart.id for chart in expected_charts), - ) + assert data_by_id["count"] == data_by_name["count"], len(expected_charts) + assert set(chart["id"] for chart in data_by_id["result"]) == set( + chart["id"] for chart in data_by_name["result"] + ), set(chart.id for chart in expected_charts) def test_get_charts_changed_on(self): """ @@ -1243,7 +1226,7 @@ def test_get_charts_changed_on(self): uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["result"][0]["changed_on_delta_humanized"] in ( "now", @@ -1266,9 +1249,9 @@ def test_get_charts_filter(self): arguments = {"filters": [{"col": "slice_name", "opr": "sw", "value": "G"}]} uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 5) + assert data["count"] == 5 @pytest.fixture() def load_energy_charts(self): @@ -1323,9 +1306,9 @@ def test_get_charts_custom_filter(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 4) + assert data["count"] == 4 expected_response = [ {"description": "ZY_bar", "slice_name": "foo_a", "viz_type": None}, @@ -1334,11 +1317,9 @@ def test_get_charts_custom_filter(self): {"description": "desc1", "slice_name": "zy_foo", "viz_type": None}, ] for index, item in enumerate(data["result"]): - self.assertEqual( - item["description"], expected_response[index]["description"] - ) - self.assertEqual(item["slice_name"], expected_response[index]["slice_name"]) - self.assertEqual(item["viz_type"], expected_response[index]["viz_type"]) + assert item["description"] == expected_response[index]["description"] + assert item["slice_name"] == expected_response[index]["slice_name"] + assert item["viz_type"] == expected_response[index]["viz_type"] @pytest.mark.usefixtures("load_energy_table_with_slice", "load_energy_charts") def test_admin_gets_filtered_energy_slices(self): @@ -1390,9 +1371,9 @@ def test_gets_certified_charts_filter(self): uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], CHARTS_FIXTURE_COUNT) + assert data["count"] == CHARTS_FIXTURE_COUNT @pytest.mark.usefixtures("create_charts") def test_gets_not_certified_charts_filter(self): @@ -1411,9 +1392,9 @@ def test_gets_not_certified_charts_filter(self): uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 17) + assert data["count"] == 17 @pytest.mark.usefixtures("load_energy_charts") def test_user_gets_none_filtered_energy_slices(self): @@ -1433,9 +1414,9 @@ def test_user_gets_none_filtered_energy_slices(self): self.login(GAMMA_USERNAME) uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 0) + assert data["count"] == 0 @pytest.mark.usefixtures("load_energy_charts") def test_user_gets_all_charts(self): @@ -1445,12 +1426,12 @@ def test_user_gets_all_charts(self): def count_charts(): uri = "api/v1/chart/" rv = self.client.get(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = rv.get_json() return data["count"] with self.temporary_user(gamma_user, login=True): - self.assertEqual(count_charts(), 0) + assert count_charts() == 0 perm = ("all_database_access", "all_database_access") with self.temporary_user(gamma_user, extra_pvms=[perm], login=True): @@ -1462,7 +1443,7 @@ def count_charts(): # Back to normal with self.temporary_user(gamma_user, login=True): - self.assertEqual(count_charts(), 0) + assert count_charts() == 0 @pytest.mark.usefixtures("create_charts") def test_get_charts_favorite_filter(self): @@ -1645,7 +1626,7 @@ def test_get_time_range(self): uri = f"api/v1/time_range/?q={prison.dumps(humanize_time_range)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 assert "since" in data["result"][0] assert "until" in data["result"][0] assert "timeRange" in data["result"][0] @@ -1686,10 +1667,10 @@ def test_query_form_data(self): uri = f"api/v1/form_data/?slice_id={slice.id if slice else None}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.content_type, "application/json") + assert rv.status_code == 200 + assert rv.content_type == "application/json" if slice: - self.assertEqual(data["slice_id"], slice.id) + assert data["slice_id"] == slice.id @pytest.mark.usefixtures( "load_unicode_dashboard_with_slice", @@ -1706,16 +1687,16 @@ def test_get_charts_page(self): arguments = {"page_size": 10, "page": 0} uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(len(data["result"]), 10) + assert len(data["result"]) == 10 arguments = {"page_size": 10, "page": 3} uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(len(data["result"]), 3) + assert len(data["result"]) == 3 def test_get_charts_no_data_access(self): """ @@ -1724,9 +1705,9 @@ def test_get_charts_no_data_access(self): self.login(GAMMA_USERNAME) uri = "api/v1/chart/" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 0) + assert data["count"] == 0 def test_export_chart(self): """ @@ -1940,9 +1921,9 @@ def test_gets_created_by_user_charts_filter(self): uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 8) + assert data["count"] == 8 def test_gets_not_created_by_user_charts_filter(self): arguments = { @@ -1954,9 +1935,9 @@ def test_gets_not_created_by_user_charts_filter(self): uri = f"api/v1/chart/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 8) + assert data["count"] == 8 @pytest.mark.usefixtures("create_charts") def test_gets_owned_created_favorited_by_me_filter(self): @@ -1978,7 +1959,7 @@ def test_gets_owned_created_favorited_by_me_filter(self): "page_size": 25, } rv = self.client.get(f"api/v1/chart/?q={prison.dumps(arguments)}") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["result"][0]["slice_name"] == "name0" @@ -1995,13 +1976,12 @@ def test_warm_up_cache(self, slice_name): self.login(ADMIN_USERNAME) slc = self.get_slice(slice_name) rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id}) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["result"], - [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}], - ) + assert data["result"] == [ + {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} + ] dashboard = self.get_dash_by_slug("births") @@ -2009,12 +1989,11 @@ def test_warm_up_cache(self, slice_name): "/api/v1/chart/warm_up_cache", json={"chart_id": slc.id, "dashboard_id": dashboard.id}, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["result"], - [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}], - ) + assert data["result"] == [ + {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} + ] rv = self.client.put( "/api/v1/chart/warm_up_cache", @@ -2026,29 +2005,25 @@ def test_warm_up_cache(self, slice_name): ), }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["result"], - [{"chart_id": slc.id, "viz_error": None, "viz_status": "success"}], - ) + assert data["result"] == [ + {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} + ] def test_warm_up_cache_chart_id_required(self): self.login(ADMIN_USERNAME) rv = self.client.put("/api/v1/chart/warm_up_cache", json={"dashboard_id": 1}) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data, - {"message": {"chart_id": ["Missing data for required field."]}}, - ) + assert data == {"message": {"chart_id": ["Missing data for required field."]}} def test_warm_up_cache_chart_not_found(self): self.login(ADMIN_USERNAME) rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": 99999}) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data, {"message": "Chart not found"}) + assert data == {"message": "Chart not found"} def test_warm_up_cache_payload_validation(self): self.login(ADMIN_USERNAME) @@ -2056,18 +2031,15 @@ def test_warm_up_cache_payload_validation(self): "/api/v1/chart/warm_up_cache", json={"chart_id": "id", "dashboard_id": "id", "extra_filters": 4}, ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data, - { - "message": { - "chart_id": ["Not a valid integer."], - "dashboard_id": ["Not a valid integer."], - "extra_filters": ["Not a valid string."], - } - }, - ) + assert data == { + "message": { + "chart_id": ["Not a valid integer."], + "dashboard_id": ["Not a valid integer."], + "extra_filters": ["Not a valid string."], + } + } @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache_error(self) -> None: @@ -2167,12 +2139,12 @@ def test_update_chart_add_tags_can_write_on_tag(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, new_tags) + assert tag_list == new_tags @pytest.mark.usefixtures("create_chart_with_tag") def test_update_chart_remove_tags_can_write_on_tag(self): @@ -2194,12 +2166,12 @@ def test_update_chart_remove_tags_can_write_on_tag(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, new_tags) + assert tag_list == new_tags @pytest.mark.usefixtures("create_chart_with_tag") def test_update_chart_add_tags_can_tag_on_chart(self): @@ -2226,12 +2198,12 @@ def test_update_chart_add_tags_can_tag_on_chart(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, new_tags) + assert tag_list == new_tags security_manager.add_permission_role(alpha_role, write_tags_perm) @@ -2256,12 +2228,12 @@ def test_update_chart_remove_tags_can_tag_on_chart(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Slice).get(chart.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, []) + assert tag_list == [] security_manager.add_permission_role(alpha_role, write_tags_perm) @@ -2291,10 +2263,9 @@ def test_update_chart_add_tags_missing_permission(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 403) - self.assertEqual( - rv.json["message"], - "You do not have permission to manage tags on charts", + assert rv.status_code == 403 + assert ( + rv.json["message"] == "You do not have permission to manage tags on charts" ) security_manager.add_permission_role(alpha_role, write_tags_perm) @@ -2322,10 +2293,9 @@ def test_update_chart_remove_tags_missing_permission(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 403) - self.assertEqual( - rv.json["message"], - "You do not have permission to manage tags on charts", + assert rv.status_code == 403 + assert ( + rv.json["message"] == "You do not have permission to manage tags on charts" ) security_manager.add_permission_role(alpha_role, write_tags_perm) @@ -2353,7 +2323,7 @@ def test_update_chart_no_tag_changes(self): uri = f"api/v1/chart/{chart.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 security_manager.add_permission_role(alpha_role, write_tags_perm) security_manager.add_permission_role(alpha_role, tag_charts_perm) diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index 950c7cbc888d5..d66980585e2d9 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -424,15 +424,19 @@ def test_warm_up_cache_command_chart_not_found(self): def test_warm_up_cache(self): slc = self.get_slice("Top 10 Girl Name Share") result = ChartWarmUpCacheCommand(slc.id, None, None).run() - self.assertEqual( - result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} - ) + assert result == { + "chart_id": slc.id, + "viz_error": None, + "viz_status": "success", + } # can just pass in chart as well result = ChartWarmUpCacheCommand(slc, None, None).run() - self.assertEqual( - result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} - ) + assert result == { + "chart_id": slc.id, + "viz_error": None, + "viz_status": "success", + } class TestFavoriteChartCommand(SupersetTestCase): diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 9aeac84cc5f9e..b922f16cb296e 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -471,19 +471,16 @@ def test_chart_data_applied_time_extras(self): "__time_origin": "now", } rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["result"][0]["applied_filters"], - [ - {"column": "gender"}, - {"column": "num"}, - {"column": "name"}, - {"column": "__time_range"}, - ], - ) + assert data["result"][0]["applied_filters"] == [ + {"column": "gender"}, + {"column": "num"}, + {"column": "name"}, + {"column": "__time_range"}, + ] expected_row_count = self.get_expected_row_count("client_id_2") - self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + assert data["result"][0]["rowcount"] == expected_row_count @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_in_op_filter__data_is_returned(self): @@ -533,7 +530,7 @@ def test_chart_data_dttm_filter(self): dttm_col.type, dttm, ) - self.assertIn(dttm_expression, result["query"]) + assert dttm_expression in result["query"] else: raise Exception("ds column not found") @@ -563,16 +560,16 @@ def test_chart_data_prophet(self): } ] rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] row = result["data"][0] - self.assertIn("__timestamp", row) - self.assertIn("sum__num", row) - self.assertIn("sum__num__yhat", row) - self.assertIn("sum__num__yhat_upper", row) - self.assertIn("sum__num__yhat_lower", row) - self.assertEqual(result["rowcount"], 103) + assert "__timestamp" in row + assert "sum__num" in row + assert "sum__num__yhat" in row + assert "sum__num__yhat_upper" in row + assert "sum__num__yhat_lower" in row + assert result["rowcount"] == 103 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_invalid_post_processing(self): @@ -730,11 +727,11 @@ def test_chart_data_async(self): time.sleep(1) rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") time.sleep(1) - self.assertEqual(rv.status_code, 202) + assert rv.status_code == 202 time.sleep(1) data = json.loads(rv.data.decode("utf-8")) keys = list(data.keys()) - self.assertCountEqual( + self.assertCountEqual( # noqa: PT009 keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] ) @@ -764,10 +761,10 @@ class QueryContext: rv = self.post_assert_metric( CHART_DATA_URI, self.query_context_payload, "data" ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) patched_run.assert_called_once_with(force_cached=True) - self.assertEqual(data, {"result": [{"query": "select * from foo"}]}) + assert data == {"result": [{"query": "select * from foo"}]} @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -779,7 +776,7 @@ def test_chart_data_async_results_type(self): async_query_manager_factory.init_app(app) self.query_context_payload["result_type"] = "results" rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -793,7 +790,7 @@ def test_chart_data_async_invalid_token(self): app.config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME"], "foo" ) rv = test_client.post(CHART_DATA_URI, json=self.query_context_payload) - self.assertEqual(rv.status_code, 401) + assert rv.status_code == 401 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_rowcount(self): @@ -846,10 +843,8 @@ def test_with_series_limit(self): unique_names = {row["name"] for row in data} self.maxDiff = None - self.assertEqual(len(unique_names), SERIES_LIMIT) - self.assertEqual( - {column for column in data[0].keys()}, {"state", "name", "sum__num"} - ) + assert len(unique_names) == SERIES_LIMIT + assert {column for column in data[0].keys()} == {"state", "name", "sum__num"} @pytest.mark.usefixtures( "create_annotation_layers", "load_birth_names_dashboard_with_slices" @@ -888,10 +883,10 @@ def test_with_annotations_layers__annotations_data_returned(self): annotation_layers.append(event) rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) # response should only contain interval and event data, not formula - self.assertEqual(len(data["result"][0]["annotation_data"]), 2) + assert len(data["result"][0]["annotation_data"]) == 2 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_virtual_table_with_colons_as_datasource(self): @@ -1184,8 +1179,8 @@ def mock_run(self, **kwargs): data = json.loads(rv.data.decode("utf-8")) expected_row_count = self.get_expected_row_count("client_id_3") - self.assertEqual(rv.status_code, 200) - self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + assert rv.status_code == 200 + assert data["result"][0]["rowcount"] == expected_row_count @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @mock.patch("superset.charts.data.api.QueryContextCacheLoader") @@ -1202,8 +1197,8 @@ def test_chart_data_cache_run_failed(self, cache_loader): ) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) - self.assertEqual(data["message"], "Error loading data from cache") + assert rv.status_code == 422 + assert data["message"] == "Error loading data from cache" @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @mock.patch("superset.charts.data.api.QueryContextCacheLoader") @@ -1231,7 +1226,7 @@ def mock_run(self, **kwargs): f"{CHART_DATA_URI}/test-cache-key", ) - self.assertEqual(rv.status_code, 401) + assert rv.status_code == 401 @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) def test_chart_data_cache_key_error(self): @@ -1244,7 +1239,7 @@ def test_chart_data_cache_key_error(self): f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" ) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_with_adhoc_column(self): diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py index 8f74d9de293c8..46ba792e52839 100644 --- a/tests/integration_tests/charts/schema_tests.py +++ b/tests/integration_tests/charts/schema_tests.py @@ -46,8 +46,8 @@ def test_query_context_limit_and_offset(self): payload["queries"][0]["row_offset"] = -1 with self.assertRaises(ValidationError) as context: _ = ChartDataQueryContextSchema().load(payload) - self.assertIn("row_limit", context.exception.messages["queries"][0]) - self.assertIn("row_offset", context.exception.messages["queries"][0]) + assert "row_limit" in context.exception.messages["queries"][0] + assert "row_offset" in context.exception.messages["queries"][0] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_null_timegrain(self): diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 44b7ef26e64cd..4f989611b2a29 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -114,15 +114,15 @@ def insert_dashboard_created_by_gamma(self): def test_login(self): resp = self.get_resp("/login/", data=dict(username="admin", password="general")) - self.assertNotIn("User confirmation needed", resp) + assert "User confirmation needed" not in resp resp = self.get_resp("/logout/", follow_redirects=True) - self.assertIn("User confirmation needed", resp) + assert "User confirmation needed" in resp resp = self.get_resp( "/login/", data=dict(username="admin", password="wrongPassword") ) - self.assertIn("User confirmation needed", resp) + assert "User confirmation needed" in resp def test_dashboard_endpoint(self): self.login(ADMIN_USERNAME) @@ -146,20 +146,17 @@ def test_viz_cache_key(self): qobj["groupby"] = [] cache_key_with_groupby = viz.cache_key(qobj) - self.assertNotEqual(cache_key, cache_key_with_groupby) + assert cache_key != cache_key_with_groupby - self.assertNotEqual( - viz.cache_key(qobj), viz.cache_key(qobj, time_compare="12 weeks") - ) + assert viz.cache_key(qobj) != viz.cache_key(qobj, time_compare="12 weeks") - self.assertNotEqual( - viz.cache_key(qobj, time_compare="28 days"), - viz.cache_key(qobj, time_compare="12 weeks"), + assert viz.cache_key(qobj, time_compare="28 days") != viz.cache_key( + qobj, time_compare="12 weeks" ) qobj["inner_from_dttm"] = datetime.datetime(1901, 1, 1) - self.assertEqual(cache_key_with_groupby, viz.cache_key(qobj)) + assert cache_key_with_groupby == viz.cache_key(qobj) def test_admin_only_menu_views(self): def assert_admin_view_menus_in(role_name, assert_func): @@ -205,9 +202,9 @@ def test_save_slice(self): new_slice_id = resp.json["form_data"]["slice_id"] slc = db.session.query(Slice).filter_by(id=new_slice_id).one() - self.assertEqual(slc.slice_name, copy_name) + assert slc.slice_name == copy_name form_data.pop("slice_id") # We don't save the slice id when saving as - self.assertEqual(slc.viz.form_data, form_data) + assert slc.viz.form_data == form_data form_data = { "adhoc_filters": [], @@ -224,8 +221,8 @@ def test_save_slice(self): data={"form_data": json.dumps(form_data)}, ) slc = db.session.query(Slice).filter_by(id=new_slice_id).one() - self.assertEqual(slc.slice_name, new_slice_name) - self.assertEqual(slc.viz.form_data, form_data) + assert slc.slice_name == new_slice_name + assert slc.viz.form_data == form_data # Cleanup slices = ( @@ -261,21 +258,21 @@ def test_slices(self): logger.info(f"[{name}]/[{method}]: {url}") print(f"[{name}]/[{method}]: {url}") resp = self.client.get(url) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 def test_add_slice(self): self.login(ADMIN_USERNAME) # assert that /chart/add responds with 200 url = "/chart/add" resp = self.client.get(url) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 def test_get_user_slices(self): self.login(ADMIN_USERNAME) userid = security_manager.find_user("admin").id url = f"/sliceasync/api/read?_flt_0_created_by={userid}" resp = self.client.get(url) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 @pytest.mark.usefixtures("load_energy_table_with_slice") def test_slices_V2(self): @@ -339,7 +336,7 @@ def test_databaseview_edit(self): data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri() self.client.post(url, data=data) database = superset.utils.database.get_example_database() - self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted) + assert sqlalchemy_uri_decrypted == database.sqlalchemy_uri_decrypted # Need to clean up after ourselves database.impersonate_user = False @@ -355,9 +352,9 @@ def test_warm_up_cache(self): self.login(ADMIN_USERNAME) slc = self.get_slice("Top 10 Girl Name Share") data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") - self.assertEqual( - data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] - ) + assert data == [ + {"slice_id": slc.id, "viz_error": None, "viz_status": "success"} + ] data = self.get_json_resp( "/superset/warm_up_cache?table_name=energy_usage&db_name=main" @@ -415,29 +412,29 @@ def test_kv_disabled(self): self.login(ADMIN_USERNAME) resp = self.client.get("/kv/10001/") - self.assertEqual(404, resp.status_code) + assert 404 == resp.status_code value = json.dumps({"data": "this is a test"}) resp = self.client.post("/kv/store/", data=dict(data=value)) - self.assertEqual(resp.status_code, 404) + assert resp.status_code == 404 @with_feature_flags(KV_STORE=True) def test_kv_enabled(self): self.login(ADMIN_USERNAME) resp = self.client.get("/kv/10001/") - self.assertEqual(404, resp.status_code) + assert 404 == resp.status_code value = json.dumps({"data": "this is a test"}) resp = self.client.post("/kv/store/", data=dict(data=value)) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 kv = db.session.query(models.KeyValue).first() kv_value = kv.value - self.assertEqual(json.loads(value), json.loads(kv_value)) + assert json.loads(value) == json.loads(kv_value) resp = self.client.get(f"/kv/{kv.id}/") - self.assertEqual(resp.status_code, 200) - self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8"))) + assert resp.status_code == 200 + assert json.loads(value) == json.loads(resp.data.decode("utf-8")) def test_gamma(self): self.login(GAMMA_USERNAME) @@ -451,7 +448,7 @@ def test_templated_sql_json(self): self.login(ADMIN_USERNAME) sql = "SELECT '{{ 1+1 }}' as test" data = self.run_sql(sql, "fdaklj3ws") - self.assertEqual(data["data"][0]["test"], "2") + assert data["data"][0]["test"] == "2" def test_fetch_datasource_metadata(self): self.login(ADMIN_USERNAME) @@ -466,7 +463,7 @@ def test_fetch_datasource_metadata(self): "id", ] for k in keys: - self.assertIn(k, resp.keys()) + assert k in resp.keys() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_slice_id_is_always_logged_correctly_on_web_request(self): @@ -475,7 +472,7 @@ def test_slice_id_is_always_logged_correctly_on_web_request(self): slc = db.session.query(Slice).filter_by(slice_name="Girls").one() qry = db.session.query(models.Log).filter_by(slice_id=slc.id) self.get_resp(slc.slice_url) - self.assertEqual(1, qry.count()) + assert 1 == qry.count() def create_sample_csvfile(self, filename: str, content: list[str]) -> None: with open(filename, "w+") as test_file: @@ -490,7 +487,7 @@ def enable_csv_upload(self, database: models.Database) -> None: database.allow_file_upload = True db.session.commit() add_datasource_page = self.get_resp("/databaseview/list/") - self.assertIn("Upload a CSV", add_datasource_page) + assert "Upload a CSV" in add_datasource_page def test_dataframe_timezone(self): tz = pytz.FixedOffset(60) @@ -502,15 +499,15 @@ def test_dataframe_timezone(self): df = results.to_pandas_df() data = dataframe.df_to_records(df) json_str = json.dumps(data, default=json.pessimistic_json_iso_dttm_ser) - self.assertDictEqual( + self.assertDictEqual( # noqa: PT009 data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)} ) - self.assertDictEqual( + self.assertDictEqual( # noqa: PT009 data[1], {"data": pd.Timestamp("2017-11-18 22:06:30+0100", tz=tz)} ) - self.assertEqual( - json_str, - '[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]', + assert ( + json_str + == '[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]' ) def test_mssql_engine_spec_pymssql(self): @@ -524,11 +521,12 @@ def test_mssql_engine_spec_pymssql(self): ) df = results.to_pandas_df() data = dataframe.df_to_records(df) - self.assertEqual(len(data), 2) - self.assertEqual( - data[0], - {"col1": 1, "col2": 1, "col3": pd.Timestamp("2017-10-19 23:39:16.660000")}, - ) + assert len(data) == 2 + assert data[0] == { + "col1": 1, + "col2": 1, + "col3": pd.Timestamp("2017-10-19 23:39:16.660000"), + } def test_comments_in_sqlatable_query(self): clean_query = "SELECT\n '/* val 1 */' AS c1,\n '-- val 2' AS c2\nFROM tbl" @@ -554,9 +552,9 @@ def test_slice_payload_no_datasource(self): ) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["errors"][0]["message"], - "The dataset associated with this chart no longer exists", + assert ( + data["errors"][0]["message"] + == "The dataset associated with this chart no longer exists" ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -579,8 +577,8 @@ def test_explore_json(self): ) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(data["rowcount"], 2) + assert rv.status_code == 200 + assert data["rowcount"] == 2 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_explore_json_dist_bar_order(self): @@ -741,7 +739,7 @@ def test_explore_json_async_results_format(self): "/superset/explore_json/?results=true", data={"form_data": json.dumps(form_data)}, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( @@ -780,8 +778,8 @@ def set(self): rv = self.client.get("/superset/explore_json/data/valid-cache-key") data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(data["rowcount"], 2) + assert rv.status_code == 200 + assert data["rowcount"] == 2 @mock.patch( "superset.utils.cache_manager.CacheManager.cache", @@ -814,7 +812,7 @@ def set(self): mock_cache.return_value = MockCache() rv = self.client.get("/superset/explore_json/data/valid-cache-key") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 def test_explore_json_data_invalid_cache_key(self): self.login(ADMIN_USERNAME) @@ -822,8 +820,8 @@ def test_explore_json_data_invalid_cache_key(self): rv = self.client.get(f"/superset/explore_json/data/{cache_key}") data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 404) - self.assertEqual(data["error"], "Cached data not found") + assert rv.status_code == 404 + assert data["error"] == "Cached data not found" def test_results_default_deserialization(self): use_new_deserialization = False @@ -863,14 +861,14 @@ def test_results_default_deserialization(self): serialized_payload = sql_lab._serialize_payload( payload, use_new_deserialization ) - self.assertIsInstance(serialized_payload, str) + assert isinstance(serialized_payload, str) query_mock = mock.Mock() deserialized_payload = superset.views.utils._deserialize_results_payload( serialized_payload, query_mock, use_new_deserialization ) - self.assertDictEqual(deserialized_payload, payload) + self.assertDictEqual(deserialized_payload, payload) # noqa: PT009 query_mock.assert_not_called() def test_results_msgpack_deserialization(self): @@ -911,7 +909,7 @@ def test_results_msgpack_deserialization(self): serialized_payload = sql_lab._serialize_payload( payload, use_new_deserialization ) - self.assertIsInstance(serialized_payload, bytes) + assert isinstance(serialized_payload, bytes) with mock.patch.object( db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data @@ -925,7 +923,7 @@ def test_results_msgpack_deserialization(self): df = results.to_pandas_df() payload["data"] = dataframe.df_to_records(df) - self.assertDictEqual(deserialized_payload, payload) + self.assertDictEqual(deserialized_payload, payload) # noqa: PT009 expand_data.assert_called_once() @mock.patch.dict( @@ -960,7 +958,7 @@ def test_feature_flag_serialization(self): ] for url in urls: data = self.get_resp(url) - self.assertTrue(html_string in data) + assert html_string in data @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -991,7 +989,7 @@ def test_tabstate_with_name(self): tab_state_id = resp["id"] payload = self.get_json_resp(f"/tabstateview/{tab_state_id}") - self.assertEqual(payload["label"], "Untitled Query foo") + assert payload["label"] == "Untitled Query foo" def test_tabstate_update(self): self.login(ADMIN_USERNAME) @@ -1014,87 +1012,87 @@ def test_tabstate_update(self): client_id = "asdfasdf" data = {"sql": json.dumps("select 1"), "latest_query_id": json.dumps(client_id)} response = self.client.put(f"/tabstateview/{tab_state_id}", data=data) - self.assertEqual(response.status_code, 400) - self.assertEqual(response.json["error"], "Bad request") + assert response.status_code == 400 + assert response.json["error"] == "Bad request" # generate query db.session.add(Query(client_id=client_id, database_id=1)) db.session.commit() # update tab state with a valid client_id response = self.client.put(f"/tabstateview/{tab_state_id}", data=data) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 # nulls should be ok too data["latest_query_id"] = "null" response = self.client.put(f"/tabstateview/{tab_state_id}", data=data) - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 def test_virtual_table_explore_visibility(self): # test that default visibility it set to True database = superset.utils.database.get_example_database() - self.assertEqual(database.allows_virtual_table_explore, True) + assert database.allows_virtual_table_explore is True # test that visibility is disabled when extra is set to False extra = database.get_extra() extra["allows_virtual_table_explore"] = False database.extra = json.dumps(extra) - self.assertEqual(database.allows_virtual_table_explore, False) + assert database.allows_virtual_table_explore is False # test that visibility is enabled when extra is set to True extra = database.get_extra() extra["allows_virtual_table_explore"] = True database.extra = json.dumps(extra) - self.assertEqual(database.allows_virtual_table_explore, True) + assert database.allows_virtual_table_explore is True # test that visibility is not broken with bad values extra = database.get_extra() extra["allows_virtual_table_explore"] = "trash value" database.extra = json.dumps(extra) - self.assertEqual(database.allows_virtual_table_explore, True) + assert database.allows_virtual_table_explore is True def test_data_preview_visibility(self): # test that default visibility is allowed database = utils.get_example_database() - self.assertEqual(database.disable_data_preview, False) + assert database.disable_data_preview is False # test that visibility is disabled when extra is set to true extra = database.get_extra() extra["disable_data_preview"] = True database.extra = json.dumps(extra) - self.assertEqual(database.disable_data_preview, True) + assert database.disable_data_preview is True # test that visibility is enabled when extra is set to false extra = database.get_extra() extra["disable_data_preview"] = False database.extra = json.dumps(extra) - self.assertEqual(database.disable_data_preview, False) + assert database.disable_data_preview is False # test that visibility is not broken with bad values extra = database.get_extra() extra["disable_data_preview"] = "trash value" database.extra = json.dumps(extra) - self.assertEqual(database.disable_data_preview, False) + assert database.disable_data_preview is False def test_disable_drill_to_detail(self): # test that disable_drill_to_detail is False by default database = utils.get_example_database() - self.assertEqual(database.disable_drill_to_detail, False) + assert database.disable_drill_to_detail is False # test that disable_drill_to_detail can be set to True extra = database.get_extra() extra["disable_drill_to_detail"] = True database.extra = json.dumps(extra) - self.assertEqual(database.disable_drill_to_detail, True) + assert database.disable_drill_to_detail is True # test that disable_drill_to_detail can be set to False extra = database.get_extra() extra["disable_drill_to_detail"] = False database.extra = json.dumps(extra) - self.assertEqual(database.disable_drill_to_detail, False) + assert database.disable_drill_to_detail is False # test that disable_drill_to_detail is not broken with bad values extra = database.get_extra() extra["disable_drill_to_detail"] = "trash value" database.extra = json.dumps(extra) - self.assertEqual(database.disable_drill_to_detail, False) + assert database.disable_drill_to_detail is False def test_explore_database_id(self): database = superset.utils.database.get_example_database() @@ -1102,13 +1100,13 @@ def test_explore_database_id(self): # test that explore_database_id is the regular database # id if none is set in the extra - self.assertEqual(database.explore_database_id, database.id) + assert database.explore_database_id == database.id # test that explore_database_id is correct if the extra is set extra = database.get_extra() extra["explore_database_id"] = explore_database.id database.extra = json.dumps(extra) - self.assertEqual(database.explore_database_id, explore_database.id) + assert database.explore_database_id == explore_database.id def test_get_column_names_from_metric(self): simple_metric = { @@ -1146,7 +1144,7 @@ def test_explore_injected_exceptions(self, mock_db_connection_mutator): self.login(ADMIN_USERNAME) data = self.get_resp(url) - self.assertIn("Error message", data) + assert "Error message" in data # Assert we can handle a driver exception at the mutator level exception = SQLAlchemyError("Error message") @@ -1156,7 +1154,7 @@ def test_explore_injected_exceptions(self, mock_db_connection_mutator): self.login(ADMIN_USERNAME) data = self.get_resp(url) - self.assertIn("Error message", data) + assert "Error message" in data @pytest.mark.skip( "TODO This test was wrong - 'Error message' was in the language pack" @@ -1176,7 +1174,7 @@ def test_dashboard_injected_exceptions(self, mock_db_connection_mutator): self.login(ADMIN_USERNAME) data = self.get_resp(url) - self.assertIn("Error message", data) + assert "Error message" in data # Assert we can handle a driver exception at the mutator level exception = SQLAlchemyError("Error message") @@ -1186,7 +1184,7 @@ def test_dashboard_injected_exceptions(self, mock_db_connection_mutator): self.login(ADMIN_USERNAME) data = self.get_resp(url) - self.assertIn("Error message", data) + assert "Error message" in data @pytest.mark.usefixtures("load_energy_table_with_slice") @mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run") @@ -1200,9 +1198,7 @@ def test_explore_redirect(self, mock_command: mock.Mock): rv = self.client.get( f"/superset/explore/?form_data={quote(json.dumps(form_data))}" ) - self.assertEqual( - rv.headers["Location"], f"/explore/?form_data_key={random_key}" - ) + assert rv.headers["Location"] == f"/explore/?form_data_key={random_key}" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_has_table(self): @@ -1223,7 +1219,7 @@ def test_dashboard_permalink(self, get_dashboard_permalink_mock, request_mock): expected_url = "/superset/dashboard/1?permalink_key=123&standalone=3" - self.assertEqual(resp.headers["Location"], expected_url) + assert resp.headers["Location"] == expected_url assert resp.status_code == 302 diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py index bee8de7a5e064..021fa0e1b5b93 100644 --- a/tests/integration_tests/dashboard_tests.py +++ b/tests/integration_tests/dashboard_tests.py @@ -117,7 +117,7 @@ def test_new_dashboard(self): url = "/dashboard/new/" response = self.client.get(url, follow_redirects=False) dash_count_after = db.session.query(func.count(Dashboard.id)).first()[0] - self.assertEqual(dash_count_before + 1, dash_count_after) + assert dash_count_before + 1 == dash_count_after group = re.match( r"\/superset\/dashboard\/([0-9]*)\/\?edit=true", response.headers["Location"], @@ -145,25 +145,25 @@ def test_public_user_dashboard_access(self): self.logout() resp = self.get_resp("/api/v1/chart/") - self.assertNotIn("birth_names", resp) + assert "birth_names" not in resp resp = self.get_resp("/api/v1/dashboard/") - self.assertNotIn("/superset/dashboard/births/", resp) + assert "/superset/dashboard/births/" not in resp self.grant_public_access_to_table(table) # Try access after adding appropriate permissions. - self.assertIn("birth_names", self.get_resp("/api/v1/chart/")) + assert "birth_names" in self.get_resp("/api/v1/chart/") resp = self.get_resp("/api/v1/dashboard/") - self.assertIn("/superset/dashboard/births/", resp) + assert "/superset/dashboard/births/" in resp # Confirm that public doesn't have access to other datasets. resp = self.get_resp("/api/v1/chart/") - self.assertNotIn("wb_health_population", resp) + assert "wb_health_population" not in resp resp = self.get_resp("/api/v1/dashboard/") - self.assertNotIn("/superset/dashboard/world_health/", resp) + assert "/superset/dashboard/world_health/" not in resp # Cleanup self.revoke_public_access_to_table(table) @@ -224,8 +224,8 @@ def test_users_can_view_own_dashboard(self): db.session.delete(hidden_dash) db.session.commit() - self.assertIn(f"/superset/dashboard/{my_dash_slug}/", resp) - self.assertNotIn(f"/superset/dashboard/{not_my_dash_slug}/", resp) + assert f"/superset/dashboard/{my_dash_slug}/" in resp + assert f"/superset/dashboard/{not_my_dash_slug}/" not in resp def test_user_can_not_view_unpublished_dash(self): admin_user = security_manager.find_user("admin") @@ -247,7 +247,7 @@ def test_user_can_not_view_unpublished_dash(self): db.session.delete(dash) db.session.commit() - self.assertNotIn(f"/superset/dashboard/{slug}/", resp) + assert f"/superset/dashboard/{slug}/" not in resp if __name__ == "__main__": diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index e1f3d457a4208..f88bdccc0eccf 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -275,15 +275,15 @@ def test_get_dashboard_datasets(self, logger_mock): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/world_health/datasets" response = self.get_assert_metric(uri, "get_datasets") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = json.loads(response.data.decode("utf-8")) dashboard = Dashboard.get("world_health") expected_dataset_ids = {s.datasource_id for s in dashboard.slices} result = data["result"] actual_dataset_ids = {dataset["id"] for dataset in result} - self.assertEqual(actual_dataset_ids, expected_dataset_ids) + assert actual_dataset_ids == expected_dataset_ids expected_values = [0, 1] if backend() == "presto" else [0, 1, 2] - self.assertEqual(result[0]["column_types"], expected_values) + assert result[0]["column_types"] == expected_values logger_mock.warning.assert_not_called() @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @@ -293,13 +293,13 @@ def test_get_dashboard_datasets_as_guest(self, is_guest_user, has_guest_access): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/world_health/datasets" response = self.get_assert_metric(uri, "get_datasets") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = json.loads(response.data.decode("utf-8")) dashboard = Dashboard.get("world_health") expected_dataset_ids = {s.datasource_id for s in dashboard.slices} result = data["result"] actual_dataset_ids = {dataset["id"] for dataset in result} - self.assertEqual(actual_dataset_ids, expected_dataset_ids) + assert actual_dataset_ids == expected_dataset_ids for dataset in result: for excluded_key in ["database", "owners"]: assert excluded_key not in dataset @@ -310,7 +310,7 @@ def test_get_dashboard_datasets_not_found(self, logger_mock): self.login(ALPHA_USERNAME) uri = "api/v1/dashboard/not_found/datasets" response = self.get_assert_metric(uri, "get_datasets") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 logger_mock.warning.assert_called_once_with( "Dashboard not found.", exc_info=True ) @@ -325,7 +325,7 @@ def test_get_dashboard_datasets_invalid_schema( self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/world_health/datasets" response = self.get_assert_metric(uri, "get_datasets") - self.assertEqual(response.status_code, 422) + assert response.status_code == 422 logger_mock.warning.assert_called_once_with( "Dataset schema is invalid, caused by: Invalid schema", exc_info=True ) @@ -367,9 +367,9 @@ def get_dashboard_by_slug(self): dashboard = self.dashboards[0] uri = f"api/v1/dashboard/{dashboard.slug}" response = self.get_assert_metric(uri, "get") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = json.loads(response.data.decode("utf-8")) - self.assertEqual(data["id"], dashboard.id) + assert data["id"] == dashboard.id @pytest.mark.usefixtures("create_dashboards") def get_dashboard_by_bad_slug(self): @@ -377,7 +377,7 @@ def get_dashboard_by_bad_slug(self): dashboard = self.dashboards[0] uri = f"api/v1/dashboard/{dashboard.slug}-bad-slug" response = self.get_assert_metric(uri, "get") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 @pytest.mark.usefixtures("create_dashboards") def get_draft_dashboard_by_slug(self): @@ -388,7 +388,7 @@ def get_draft_dashboard_by_slug(self): dashboard = self.dashboards[0] uri = f"api/v1/dashboard/{dashboard.slug}" response = self.get_assert_metric(uri, "get") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 @pytest.mark.usefixtures("create_dashboards") def test_get_dashboard_charts(self): @@ -399,7 +399,7 @@ def test_get_dashboard_charts(self): dashboard = self.dashboards[0] uri = f"api/v1/dashboard/{dashboard.id}/charts" response = self.get_assert_metric(uri, "get_charts") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = json.loads(response.data.decode("utf-8")) assert len(data["result"]) == 1 result = data["result"][0] @@ -427,12 +427,10 @@ def test_get_dashboard_charts_by_slug(self): dashboard = self.dashboards[0] uri = f"api/v1/dashboard/{dashboard.slug}/charts" response = self.get_assert_metric(uri, "get_charts") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = json.loads(response.data.decode("utf-8")) - self.assertEqual(len(data["result"]), 1) - self.assertEqual( - data["result"][0]["slice_name"], dashboard.slices[0].slice_name - ) + assert len(data["result"]) == 1 + assert data["result"][0]["slice_name"] == dashboard.slices[0].slice_name @pytest.mark.usefixtures("create_dashboards") def test_get_dashboard_charts_not_found(self): @@ -443,14 +441,14 @@ def test_get_dashboard_charts_not_found(self): bad_id = self.get_nonexistent_numeric_id(Dashboard) uri = f"api/v1/dashboard/{bad_id}/charts" response = self.get_assert_metric(uri, "get_charts") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_dashboard_datasets_not_allowed(self): self.login(GAMMA_USERNAME) uri = "api/v1/dashboard/world_health/datasets" response = self.get_assert_metric(uri, "get_datasets") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 @pytest.mark.usefixtures("create_dashboards") def test_get_gamma_dashboard_charts(self): @@ -493,9 +491,9 @@ def test_get_dashboard_charts_empty(self): # the fixture setup assigns no charts to the second half of dashboards uri = f"api/v1/dashboard/{self.dashboards[-1].id}/charts" response = self.get_assert_metric(uri, "get_charts") - self.assertEqual(response.status_code, 200) + assert response.status_code == 200 data = json.loads(response.data.decode("utf-8")) - self.assertEqual(data["result"], []) + assert data["result"] == [] def test_get_dashboard(self): """ @@ -508,7 +506,7 @@ def test_get_dashboard(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard.id}" rv = self.get_assert_metric(uri, "get") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 with override_user(admin): expected_result = { "certified_by": None, @@ -543,9 +541,9 @@ def test_get_dashboard(self): "is_managed_externally": False, } data = json.loads(rv.data.decode("utf-8")) - self.assertIn("changed_on", data["result"]) - self.assertIn("changed_on_delta_humanized", data["result"]) - self.assertIn("created_on_delta_humanized", data["result"]) + assert "changed_on" in data["result"] + assert "changed_on_delta_humanized" in data["result"] + assert "created_on_delta_humanized" in data["result"] for key, value in data["result"].items(): # We can't assert timestamp values if key not in ( @@ -553,7 +551,7 @@ def test_get_dashboard(self): "changed_on_delta_humanized", "created_on_delta_humanized", ): - self.assertEqual(value, expected_result[key]) + assert value == expected_result[key] # rollback changes db.session.delete(dashboard) db.session.commit() @@ -573,7 +571,7 @@ def test_get_dashboard_as_guest(self, is_guest_user, has_guest_access): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard.id}" rv = self.get_assert_metric(uri, "get") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) for excluded_key in ["changed_by", "changed_by_name", "owners"]: assert excluded_key not in data["result"] @@ -588,7 +586,7 @@ def test_info_dashboard(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/_info" rv = self.get_assert_metric(uri, "info") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 def test_info_security_dashboard(self): """ @@ -619,7 +617,7 @@ def test_get_dashboard_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{bad_id}" rv = self.get_assert_metric(uri, "get") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_get_dashboard_no_data_access(self): """ @@ -656,12 +654,11 @@ def test_get_dashboards_changed_on(self): uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data["result"][0]["changed_on_delta_humanized"], - humanize.naturaltime(datetime.now()), - ) + assert data["result"][0][ + "changed_on_delta_humanized" + ] == humanize.naturaltime(datetime.now()) # rollback changes db.session.delete(dashboard) @@ -683,9 +680,9 @@ def test_get_dashboards_filter(self): uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 1) + assert data["count"] == 1 arguments = { "filters": [ @@ -694,9 +691,9 @@ def test_get_dashboards_filter(self): } uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 1) + assert data["count"] == 1 # rollback changes db.session.delete(dashboard) @@ -720,9 +717,9 @@ def test_get_dashboards_title_or_slug_filter(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 1) + assert data["count"] == 1 expected_response = [ {"slug": "slug1", "dashboard_title": "title1"}, @@ -733,9 +730,9 @@ def test_get_dashboards_title_or_slug_filter(self): arguments["filters"][0]["value"] = "slug2" uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 1) + assert data["count"] == 1 expected_response = [ {"slug": "slug2", "dashboard_title": "title2"}, @@ -746,9 +743,9 @@ def test_get_dashboards_title_or_slug_filter(self): self.login(GAMMA_USERNAME) uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 0) + assert data["count"] == 0 @pytest.mark.usefixtures("create_dashboards") def test_get_dashboards_favorite_filter(self): @@ -813,26 +810,22 @@ def test_get_dashboards_tag_filters(self): # Filter by tag ID filter_params = get_filter_params("dashboard_tag_id", tag.id) response_by_id = self.get_list("dashboard", filter_params) - self.assertEqual(response_by_id.status_code, 200) + assert response_by_id.status_code == 200 data_by_id = json.loads(response_by_id.data.decode("utf-8")) # Filter by tag name filter_params = get_filter_params("dashboard_tags", tag.name) response_by_name = self.get_list("dashboard", filter_params) - self.assertEqual(response_by_name.status_code, 200) + assert response_by_name.status_code == 200 data_by_name = json.loads(response_by_name.data.decode("utf-8")) # Compare results - self.assertEqual( - data_by_id["count"], - data_by_name["count"], - len(expected_dashboards), - ) - self.assertEqual( - set(chart["id"] for chart in data_by_id["result"]), - set(chart["id"] for chart in data_by_name["result"]), - set(chart.id for chart in expected_dashboards), + assert data_by_id["count"] == data_by_name["count"], len( + expected_dashboards ) + assert set(chart["id"] for chart in data_by_id["result"]) == set( + chart["id"] for chart in data_by_name["result"] + ), set(chart.id for chart in expected_dashboards) @pytest.mark.usefixtures("create_dashboards") def test_get_current_user_favorite_status(self): @@ -982,9 +975,9 @@ def test_gets_certified_dashboards_filter(self): uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], DASHBOARDS_FIXTURE_COUNT) + assert data["count"] == DASHBOARDS_FIXTURE_COUNT @pytest.mark.usefixtures("create_dashboards") def test_gets_not_certified_dashboards_filter(self): @@ -1003,9 +996,9 @@ def test_gets_not_certified_dashboards_filter(self): uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 0) + assert data["count"] == 0 @pytest.mark.usefixtures("create_created_by_gamma_dashboards") def test_get_dashboards_created_by_me(self): @@ -1206,8 +1199,8 @@ def test_get_dashboard_tabs(self): ], } } - self.assertEqual(rv.status_code, 200) - self.assertEqual(response, expected_response) + assert rv.status_code == 200 + assert response == expected_response db.session.delete(dashboard) db.session.commit() @@ -1220,7 +1213,7 @@ def test_get_dashboard_tabs_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{bad_id}/tabs" rv = self.get_assert_metric(uri, "get_tabs") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def create_dashboard_import(self): buf = BytesIO() @@ -1261,9 +1254,9 @@ def test_delete_dashboard(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard_id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_dashboards(self): """ @@ -1284,13 +1277,13 @@ def test_delete_bulk_dashboards(self): argument = dashboard_ids uri = f"api/v1/dashboard/?q={prison.dumps(argument)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": f"Deleted {dashboard_count} dashboards"} - self.assertEqual(response, expected_response) + assert response == expected_response for dashboard_id in dashboard_ids: model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_embedded_dashboards(self): """ @@ -1316,21 +1309,21 @@ def test_delete_bulk_embedded_dashboards(self): {"allowed_domains": allowed_domains}, "set_embedded", ) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8"))["result"] - self.assertIsNotNone(result["uuid"]) - self.assertNotEqual(result["uuid"], "") - self.assertEqual(result["allowed_domains"], allowed_domains) + assert result["uuid"] is not None + assert result["uuid"] != "" + assert result["allowed_domains"] == allowed_domains argument = dashboard_ids uri = f"api/v1/dashboard/?q={prison.dumps(argument)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": f"Deleted {dashboard_count} dashboards"} - self.assertEqual(response, expected_response) + assert response == expected_response for dashboard_id in dashboard_ids: model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_dashboards_bad_request(self): """ @@ -1341,7 +1334,7 @@ def test_delete_bulk_dashboards_bad_request(self): argument = dashboard_ids uri = f"api/v1/dashboard/?q={prison.dumps(argument)}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 def test_delete_not_found_dashboard(self): """ @@ -1351,7 +1344,7 @@ def test_delete_not_found_dashboard(self): dashboard_id = 1000 uri = f"api/v1/dashboard/{dashboard_id}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_dashboard_with_report") def test_delete_dashboard_with_report(self): @@ -1367,11 +1360,11 @@ def test_delete_dashboard_with_report(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": "There are associated alerts or reports: report_with_dashboard" } - self.assertEqual(response, expected_response) + assert response == expected_response def test_delete_bulk_dashboards_not_found(self): """ @@ -1382,7 +1375,7 @@ def test_delete_bulk_dashboards_not_found(self): argument = dashboard_ids uri = f"api/v1/dashboard/?q={prison.dumps(argument)}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_dashboard_with_report", "create_dashboards") def test_delete_bulk_dashboard_with_report(self): @@ -1406,11 +1399,11 @@ def test_delete_bulk_dashboard_with_report(self): uri = f"api/v1/dashboard/?q={prison.dumps(dashboard_ids)}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": "There are associated alerts or reports: report_with_dashboard" } - self.assertEqual(response, expected_response) + assert response == expected_response def test_delete_dashboard_admin_not_owned(self): """ @@ -1422,9 +1415,9 @@ def test_delete_dashboard_admin_not_owned(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard_id}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model, None) + assert model is None def test_delete_bulk_dashboard_admin_not_owned(self): """ @@ -1447,13 +1440,13 @@ def test_delete_bulk_dashboard_admin_not_owned(self): uri = f"api/v1/dashboard/?q={prison.dumps(argument)}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_response = {"message": f"Deleted {dashboard_count} dashboards"} - self.assertEqual(response, expected_response) + assert response == expected_response for dashboard_id in dashboard_ids: model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model, None) + assert model is None @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_delete_dashboard_not_owned(self): @@ -1475,7 +1468,7 @@ def test_delete_dashboard_not_owned(self): self.login(username="alpha2", password="password") uri = f"api/v1/dashboard/{dashboard.id}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 db.session.delete(dashboard) db.session.delete(user_alpha1) db.session.delete(user_alpha2) @@ -1523,19 +1516,19 @@ def test_delete_bulk_dashboard_not_owned(self): arguments = [dashboard.id for dashboard in dashboards] uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": "Forbidden"} - self.assertEqual(response, expected_response) + assert response == expected_response # nothing is deleted in bulk with a list of owned and not owned dashboards arguments = [dashboard.id for dashboard in dashboards] + [owned_dashboard.id] uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": "Forbidden"} - self.assertEqual(response, expected_response) + assert response == expected_response for dashboard in dashboards: db.session.delete(dashboard) @@ -1561,7 +1554,7 @@ def test_create_dashboard(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.post_assert_metric(uri, dashboard_data, "post") - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 data = json.loads(rv.data.decode("utf-8")) model = db.session.query(Dashboard).get(data.get("id")) db.session.delete(model) @@ -1575,7 +1568,7 @@ def test_create_simple_dashboard(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 data = json.loads(rv.data.decode("utf-8")) model = db.session.query(Dashboard).get(data.get("id")) db.session.delete(model) @@ -1589,7 +1582,7 @@ def test_create_dashboard_empty(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 data = json.loads(rv.data.decode("utf-8")) model = db.session.query(Dashboard).get(data.get("id")) db.session.delete(model) @@ -1599,7 +1592,7 @@ def test_create_dashboard_empty(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 data = json.loads(rv.data.decode("utf-8")) model = db.session.query(Dashboard).get(data.get("id")) db.session.delete(model) @@ -1613,12 +1606,12 @@ def test_create_dashboard_validate_title(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.post_assert_metric(uri, dashboard_data, "post") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": {"dashboard_title": ["Length must be between 0 and 500."]} } - self.assertEqual(response, expected_response) + assert response == expected_response def test_create_dashboard_validate_slug(self): """ @@ -1632,19 +1625,19 @@ def test_create_dashboard_validate_slug(self): dashboard_data = {"dashboard_title": "title2", "slug": "slug1"} uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"slug": ["Must be unique"]}} - self.assertEqual(response, expected_response) + assert response == expected_response # Check for slug max size dashboard_data = {"dashboard_title": "title2", "slug": "a" * 256} uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"slug": ["Length must be between 1 and 255."]}} - self.assertEqual(response, expected_response) + assert response == expected_response db.session.delete(dashboard) db.session.commit() @@ -1657,10 +1650,10 @@ def test_create_dashboard_validate_owners(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"owners": ["Owners are invalid"]}} - self.assertEqual(response, expected_response) + assert response == expected_response def test_create_dashboard_validate_roles(self): """ @@ -1670,10 +1663,10 @@ def test_create_dashboard_validate_roles(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"roles": ["Some roles do not exist"]}} - self.assertEqual(response, expected_response) + assert response == expected_response def test_create_dashboard_validate_json(self): """ @@ -1683,13 +1676,13 @@ def test_create_dashboard_validate_json(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 dashboard_data = {"dashboard_title": "title1", "json_metadata": '{"A:"a"}'} self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 dashboard_data = { "dashboard_title": "title1", @@ -1698,7 +1691,7 @@ def test_create_dashboard_validate_json(self): self.login(ADMIN_USERNAME) uri = "api/v1/dashboard/" rv = self.client.post(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 def test_update_dashboard(self): """ @@ -1712,16 +1705,16 @@ def test_update_dashboard(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard_id}" rv = self.put_assert_metric(uri, self.dashboard_data, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model.dashboard_title, self.dashboard_data["dashboard_title"]) - self.assertEqual(model.slug, self.dashboard_data["slug"]) - self.assertEqual(model.position_json, self.dashboard_data["position_json"]) - self.assertEqual(model.css, self.dashboard_data["css"]) - self.assertEqual(model.json_metadata, self.dashboard_data["json_metadata"]) - self.assertEqual(model.published, self.dashboard_data["published"]) - self.assertEqual(model.owners, [admin]) - self.assertEqual(model.roles, [admin_role]) + assert model.dashboard_title == self.dashboard_data["dashboard_title"] + assert model.slug == self.dashboard_data["slug"] + assert model.position_json == self.dashboard_data["position_json"] + assert model.css == self.dashboard_data["css"] + assert model.json_metadata == self.dashboard_data["json_metadata"] + assert model.published == self.dashboard_data["published"] + assert model.owners == [admin] + assert model.roles == [admin_role] db.session.delete(model) db.session.commit() @@ -1740,15 +1733,15 @@ def test_dashboard_get_list_no_username(self): uri = f"api/v1/dashboard/{dashboard_id}" dashboard_data = {"dashboard_title": "title2"} rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = self.get_assert_metric("api/v1/dashboard/", "get_list") res = json.loads(response.data.decode("utf-8"))["result"] current_dash = [d for d in res if d["id"] == dashboard_id][0] - self.assertEqual(current_dash["dashboard_title"], "title2") - self.assertNotIn("username", current_dash["changed_by"].keys()) - self.assertNotIn("username", current_dash["owners"][0].keys()) + assert current_dash["dashboard_title"] == "title2" + assert "username" not in current_dash["changed_by"].keys() + assert "username" not in current_dash["owners"][0].keys() db.session.delete(model) db.session.commit() @@ -1767,14 +1760,14 @@ def test_dashboard_get_no_username(self): uri = f"api/v1/dashboard/{dashboard_id}" dashboard_data = {"dashboard_title": "title2"} rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = self.get_assert_metric(uri, "get") res = json.loads(response.data.decode("utf-8"))["result"] - self.assertEqual(res["dashboard_title"], "title2") - self.assertNotIn("username", res["changed_by"].keys()) - self.assertNotIn("username", res["owners"][0].keys()) + assert res["dashboard_title"] == "title2" + assert "username" not in res["changed_by"].keys() + assert "username" not in res["owners"][0].keys() db.session.delete(model) db.session.commit() @@ -1822,11 +1815,11 @@ def test_update_dashboard_chart_owners_propagation(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard.id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # Check that chart named Boys does not contain alpha 1 in its owners boys = db.session.query(Slice).filter_by(slice_name="Boys").one() - self.assertNotIn(user_alpha1, boys.owners) + assert user_alpha1 not in boys.owners # Revert owners on slice for slice in slices: @@ -1849,20 +1842,20 @@ def test_update_partial_dashboard(self): rv = self.client.put( uri, json={"json_metadata": self.dashboard_data["json_metadata"]} ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 rv = self.client.put( uri, json={"dashboard_title": self.dashboard_data["dashboard_title"]} ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 rv = self.client.put(uri, json={"slug": self.dashboard_data["slug"]}) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model.json_metadata, self.dashboard_data["json_metadata"]) - self.assertEqual(model.dashboard_title, self.dashboard_data["dashboard_title"]) - self.assertEqual(model.slug, self.dashboard_data["slug"]) + assert model.json_metadata == self.dashboard_data["json_metadata"] + assert model.dashboard_title == self.dashboard_data["dashboard_title"] + assert model.slug == self.dashboard_data["slug"] db.session.delete(model) db.session.commit() @@ -1878,13 +1871,13 @@ def test_update_dashboard_new_owner_not_admin(self): self.login(ALPHA_USERNAME) uri = f"api/v1/dashboard/{dashboard_id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertIn(gamma, model.owners) - self.assertIn(alpha, model.owners) + assert gamma in model.owners + assert alpha in model.owners for slc in model.slices: - self.assertIn(gamma, slc.owners) - self.assertIn(alpha, slc.owners) + assert gamma in slc.owners + assert alpha in slc.owners db.session.delete(model) db.session.commit() @@ -1899,13 +1892,13 @@ def test_update_dashboard_new_owner_admin(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard_id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertIn(gamma, model.owners) - self.assertNotIn(admin, model.owners) + assert gamma in model.owners + assert admin not in model.owners for slc in model.slices: - self.assertIn(gamma, slc.owners) - self.assertNotIn(admin, slc.owners) + assert gamma in slc.owners + assert admin not in slc.owners db.session.delete(model) db.session.commit() @@ -1919,9 +1912,9 @@ def test_update_dashboard_clear_owner_list(self): uri = f"api/v1/dashboard/{dashboard_id}" dashboard_data = {"owners": []} rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual([], model.owners) + assert [] == model.owners db.session.delete(model) db.session.commit() @@ -1940,9 +1933,9 @@ def test_update_dashboard_populate_owner(self): uri = f"api/v1/dashboard/{dashboard.id}" dashboard_data = {"owners": [gamma.id]} rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard.id) - self.assertEqual([gamma], model.owners) + assert [gamma] == model.owners db.session.delete(model) db.session.commit() @@ -1956,10 +1949,10 @@ def test_update_dashboard_slug_formatting(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard_id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard_id) - self.assertEqual(model.dashboard_title, "title1_changed") - self.assertEqual(model.slug, "slug1-changed") + assert model.dashboard_title == "title1_changed" + assert model.slug == "slug1-changed" db.session.delete(model) db.session.commit() @@ -1976,10 +1969,10 @@ def test_update_dashboard_validate_slug(self): dashboard_data = {"dashboard_title": "title2", "slug": "slug 1"} uri = f"api/v1/dashboard/{dashboard2.id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"slug": ["Must be unique"]}} - self.assertEqual(response, expected_response) + assert response == expected_response db.session.delete(dashboard1) db.session.delete(dashboard2) @@ -1992,7 +1985,7 @@ def test_update_dashboard_validate_slug(self): dashboard_data = {"dashboard_title": "title2_changed", "slug": ""} uri = f"api/v1/dashboard/{dashboard2.id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 db.session.delete(dashboard1) db.session.delete(dashboard2) @@ -2010,13 +2003,13 @@ def test_update_published(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{dashboard.id}" rv = self.client.put(uri, json=dashboard_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard.id) - self.assertEqual(model.published, True) - self.assertEqual(model.slug, "slug1") - self.assertIn(admin, model.owners) - self.assertIn(gamma, model.owners) + assert model.published is True + assert model.slug == "slug1" + assert admin in model.owners + assert gamma in model.owners db.session.delete(model) db.session.commit() @@ -2041,7 +2034,7 @@ def test_update_dashboard_not_owned(self): dashboard_data = {"dashboard_title": "title1_changed", "slug": "slug1 changed"} uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, dashboard_data, "put") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 db.session.delete(dashboard) db.session.delete(user_alpha1) db.session.delete(user_alpha2) @@ -2074,7 +2067,7 @@ def test_export_not_found(self): argument = [1000] uri = f"api/v1/dashboard/export/?q={prison.dumps(argument)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_export_not_allowed(self): """ @@ -2087,7 +2080,7 @@ def test_export_not_allowed(self): argument = [dashboard.id] uri = f"api/v1/dashboard/export/?q={prison.dumps(argument)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 db.session.delete(dashboard) db.session.commit() @@ -2429,7 +2422,7 @@ def test_embedded_dashboards(self): # initial get should return 404 resp = self.get_assert_metric(uri, "get_embedded") - self.assertEqual(resp.status_code, 404) + assert resp.status_code == 404 # post succeeds and returns value allowed_domains = ["test.example", "embedded.example"] @@ -2438,46 +2431,46 @@ def test_embedded_dashboards(self): {"allowed_domains": allowed_domains}, "set_embedded", ) - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8"))["result"] - self.assertIsNotNone(result["uuid"]) - self.assertNotEqual(result["uuid"], "") - self.assertEqual(result["allowed_domains"], allowed_domains) + assert result["uuid"] is not None + assert result["uuid"] != "" + assert result["allowed_domains"] == allowed_domains # get returns value resp = self.get_assert_metric(uri, "get_embedded") - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8"))["result"] - self.assertIsNotNone(result["uuid"]) - self.assertNotEqual(result["uuid"], "") - self.assertEqual(result["allowed_domains"], allowed_domains) + assert result["uuid"] is not None + assert result["uuid"] != "" + assert result["allowed_domains"] == allowed_domains # save uuid for later original_uuid = result["uuid"] # put succeeds and returns value resp = self.post_assert_metric(uri, {"allowed_domains": []}, "set_embedded") - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8"))["result"] - self.assertEqual(resp.status_code, 200) - self.assertIsNotNone(result["uuid"]) - self.assertNotEqual(result["uuid"], "") - self.assertEqual(result["allowed_domains"], []) + assert resp.status_code == 200 + assert result["uuid"] is not None + assert result["uuid"] != "" + assert result["allowed_domains"] == [] # get returns changed value resp = self.get_assert_metric(uri, "get_embedded") - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 result = json.loads(resp.data.decode("utf-8"))["result"] - self.assertEqual(result["uuid"], original_uuid) - self.assertEqual(result["allowed_domains"], []) + assert result["uuid"] == original_uuid + assert result["allowed_domains"] == [] # delete succeeds resp = self.delete_assert_metric(uri, "delete_embedded") - self.assertEqual(resp.status_code, 200) + assert resp.status_code == 200 # get returns 404 resp = self.get_assert_metric(uri, "get_embedded") - self.assertEqual(resp.status_code, 404) + assert resp.status_code == 404 @pytest.mark.usefixtures("create_created_by_gamma_dashboards") def test_gets_created_by_user_dashboards_filter(self): @@ -2498,9 +2491,9 @@ def test_gets_created_by_user_dashboards_filter(self): uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], len(expected_models)) + assert data["count"] == len(expected_models) def test_gets_not_created_by_user_dashboards_filter(self): dashboard = self.insert_dashboard("title", "slug", []) # noqa: F541 @@ -2519,9 +2512,9 @@ def test_gets_not_created_by_user_dashboards_filter(self): uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.get_assert_metric(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], len(expected_models)) + assert data["count"] == len(expected_models) db.session.delete(dashboard) db.session.commit() @@ -2547,23 +2540,23 @@ def test_copy_dashboard(self): pk = original_dash.id uri = f"api/v1/dashboard/{pk}/copy/" rv = self.client.post(uri, json=data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response, {"result": {"id": ANY, "last_modified_time": ANY}}) + assert response == {"result": {"id": ANY, "last_modified_time": ANY}} dash = ( db.session.query(Dashboard) .filter(Dashboard.id == response["result"]["id"]) .one() ) - self.assertNotEqual(dash.id, original_dash.id) - self.assertEqual(len(dash.position), len(original_dash.position)) - self.assertEqual(dash.dashboard_title, "copied dash") - self.assertEqual(dash.css, "") - self.assertEqual(dash.owners, [security_manager.find_user("admin")]) - self.assertCountEqual(dash.slices, original_dash.slices) - self.assertEqual(dash.params_dict["color_namespace"], "Color Namespace Test") - self.assertEqual(dash.params_dict["color_scheme"], "Color Scheme Test") + assert dash.id != original_dash.id + assert len(dash.position) == len(original_dash.position) + assert dash.dashboard_title == "copied dash" + assert dash.css == "" + assert dash.owners == [security_manager.find_user("admin")] + self.assertCountEqual(dash.slices, original_dash.slices) # noqa: PT009 + assert dash.params_dict["color_namespace"] == "Color Namespace Test" + assert dash.params_dict["color_scheme"] == "Color Scheme Test" db.session.delete(dash) db.session.commit() @@ -2590,26 +2583,26 @@ def test_copy_dashboard_duplicate_slices(self): pk = original_dash.id uri = f"api/v1/dashboard/{pk}/copy/" rv = self.client.post(uri, json=data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response, {"result": {"id": ANY, "last_modified_time": ANY}}) + assert response == {"result": {"id": ANY, "last_modified_time": ANY}} dash = ( db.session.query(Dashboard) .filter(Dashboard.id == response["result"]["id"]) .one() ) - self.assertNotEqual(dash.id, original_dash.id) - self.assertEqual(len(dash.position), len(original_dash.position)) - self.assertEqual(dash.dashboard_title, "copied dash") - self.assertEqual(dash.css, "") - self.assertEqual(dash.owners, [security_manager.find_user("admin")]) - self.assertEqual(dash.params_dict["color_namespace"], "Color Namespace Test") - self.assertEqual(dash.params_dict["color_scheme"], "Color Scheme Test") - self.assertEqual(len(dash.slices), len(original_dash.slices)) + assert dash.id != original_dash.id + assert len(dash.position) == len(original_dash.position) + assert dash.dashboard_title == "copied dash" + assert dash.css == "" + assert dash.owners == [security_manager.find_user("admin")] + assert dash.params_dict["color_namespace"] == "Color Namespace Test" + assert dash.params_dict["color_scheme"] == "Color Scheme Test" + assert len(dash.slices) == len(original_dash.slices) for original_slc in original_dash.slices: for slc in dash.slices: - self.assertNotEqual(slc.id, original_slc.id) + assert slc.id != original_slc.id for slc in dash.slices: db.session.delete(slc) @@ -2639,12 +2632,12 @@ def test_update_dashboard_add_tags_can_write_on_tag(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(sorted(tag_list), sorted(new_tags)) + assert sorted(tag_list) == sorted(new_tags) @pytest.mark.usefixtures("create_dashboard_with_tag") def test_update_dashboard_remove_tags_can_write_on_tag(self): @@ -2668,12 +2661,12 @@ def test_update_dashboard_remove_tags_can_write_on_tag(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, new_tags) + assert tag_list == new_tags @pytest.mark.usefixtures("create_dashboard_with_tag") def test_update_dashboard_add_tags_can_tag_on_dashboard(self): @@ -2701,12 +2694,12 @@ def test_update_dashboard_add_tags_can_tag_on_dashboard(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(sorted(tag_list), sorted(new_tags)) + assert sorted(tag_list) == sorted(new_tags) security_manager.add_permission_role(gamma_role, write_tags_perm) @@ -2732,12 +2725,12 @@ def test_update_dashboard_remove_tags_can_tag_on_dashboard(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Dashboard).get(dashboard.id) # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, []) + assert tag_list == [] security_manager.add_permission_role(gamma_role, write_tags_perm) @@ -2770,10 +2763,10 @@ def test_update_dashboard_add_tags_missing_permission(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 403) - self.assertEqual( - rv.json["message"], - "You do not have permission to manage tags on dashboards", + assert rv.status_code == 403 + assert ( + rv.json["message"] + == "You do not have permission to manage tags on dashboards" ) security_manager.add_permission_role(gamma_role, write_tags_perm) @@ -2804,10 +2797,10 @@ def test_update_dashboard_remove_tags_missing_permission(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 403) - self.assertEqual( - rv.json["message"], - "You do not have permission to manage tags on dashboards", + assert rv.status_code == 403 + assert ( + rv.json["message"] + == "You do not have permission to manage tags on dashboards" ) security_manager.add_permission_role(gamma_role, write_tags_perm) @@ -2838,7 +2831,7 @@ def test_update_dashboard_no_tag_changes(self): uri = f"api/v1/dashboard/{dashboard.id}" rv = self.put_assert_metric(uri, update_payload, "put") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 security_manager.add_permission_role(gamma_role, write_tags_perm) security_manager.add_permission_role(gamma_role, tag_dashboards_perm) @@ -2862,7 +2855,7 @@ def test_cache_dashboard_screenshot_success(self): .first() ) response = self._cache_screenshot(dashboard.id) - self.assertEqual(response.status_code, 202) + assert response.status_code == 202 @pytest.mark.usefixtures("create_dashboard_with_tag") def test_cache_dashboard_screenshot_dashboard_validation(self): @@ -2879,13 +2872,13 @@ def test_cache_dashboard_screenshot_dashboard_validation(self): "urlParams": "should be a list", } response = self._cache_screenshot(dashboard.id, invalid_payload) - self.assertEqual(response.status_code, 400) + assert response.status_code == 400 def test_cache_dashboard_screenshot_dashboard_not_found(self): self.login(ADMIN_USERNAME) non_existent_id = 999 response = self._cache_screenshot(non_existent_id) - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @@ -2904,13 +2897,13 @@ def test_screenshot_success_png(self, mock_get_cache, mock_cache_task): .first() ) cache_resp = self._cache_screenshot(dashboard.id) - self.assertEqual(cache_resp.status_code, 202) + assert cache_resp.status_code == 202 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] response = self._get_screenshot(dashboard.id, cache_key, "png") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.mimetype, "image/png") - self.assertEqual(response.data, b"fake image data") + assert response.status_code == 200 + assert response.mimetype == "image/png" + assert response.data == b"fake image data" @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @@ -2933,13 +2926,13 @@ def test_screenshot_success_pdf( .first() ) cache_resp = self._cache_screenshot(dashboard.id) - self.assertEqual(cache_resp.status_code, 202) + assert cache_resp.status_code == 202 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] response = self._get_screenshot(dashboard.id, cache_key, "pdf") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.mimetype, "application/pdf") - self.assertEqual(response.data, b"fake pdf data") + assert response.status_code == 200 + assert response.mimetype == "application/pdf" + assert response.data == b"fake pdf data" @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @@ -2955,17 +2948,17 @@ def test_screenshot_not_in_cache(self, mock_get_cache, mock_cache_task): .first() ) cache_resp = self._cache_screenshot(dashboard.id) - self.assertEqual(cache_resp.status_code, 202) + assert cache_resp.status_code == 202 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] response = self._get_screenshot(dashboard.id, cache_key, "pdf") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 def test_screenshot_dashboard_not_found(self): self.login(ADMIN_USERNAME) non_existent_id = 999 response = self._get_screenshot(non_existent_id, "some_cache_key", "png") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @@ -2982,8 +2975,8 @@ def test_screenshot_invalid_download_format(self, mock_get_cache, mock_cache_tas ) cache_resp = self._cache_screenshot(dashboard.id) - self.assertEqual(cache_resp.status_code, 202) + assert cache_resp.status_code == 202 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] response = self._get_screenshot(dashboard.id, cache_key, "invalid") - self.assertEqual(response.status_code, 404) + assert response.status_code == 404 diff --git a/tests/integration_tests/dashboards/base_case.py b/tests/integration_tests/dashboards/base_case.py index 8b9b8e95ed6b6..8646200140452 100644 --- a/tests/integration_tests/dashboards/base_case.py +++ b/tests/integration_tests/dashboards/base_case.py @@ -63,19 +63,19 @@ def delete_dashboard(self, dashboard_id: int) -> Response: def assert_permission_was_created(self, dashboard): view_menu = security_manager.find_view_menu(dashboard.view_name) - self.assertIsNotNone(view_menu) - self.assertEqual(len(security_manager.find_permissions_view_menu(view_menu)), 1) + assert view_menu is not None + assert len(security_manager.find_permissions_view_menu(view_menu)) == 1 def assert_permission_kept_and_changed(self, updated_dashboard, excepted_view_id): view_menu_after_title_changed = security_manager.find_view_menu( updated_dashboard.view_name ) - self.assertIsNotNone(view_menu_after_title_changed) - self.assertEqual(view_menu_after_title_changed.id, excepted_view_id) + assert view_menu_after_title_changed is not None + assert view_menu_after_title_changed.id == excepted_view_id def assert_permissions_were_deleted(self, deleted_dashboard): view_menu = security_manager.find_view_menu(deleted_dashboard.view_name) - self.assertIsNone(view_menu) + assert view_menu is None def clean_created_objects(self): with app.test_request_context(): diff --git a/tests/integration_tests/dashboards/dao_tests.py b/tests/integration_tests/dashboards/dao_tests.py index 83ef02730b0a8..86197c0d5aef7 100644 --- a/tests/integration_tests/dashboards/dao_tests.py +++ b/tests/integration_tests/dashboards/dao_tests.py @@ -87,12 +87,12 @@ def test_copy_dashboard(self, mock_g): "duplicate_slices": False, } dash = DashboardDAO.copy_dashboard(original_dash, dash_data) - self.assertNotEqual(dash.id, original_dash.id) - self.assertEqual(len(dash.position), len(original_dash.position)) - self.assertEqual(dash.dashboard_title, "copied dash") - self.assertEqual(dash.css, "") - self.assertEqual(dash.owners, [security_manager.find_user("admin")]) - self.assertCountEqual(dash.slices, original_dash.slices) + assert dash.id != original_dash.id + assert len(dash.position) == len(original_dash.position) + assert dash.dashboard_title == "copied dash" + assert dash.css == "" + assert dash.owners == [security_manager.find_user("admin")] + self.assertCountEqual(dash.slices, original_dash.slices) # noqa: PT009 db.session.delete(dash) db.session.commit() @@ -118,9 +118,7 @@ def test_copy_dashboard_copies_native_filters(self, mock_g): "duplicate_slices": False, } dash = DashboardDAO.copy_dashboard(original_dash, dash_data) - self.assertEqual( - dash.params_dict["native_filter_configuration"], [{"mock": "filter"}] - ) + assert dash.params_dict["native_filter_configuration"] == [{"mock": "filter"}] db.session.delete(dash) db.session.commit() @@ -141,15 +139,15 @@ def test_copy_dashboard_duplicate_slices(self, mock_g): "duplicate_slices": True, } dash = DashboardDAO.copy_dashboard(original_dash, dash_data) - self.assertNotEqual(dash.id, original_dash.id) - self.assertEqual(len(dash.position), len(original_dash.position)) - self.assertEqual(dash.dashboard_title, "copied dash") - self.assertEqual(dash.css, "") - self.assertEqual(dash.owners, [security_manager.find_user("admin")]) - self.assertEqual(len(dash.slices), len(original_dash.slices)) + assert dash.id != original_dash.id + assert len(dash.position) == len(original_dash.position) + assert dash.dashboard_title == "copied dash" + assert dash.css == "" + assert dash.owners == [security_manager.find_user("admin")] + assert len(dash.slices) == len(original_dash.slices) for original_slc in original_dash.slices: for slc in dash.slices: - self.assertNotEqual(slc.id, original_slc.id) + assert slc.id != original_slc.id for slc in dash.slices: db.session.delete(slc) diff --git a/tests/integration_tests/dashboards/security/security_dataset_tests.py b/tests/integration_tests/dashboards/security/security_dataset_tests.py index 17a5c477e6b85..cf2275680eb4a 100644 --- a/tests/integration_tests/dashboards/security/security_dataset_tests.py +++ b/tests/integration_tests/dashboards/security/security_dataset_tests.py @@ -109,8 +109,8 @@ def test_get_dashboards__users_are_dashboards_owners(self): get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405 # assert - self.assertIn(my_owned_dashboard.url, get_dashboards_response) - self.assertNotIn(not_my_owned_dashboard.url, get_dashboards_response) + assert my_owned_dashboard.url in get_dashboards_response + assert not_my_owned_dashboard.url not in get_dashboards_response def test_get_dashboards__owners_can_view_empty_dashboard(self): # arrange @@ -123,7 +123,7 @@ def test_get_dashboards__owners_can_view_empty_dashboard(self): get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405 # assert - self.assertNotIn(dashboard_url, get_dashboards_response) + assert dashboard_url not in get_dashboards_response def test_get_dashboards__user_can_not_view_unpublished_dash(self): # arrange @@ -139,9 +139,7 @@ def test_get_dashboards__user_can_not_view_unpublished_dash(self): get_dashboards_response_as_gamma = self.get_resp(DASHBOARDS_API_URL) # noqa: F405 # assert - self.assertNotIn( - admin_and_draft_dashboard.url, get_dashboards_response_as_gamma - ) + assert admin_and_draft_dashboard.url not in get_dashboards_response_as_gamma @pytest.mark.usefixtures("load_energy_table_with_slice", "load_dashboard") def test_get_dashboards__users_can_view_permitted_dashboard(self): @@ -172,8 +170,8 @@ def test_get_dashboards__users_can_view_permitted_dashboard(self): get_dashboards_response = self.get_resp(DASHBOARDS_API_URL) # noqa: F405 # assert - self.assertIn(second_dash.url, get_dashboards_response) - self.assertIn(first_dash.url, get_dashboards_response) + assert second_dash.url in get_dashboards_response + assert first_dash.url in get_dashboards_response finally: self.revoke_public_access_to_table(accessed_table) @@ -193,5 +191,5 @@ def test_get_dashboards_api_no_data_access(self): rv = self.client.get(uri) self.assert200(rv) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(0, data["count"]) + assert 0 == data["count"] DashboardDAO.delete([dashboard]) diff --git a/tests/integration_tests/dashboards/security/security_rbac_tests.py b/tests/integration_tests/dashboards/security/security_rbac_tests.py index 797212d516908..4ecc2e3e386b4 100644 --- a/tests/integration_tests/dashboards/security/security_rbac_tests.py +++ b/tests/integration_tests/dashboards/security/security_rbac_tests.py @@ -207,7 +207,7 @@ def test_get_dashboard_view__user_access_with_dashboard_permission(self): request_payload = get_query_context("birth_names") rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 # post revoke_access_to_dashboard(dashboard_to_access, new_role) # noqa: F405 @@ -480,12 +480,12 @@ def test_copy_dashboard_via_api(self): self.login(GAMMA_USERNAME) rv = self.client.post(uri, json=data) - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 self.logout() self.login(ADMIN_USERNAME) rv = self.client.post(uri, json=data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) target = ( diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index c5cd20b5dc347..e4fbbc6e48e37 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -187,7 +187,7 @@ def test_get_items(self): self.login(ADMIN_USERNAME) uri = "api/v1/database/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) expected_columns = [ "allow_ctas", @@ -216,8 +216,8 @@ def test_get_items(self): "uuid", ] - self.assertGreater(response["count"], 0) - self.assertEqual(list(response["result"][0].keys()), expected_columns) + assert response["count"] > 0 + assert list(response["result"][0].keys()) == expected_columns def test_get_items_filter(self): """ @@ -241,8 +241,8 @@ def test_get_items_filter(self): uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(response["count"], len(dbs)) + assert rv.status_code == 200 + assert response["count"] == len(dbs) # Cleanup db.session.delete(test_database) @@ -255,9 +255,9 @@ def test_get_items_not_allowed(self): self.login(GAMMA_USERNAME) uri = "api/v1/database/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["count"], 0) + assert response["count"] == 0 def test_create_database(self): """ @@ -284,7 +284,7 @@ def test_create_database(self): uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 # Cleanup model = db.session.query(Database).get(response.get("id")) assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM @@ -326,14 +326,14 @@ def test_create_database_with_ssh_tunnel( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) - self.assertEqual(response.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX") - self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" + assert model_ssh_tunnel.database_id == response.get("id") # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) @@ -385,10 +385,10 @@ def test_create_database_with_missing_port_raises_error( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data_with_ssh_tunnel) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertEqual( - response.get("message"), - "A database port is required when connecting via SSH Tunnel.", + assert rv.status_code == 400 + assert ( + response.get("message") + == "A database port is required when connecting via SSH Tunnel." ) @mock.patch( @@ -434,19 +434,19 @@ def test_update_database_with_ssh_tunnel( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 uri = "api/v1/database/{}".format(response.get("id")) rv = self.client.put(uri, json=database_data_with_ssh_tunnel) response_update = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response_update.get("id")) .one() ) - self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + assert model_ssh_tunnel.database_id == response_update.get("id") # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) @@ -500,15 +500,15 @@ def test_update_database_with_missing_port_raises_error( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response_create = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 uri = "api/v1/database/{}".format(response_create.get("id")) rv = self.client.put(uri, json=database_data_with_ssh_tunnel) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertEqual( - response.get("message"), - "A database port is required when connecting via SSH Tunnel.", + assert rv.status_code == 400 + assert ( + response.get("message") + == "A database port is required when connecting via SSH Tunnel." ) # Cleanup @@ -563,19 +563,19 @@ def test_delete_ssh_tunnel( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 uri = "api/v1/database/{}".format(response.get("id")) rv = self.client.put(uri, json=database_data_with_ssh_tunnel) response_update = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response_update.get("id")) .one() ) - self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + assert model_ssh_tunnel.database_id == response_update.get("id") database_data_with_ssh_tunnel_null = { "database_name": "test-db-with-ssh-tunnel", @@ -585,7 +585,7 @@ def test_delete_ssh_tunnel( rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null) response_update = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model_ssh_tunnel = ( db.session.query(SSHTunnel) @@ -651,30 +651,28 @@ def test_update_ssh_tunnel_via_database_api( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data_with_ssh_tunnel) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) - self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) - self.assertEqual(model_ssh_tunnel.username, "foo") + assert model_ssh_tunnel.database_id == response.get("id") + assert model_ssh_tunnel.username == "foo" uri = "api/v1/database/{}".format(response.get("id")) rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update) response_update = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response_update.get("id")) .one() ) - self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) - self.assertEqual( - response_update.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX" - ) - self.assertEqual(model_ssh_tunnel.username, "Test") - self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1") - self.assertEqual(model_ssh_tunnel.server_port, 8080) + assert model_ssh_tunnel.database_id == response_update.get("id") + assert response_update.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" + assert model_ssh_tunnel.username == "Test" + assert model_ssh_tunnel.server_address == "123.132.123.1" + assert model_ssh_tunnel.server_port == 8080 # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) @@ -715,13 +713,13 @@ def test_cascade_delete_ssh_tunnel( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) - self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + assert model_ssh_tunnel.database_id == response.get("id") # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) @@ -769,7 +767,7 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 model_ssh_tunnel = ( db.session.query(SSHTunnel) @@ -777,7 +775,7 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( .one_or_none() ) assert model_ssh_tunnel is None - self.assertEqual(response, fail_message) + assert response == fail_message # Check that rollback was called mock_rollback.assert_called() @@ -824,14 +822,14 @@ def test_get_database_returns_related_ssh_tunnel( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) - self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) - self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel) + assert model_ssh_tunnel.database_id == response.get("id") + assert response.get("result")["ssh_tunnel"] == response_ssh_tunnel # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) @@ -866,8 +864,8 @@ def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception( uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertEqual(response, {"message": "SSH Tunneling is not enabled"}) + assert rv.status_code == 400 + assert response == {"message": "SSH Tunneling is not enabled"} model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) @@ -897,7 +895,7 @@ def test_get_table_details_with_slash_in_table_name(self): uri = f"api/v1/database/{database.id}/table/{table_name}/null/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 def test_create_database_invalid_configuration_method(self): """ @@ -959,7 +957,7 @@ def test_create_database_no_configuration_method(self): rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 201 - self.assertIn("sqlalchemy_form", response["result"]["configuration_method"]) + assert "sqlalchemy_form" in response["result"]["configuration_method"] def test_create_database_server_cert_validate(self): """ @@ -981,8 +979,8 @@ def test_create_database_server_cert_validate(self): rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"server_cert": ["Invalid certificate"]}} - self.assertEqual(rv.status_code, 400) - self.assertEqual(response, expected_response) + assert rv.status_code == 400 + assert response == expected_response def test_create_database_json_validate(self): """ @@ -1016,8 +1014,8 @@ def test_create_database_json_validate(self): ], } } - self.assertEqual(rv.status_code, 400) - self.assertEqual(response, expected_response) + assert rv.status_code == 400 + assert response == expected_response def test_create_database_extra_metadata_validate(self): """ @@ -1052,8 +1050,8 @@ def test_create_database_extra_metadata_validate(self): ] } } - self.assertEqual(rv.status_code, 400) - self.assertEqual(response, expected_response) + assert rv.status_code == 400 + assert response == expected_response def test_create_database_unique_validate(self): """ @@ -1078,8 +1076,8 @@ def test_create_database_unique_validate(self): "database_name": "A database with the same name already exists." } } - self.assertEqual(rv.status_code, 422) - self.assertEqual(response, expected_response) + assert rv.status_code == 422 + assert response == expected_response def test_create_database_uri_validate(self): """ @@ -1095,11 +1093,8 @@ def test_create_database_uri_validate(self): uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertIn( - "Invalid connection string", - response["message"]["sqlalchemy_uri"][0], - ) + assert rv.status_code == 400 + assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0] @mock.patch( "superset.views.core.app.config", @@ -1127,8 +1122,8 @@ def test_create_database_fail_sqlite(self): ] } } - self.assertEqual(response_data, expected_response) - self.assertEqual(response.status_code, 400) + assert response_data == expected_response + assert response.status_code == 400 def test_create_database_conn_fail(self): """ @@ -1192,11 +1187,11 @@ def test_create_database_conn_fail(self): expected_response_postgres = { "errors": [dataclasses.asdict(superset_error_postgres)] } - self.assertEqual(response.status_code, 500) + assert response.status_code == 500 if example_db.backend == "mysql": - self.assertEqual(response_data, expected_response_mysql) + assert response_data == expected_response_mysql else: - self.assertEqual(response_data, expected_response_postgres) + assert response_data == expected_response_postgres def test_update_database(self): """ @@ -1213,7 +1208,7 @@ def test_update_database(self): } uri = f"api/v1/database/{test_database.id}" rv = self.client.put(uri, json=database_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # Cleanup model = db.session.query(Database).get(test_database.id) db.session.delete(model) @@ -1242,8 +1237,8 @@ def test_update_database_conn_fail(self): expected_response = { "message": "Connection failed, please check your connection settings" } - self.assertEqual(rv.status_code, 422) - self.assertEqual(response, expected_response) + assert rv.status_code == 422 + assert response == expected_response # Cleanup model = db.session.query(Database).get(test_database.id) db.session.delete(model) @@ -1271,8 +1266,8 @@ def test_update_database_uniqueness(self): "database_name": "A database with the same name already exists." } } - self.assertEqual(rv.status_code, 422) - self.assertEqual(response, expected_response) + assert rv.status_code == 422 + assert response == expected_response # Cleanup db.session.delete(test_database1) db.session.delete(test_database2) @@ -1286,7 +1281,7 @@ def test_update_database_invalid(self): database_data = {"database_name": "test-database-updated"} uri = "api/v1/database/invalid" rv = self.client.put(uri, json=database_data) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_update_database_uri_validate(self): """ @@ -1305,11 +1300,8 @@ def test_update_database_uri_validate(self): uri = f"api/v1/database/{test_database.id}" rv = self.client.put(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertIn( - "Invalid connection string", - response["message"]["sqlalchemy_uri"][0], - ) + assert rv.status_code == 400 + assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0] db.session.delete(test_database) db.session.commit() @@ -1369,9 +1361,9 @@ def test_delete_database(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{database_id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 model = db.session.query(Database).get(database_id) - self.assertEqual(model, None) + assert model is None def test_delete_database_not_found(self): """ @@ -1381,7 +1373,7 @@ def test_delete_database_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{max_id + 1}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_database_with_dataset") def test_delete_database_with_datasets(self): @@ -1391,7 +1383,7 @@ def test_delete_database_with_datasets(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{self._database.id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 @pytest.mark.usefixtures("create_database_with_report") def test_delete_database_with_report(self): @@ -1407,11 +1399,11 @@ def test_delete_database_with_report(self): uri = f"api/v1/database/{database.id}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 expected_response = { "message": "There are associated alerts or reports: report_with_database" } - self.assertEqual(response, expected_response) + assert response == expected_response @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_table_metadata(self): @@ -1422,12 +1414,12 @@ def test_get_table_metadata(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{example_db.id}/table/birth_names/null/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["name"], "birth_names") - self.assertIsNone(response["comment"]) - self.assertTrue(len(response["columns"]) > 5) - self.assertTrue(response.get("selectStar").startswith("SELECT")) + assert response["name"] == "birth_names" + assert response["comment"] is None + assert len(response["columns"]) > 5 + assert response.get("selectStar").startswith("SELECT") def test_info_security_database(self): """ @@ -1456,11 +1448,11 @@ def test_get_invalid_database_table_metadata(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{database_id}/table/some_table/some_schema/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 uri = "api/v1/database/some_database/table/some_table/some_schema/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_get_invalid_table_table_metadata(self): """ @@ -1472,25 +1464,22 @@ def test_get_invalid_table_table_metadata(self): rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) if example_db.backend == "sqlite": - self.assertEqual(rv.status_code, 200) - self.assertEqual( - data, - { - "columns": [], - "comment": None, - "foreignKeys": [], - "indexes": [], - "name": "wrong_table", - "primaryKey": {"constrained_columns": None, "name": None}, - "selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0", - }, - ) + assert rv.status_code == 200 + assert data == { + "columns": [], + "comment": None, + "foreignKeys": [], + "indexes": [], + "name": "wrong_table", + "primaryKey": {"constrained_columns": None, "name": None}, + "selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0", + } elif example_db.backend == "mysql": - self.assertEqual(rv.status_code, 422) - self.assertEqual(data, {"message": "`wrong_table`"}) + assert rv.status_code == 422 + assert data == {"message": "`wrong_table`"} else: - self.assertEqual(rv.status_code, 422) - self.assertEqual(data, {"message": "wrong_table"}) + assert rv.status_code == 422 + assert data == {"message": "wrong_table"} def test_get_table_metadata_no_db_permission(self): """ @@ -1500,7 +1489,7 @@ def test_get_table_metadata_no_db_permission(self): example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/birth_names/null/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_table_extra_metadata_deprecated(self): @@ -1511,9 +1500,9 @@ def test_get_table_extra_metadata_deprecated(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response, {}) + assert response == {} def test_get_invalid_database_table_extra_metadata_deprecated(self): """ @@ -1523,11 +1512,11 @@ def test_get_invalid_database_table_extra_metadata_deprecated(self): self.login(ADMIN_USERNAME) uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 uri = "api/v1/database/some_database/table_extra/some_table/some_schema/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_get_invalid_table_table_extra_metadata_deprecated(self): """ @@ -1539,8 +1528,8 @@ def test_get_invalid_table_table_extra_metadata_deprecated(self): rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(data, {}) + assert rv.status_code == 200 + assert data == {} @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_select_star(self): @@ -1551,7 +1540,7 @@ def test_get_select_star(self): example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 def test_get_select_star_not_allowed(self): """ @@ -1561,7 +1550,7 @@ def test_get_select_star_not_allowed(self): example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_get_select_star_not_found_database(self): """ @@ -1571,7 +1560,7 @@ def test_get_select_star_not_found_database(self): max_id = db.session.query(func.max(Database.id)).scalar() uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_get_select_star_not_found_table(self): """ @@ -1585,7 +1574,7 @@ def test_get_select_star_not_found_table(self): uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/" rv = self.client.get(uri) # TODO(bkyryliuk): investigate why presto returns 500 - self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500) + assert rv.status_code == (404 if example_db.backend != "presto" else 500) def test_get_allow_file_upload_filter(self): """ @@ -1952,13 +1941,13 @@ def test_database_schemas(self): rv = self.client.get(f"api/v1/database/{database.id}/schemas/") response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(schemas, set(response["result"])) + assert schemas == set(response["result"]) rv = self.client.get( f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}" ) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(schemas, set(response["result"])) + assert schemas == set(response["result"]) def test_database_schemas_not_found(self): """ @@ -1968,7 +1957,7 @@ def test_database_schemas_not_found(self): example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/schemas/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_database_schemas_invalid_query(self): """ @@ -1979,7 +1968,7 @@ def test_database_schemas_invalid_query(self): rv = self.client.get( f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}" ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 def test_database_tables(self): """ @@ -1993,17 +1982,17 @@ def test_database_tables(self): f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': schema_name})}" ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 if database.backend == "postgresql": response = json.loads(rv.data.decode("utf-8")) schemas = [ s[0] for s in database.get_all_table_names_in_schema(None, schema_name) ] - self.assertEqual(response["count"], len(schemas)) + assert response["count"] == len(schemas) for option in response["result"]: - self.assertEqual(option["extra"], None) - self.assertEqual(option["type"], "table") - self.assertTrue(option["value"] in schemas) + assert option["extra"] is None + assert option["type"] == "table" + assert option["value"] in schemas @patch("superset.utils.log.logger") def test_database_tables_not_found(self, logger_mock): @@ -2014,7 +2003,7 @@ def test_database_tables_not_found(self, logger_mock): example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/tables/?q={prison.dumps({'schema_name': 'non_existent'})}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 logger_mock.warning.assert_called_once_with( "Database not found.", exc_info=True ) @@ -2028,7 +2017,7 @@ def test_database_tables_invalid_query(self): rv = self.client.get( f"api/v1/database/{database.id}/tables/?q={prison.dumps({'force': 'nop'})}" ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 @patch("superset.utils.log.logger") @mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database") @@ -2046,7 +2035,7 @@ def test_database_tables_unexpected_error( rv = self.client.get( f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': 'main'})}" ) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 logger_mock.warning.assert_called_once_with("Test Error", exc_info=True) def test_test_connection(self): @@ -2074,8 +2063,8 @@ def test_test_connection(self): } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") - self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + assert rv.status_code == 200 + assert rv.headers["Content-Type"] == "application/json; charset=utf-8" # validate that the endpoint works with the decrypted sqlalchemy uri data = { @@ -2086,8 +2075,8 @@ def test_test_connection(self): "server_cert": None, } rv = self.post_assert_metric(url, data, "test_connection") - self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + assert rv.status_code == 200 + assert rv.headers["Content-Type"] == "application/json; charset=utf-8" def test_test_connection_failed(self): """ @@ -2103,8 +2092,8 @@ def test_test_connection_failed(self): } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") - self.assertEqual(rv.status_code, 422) - self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + assert rv.status_code == 422 + assert rv.headers["Content-Type"] == "application/json; charset=utf-8" response = json.loads(rv.data.decode("utf-8")) expected_response = { "errors": [ @@ -2123,7 +2112,7 @@ def test_test_connection_failed(self): } ] } - self.assertEqual(response, expected_response) + assert response == expected_response data = { "sqlalchemy_uri": "mssql+pymssql://url", @@ -2132,8 +2121,8 @@ def test_test_connection_failed(self): "server_cert": None, } rv = self.post_assert_metric(url, data, "test_connection") - self.assertEqual(rv.status_code, 422) - self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") + assert rv.status_code == 422 + assert rv.headers["Content-Type"] == "application/json; charset=utf-8" response = json.loads(rv.data.decode("utf-8")) expected_response = { "errors": [ @@ -2152,7 +2141,7 @@ def test_test_connection_failed(self): } ] } - self.assertEqual(response, expected_response) + assert response == expected_response def test_test_connection_unsafe_uri(self): """ @@ -2169,7 +2158,7 @@ def test_test_connection_unsafe_uri(self): } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": { @@ -2178,7 +2167,7 @@ def test_test_connection_unsafe_uri(self): ] } } - self.assertEqual(response, expected_response) + assert response == expected_response app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False @@ -2250,10 +2239,10 @@ def test_get_database_related_objects(self): database = get_example_database() uri = f"api/v1/database/{database.id}/related_objects/" rv = self.get_assert_metric(uri, "related_objects") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["charts"]["count"], 33) - self.assertEqual(response["dashboards"]["count"], 3) + assert response["charts"]["count"] == 33 + assert response["dashboards"]["count"] == 3 def test_get_database_related_objects_not_found(self): """ @@ -2265,13 +2254,13 @@ def test_get_database_related_objects_not_found(self): uri = f"api/v1/database/{invalid_id}/related_objects/" self.login(ADMIN_USERNAME) rv = self.get_assert_metric(uri, "related_objects") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 self.logout() self.login(GAMMA_USERNAME) database = get_example_database() uri = f"api/v1/database/{database.id}/related_objects/" rv = self.get_assert_metric(uri, "related_objects") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_export_database(self): """ @@ -2679,7 +2668,7 @@ def test_import_database_masked_ssh_tunnel_password_provided( .filter(SSHTunnel.database_id == database.id) .one() ) - self.assertEqual(model_ssh_tunnel.password, "TEST") + assert model_ssh_tunnel.password == "TEST" db.session.delete(database) db.session.commit() @@ -2797,8 +2786,8 @@ def test_import_database_masked_ssh_tunnel_private_key_and_password_provided( .filter(SSHTunnel.database_id == database.id) .one() ) - self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey") - self.assertEqual(model_ssh_tunnel.private_key_password, "TEST") + assert model_ssh_tunnel.private_key == "TestPrivateKey" + assert model_ssh_tunnel.private_key_password == "TEST" db.session.delete(database) db.session.commit() @@ -3852,8 +3841,8 @@ def test_validate_sql(self): uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(response["result"], []) + assert rv.status_code == 200 + assert response["result"] == [] @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", @@ -3878,18 +3867,15 @@ def test_validate_sql_errors(self): uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual( - response["result"], - [ - { - "end_column": None, - "line_number": 1, - "message": 'ERROR: syntax error at or near "table1"', - "start_column": None, - } - ], - ) + assert rv.status_code == 200 + assert response["result"] == [ + { + "end_column": None, + "line_number": 1, + "message": 'ERROR: syntax error at or near "table1"', + "start_column": None, + } + ] @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", @@ -3910,7 +3896,7 @@ def test_validate_sql_not_found(self): f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/" ) rv = self.client.post(uri, json=request_payload) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", @@ -3932,8 +3918,8 @@ def test_validate_sql_validation_fails(self): ) rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) - self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}}) + assert rv.status_code == 400 + assert response == {"message": {"sql": ["Field may not be null."]}} @mock.patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", @@ -3956,29 +3942,26 @@ def test_validate_sql_endpoint_noconfig(self): uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 422) - self.assertEqual( - response, - { - "errors": [ - { - "message": f"no SQL validator is configured for " - f"{example_db.backend}", - "error_type": "GENERIC_DB_ENGINE_ERROR", - "level": "error", - "extra": { - "issue_codes": [ - { - "code": 1002, - "message": "Issue 1002 - The database returned an " - "unexpected error.", - } - ] - }, - } - ] - }, - ) + assert rv.status_code == 422 + assert response == { + "errors": [ + { + "message": f"no SQL validator is configured for " + f"{example_db.backend}", + "error_type": "GENERIC_DB_ENGINE_ERROR", + "level": "error", + "extra": { + "issue_codes": [ + { + "code": 1002, + "message": "Issue 1002 - The database returned an " + "unexpected error.", + } + ] + }, + } + ] + } @mock.patch("superset.commands.database.validate_sql.get_validator_by_name") @mock.patch.dict( @@ -4013,8 +3996,8 @@ def test_validate_sql_endpoint_failure(self, get_validator_by_name): # TODO(bkyryliuk): properly handle hive error if get_example_database().backend == "hive": return - self.assertEqual(rv.status_code, 422) - self.assertIn("Kaboom!", response["errors"][0]["message"]) + assert rv.status_code == 422 + assert "Kaboom!" in response["errors"][0]["message"] def test_get_databases_with_extra_filters(self): """ @@ -4048,14 +4031,14 @@ def test_get_databases_with_extra_filters(self): uri, json={**database_data, "database_name": "dyntest-create-database-1"} ) first_response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 uri = "api/v1/database/" rv = self.client.post( uri, json={**database_data, "database_name": "create-database-2"} ) second_response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 # The filter function def _base_filter(query): @@ -4074,11 +4057,11 @@ def _base_filter(query): rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) # All databases must be returned if no filter is present - self.assertEqual(data["count"], len(dbs)) + assert data["count"] == len(dbs) database_names = [item["database_name"] for item in data["result"]] database_names.sort() # All Databases because we are an admin - self.assertEqual(database_names, expected_names) + assert database_names == expected_names assert rv.status_code == 200 # Our filter function wasn't get called base_filter_mock.assert_not_called() @@ -4092,10 +4075,10 @@ def _base_filter(query): rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) # Only one database start with dyntest - self.assertEqual(data["count"], 1) + assert data["count"] == 1 database_names = [item["database_name"] for item in data["result"]] # Only the database that starts with tests, even if we are an admin - self.assertEqual(database_names, ["dyntest-create-database-1"]) + assert database_names == ["dyntest-create-database-1"] assert rv.status_code == 200 # The filter function is called now that it's defined in our config base_filter_mock.assert_called() diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 692c728a2bc3b..3bd0cfce22512 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -740,7 +740,7 @@ def test_import_v1_database_with_ssh_tunnel_password( .filter(SSHTunnel.database_id == database.id) .one() ) - self.assertEqual(model_ssh_tunnel.password, "TEST") + assert model_ssh_tunnel.password == "TEST" db.session.delete(database) db.session.commit() @@ -787,8 +787,8 @@ def test_import_v1_database_with_ssh_tunnel_private_key_and_password( .filter(SSHTunnel.database_id == database.id) .one() ) - self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey") - self.assertEqual(model_ssh_tunnel.private_key_password, "TEST") + assert model_ssh_tunnel.private_key == "TestPrivateKey" + assert model_ssh_tunnel.private_key_password == "TEST" db.session.delete(database) db.session.commit() diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 0887185619552..49110277bf328 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -192,7 +192,7 @@ def test_user_gets_all_datasets(self): def count_datasets(): uri = "api/v1/chart/" rv = self.client.get(uri, "get_list") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = rv.get_json() return data["count"] @@ -1342,14 +1342,14 @@ def test_dataset_get_list_no_username(self): table_data = {"description": "changed_description"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = self.get_assert_metric("api/v1/dataset/", "get_list") res = json.loads(response.data.decode("utf-8"))["result"] current_dataset = [d for d in res if d["id"] == dataset.id][0] - self.assertEqual(current_dataset["description"], "changed_description") - self.assertNotIn("username", current_dataset["changed_by"].keys()) + assert current_dataset["description"] == "changed_description" + assert "username" not in current_dataset["changed_by"].keys() db.session.delete(dataset) db.session.commit() @@ -1364,13 +1364,13 @@ def test_dataset_get_no_username(self): table_data = {"description": "changed_description"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = self.get_assert_metric(uri, "get") res = json.loads(response.data.decode("utf-8"))["result"] - self.assertEqual(res["description"], "changed_description") - self.assertNotIn("username", res["changed_by"].keys()) + assert res["description"] == "changed_description" + assert "username" not in res["changed_by"].keys() db.session.delete(dataset) db.session.commit() @@ -2311,14 +2311,14 @@ def test_get_or_create_dataset_already_exists(self): "database_id": get_example_database().id, }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) dataset = ( db.session.query(SqlaTable) .filter(SqlaTable.table_name == "virtual_dataset") .one() ) - self.assertEqual(response["result"], {"table_id": dataset.id}) + assert response["result"] == {"table_id": dataset.id} def test_get_or_create_dataset_database_not_found(self): """ @@ -2329,9 +2329,9 @@ def test_get_or_create_dataset_database_not_found(self): "api/v1/dataset/get_or_create/", json={"table_name": "virtual_dataset", "database_id": 999}, ) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["message"], {"database": ["Database does not exist"]}) + assert response["message"] == {"database": ["Database does not exist"]} @patch("superset.commands.dataset.create.CreateDatasetCommand.run") def test_get_or_create_dataset_create_fails(self, command_run_mock): @@ -2347,9 +2347,9 @@ def test_get_or_create_dataset_create_fails(self, command_run_mock): "database_id": get_example_database().id, }, ) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["message"], "Dataset could not be created.") + assert response["message"] == "Dataset could not be created." def test_get_or_create_dataset_creates_table(self): """ @@ -2370,7 +2370,7 @@ def test_get_or_create_dataset_creates_table(self): "template_params": '{"param": 1}', }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) table = ( db.session.query(SqlaTable) @@ -2410,12 +2410,9 @@ def test_warm_up_cache(self): "db_name": get_example_database().database_name, }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - len(data["result"]), - len(energy_charts), - ) + assert len(data["result"]) == len(energy_charts) for chart_result in data["result"]: assert "chart_id" in chart_result assert "viz_error" in chart_result @@ -2439,12 +2436,9 @@ def test_warm_up_cache(self): "dashboard_id": dashboard.id, }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - len(data["result"]), - len(birth_charts), - ) + assert len(data["result"]) == len(birth_charts) for chart_result in data["result"]: assert "chart_id" in chart_result assert "viz_error" in chart_result @@ -2462,12 +2456,9 @@ def test_warm_up_cache(self): ), }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - len(data["result"]), - len(birth_charts), - ) + assert len(data["result"]) == len(birth_charts) for chart_result in data["result"]: assert "chart_id" in chart_result assert "viz_error" in chart_result @@ -2476,17 +2467,14 @@ def test_warm_up_cache(self): def test_warm_up_cache_db_and_table_name_required(self): self.login(ADMIN_USERNAME) rv = self.client.put("/api/v1/dataset/warm_up_cache", json={"dashboard_id": 1}) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data, - { - "message": { - "db_name": ["Missing data for required field."], - "table_name": ["Missing data for required field."], - } - }, - ) + assert data == { + "message": { + "db_name": ["Missing data for required field."], + "table_name": ["Missing data for required field."], + } + } def test_warm_up_cache_table_not_found(self): self.login(ADMIN_USERNAME) @@ -2494,9 +2482,8 @@ def test_warm_up_cache_table_not_found(self): "/api/v1/dataset/warm_up_cache", json={"table_name": "not_here", "db_name": "abc"}, ) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 data = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - data, - {"message": "The provided table was not found in the provided database"}, - ) + assert data == { + "message": "The provided table was not found in the provided database" + } diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 66a15e2e61d52..f85951c4535e5 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -587,8 +587,8 @@ def test_create_dataset_command(self): .filter_by(table_name="test_create_dataset_command") .one() ) - self.assertEqual(table, fetched_table) - self.assertEqual([owner.username for owner in table.owners], ["admin"]) + assert table == fetched_table + assert [owner.username for owner in table.owners] == ["admin"] db.session.delete(table) with examples_db.get_sqla_engine() as engine: @@ -626,7 +626,7 @@ def test_warm_up_cache(self): results = DatasetWarmUpCacheCommand( get_example_database().database_name, "birth_names", None, None ).run() - self.assertEqual(len(results), len(birth_charts)) + assert len(results) == len(birth_charts) for chart_result in results: assert "chart_id" in chart_result assert "viz_error" in chart_result diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py index e810e02ee5716..4c285caeb69e6 100644 --- a/tests/integration_tests/datasource/api_tests.py +++ b/tests/integration_tests/datasource/api_tests.py @@ -41,7 +41,7 @@ def test_get_column_values_ints(self): self.login(ADMIN_USERNAME) table = self.get_virtual_dataset() rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col1/values/") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) for val in range(10): assert val in response["result"] @@ -51,7 +51,7 @@ def test_get_column_values_strs(self): self.login(ADMIN_USERNAME) table = self.get_virtual_dataset() rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) for val in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]: assert val in response["result"] @@ -61,7 +61,7 @@ def test_get_column_values_floats(self): self.login(ADMIN_USERNAME) table = self.get_virtual_dataset() rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col3/values/") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) for val in [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]: assert val in response["result"] @@ -71,16 +71,16 @@ def test_get_column_values_nulls(self): self.login(ADMIN_USERNAME) table = self.get_virtual_dataset() rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col4/values/") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["result"], [None]) + assert response["result"] == [None] @pytest.mark.usefixtures("app_context", "virtual_dataset") def test_get_column_values_integers_with_nulls(self): self.login(ADMIN_USERNAME) table = self.get_virtual_dataset() rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col6/values/") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) for val in [1, None, 3, 4, 5, 6, 7, 8, 9, 10]: assert val in response["result"] @@ -92,27 +92,27 @@ def test_get_column_values_invalid_datasource_type(self): rv = self.client.get( f"api/v1/datasource/not_table/{table.id}/column/col1/values/" ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["message"], "Invalid datasource type: not_table") + assert response["message"] == "Invalid datasource type: not_table" @patch("superset.datasource.api.DatasourceDAO.get_datasource") def test_get_column_values_datasource_type_not_supported(self, get_datasource_mock): get_datasource_mock.side_effect = DatasourceTypeNotSupportedError self.login(ADMIN_USERNAME) rv = self.client.get("api/v1/datasource/table/1/column/col1/values/") - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response["message"], "DAO datasource query source type is not supported" + assert ( + response["message"] == "DAO datasource query source type is not supported" ) def test_get_column_values_datasource_not_found(self): self.login(ADMIN_USERNAME) rv = self.client.get("api/v1/datasource/table/999/column/col1/values/") - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["message"], "Datasource does not exist") + assert response["message"] == "Datasource does not exist" @pytest.mark.usefixtures("app_context", "virtual_dataset") def test_get_column_values_no_datasource_access(self): @@ -126,12 +126,11 @@ def test_get_column_values_no_datasource_access(self): self.login(GAMMA_USERNAME) table = self.get_virtual_dataset() rv = self.client.get(f"api/v1/datasource/table/{table.id}/column/col1/values/") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response["message"], - f"This endpoint requires the datasource {table.id}, " - "database or `all_datasource_access` permission", + assert ( + response["message"] == f"This endpoint requires the datasource {table.id}, " + "database or `all_datasource_access` permission" ) @pytest.mark.usefixtures("app_context", "virtual_dataset") @@ -188,9 +187,9 @@ def test_get_column_values_with_rls(self): rv = self.client.get( f"api/v1/datasource/table/{table.id}/column/col2/values/" ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["result"], ["b"]) + assert response["result"] == ["b"] @pytest.mark.usefixtures("app_context", "virtual_dataset") def test_get_column_values_with_rls_no_values(self): @@ -202,6 +201,6 @@ def test_get_column_values_with_rls_no_values(self): rv = self.client.get( f"api/v1/datasource/table/{table.id}/column/col2/values/" ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["result"], []) + assert response["result"] == [] diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index a6c316fd1e742..ec45c8c57e882 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -101,9 +101,15 @@ def test_external_metadata_for_physical_table(self): url = f"/datasource/external_metadata/table/{tbl.id}/" resp = self.get_json_resp(url) col_names = {o.get("column_name") for o in resp} - self.assertEqual( - col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"} - ) + assert col_names == { + "num_boys", + "num", + "gender", + "name", + "ds", + "state", + "num_girls", + } def test_always_filter_main_dttm(self): database = get_example_database() @@ -175,9 +181,15 @@ def test_external_metadata_by_name_for_physical_table(self): url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) col_names = {o.get("column_name") for o in resp} - self.assertEqual( - col_names, {"num_boys", "num", "gender", "name", "ds", "state", "num_girls"} - ) + assert col_names == { + "num_boys", + "num", + "gender", + "name", + "ds", + "state", + "num_girls", + } def test_external_metadata_by_name_for_virtual_table(self): self.login(ADMIN_USERNAME) @@ -235,7 +247,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self): url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) col_names = {o.get("column_name") for o in resp} - self.assertEqual(col_names, {"first", "second"}) + assert col_names == {"first", "second"} # No databases found params = prison.dumps( @@ -249,10 +261,10 @@ def test_external_metadata_by_name_from_sqla_inspector(self): ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.client.get(url) - self.assertEqual(resp.status_code, DatasetNotFoundError.status) - self.assertEqual( - json.loads(resp.data.decode("utf-8")).get("error"), - DatasetNotFoundError.message, + assert resp.status_code == DatasetNotFoundError.status + assert ( + json.loads(resp.data.decode("utf-8")).get("error") + == DatasetNotFoundError.message ) # No table found @@ -267,10 +279,10 @@ def test_external_metadata_by_name_from_sqla_inspector(self): ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.client.get(url) - self.assertEqual(resp.status_code, DatasetNotFoundError.status) - self.assertEqual( - json.loads(resp.data.decode("utf-8")).get("error"), - DatasetNotFoundError.message, + assert resp.status_code == DatasetNotFoundError.status + assert ( + json.loads(resp.data.decode("utf-8")).get("error") + == DatasetNotFoundError.message ) # invalid query params @@ -281,7 +293,7 @@ def test_external_metadata_by_name_from_sqla_inspector(self): ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) - self.assertIn("error", resp) + assert "error" in resp def test_external_metadata_for_virtual_table_template_params(self): self.login(ADMIN_USERNAME) @@ -308,7 +320,7 @@ def test_external_metadata_for_malicious_virtual_table(self): with db_insert_temp_object(table): url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) - self.assertEqual(resp["error"], "Only `SELECT` statements are allowed") + assert resp["error"] == "Only `SELECT` statements are allowed" def test_external_metadata_for_multistatement_virtual_table(self): self.login(ADMIN_USERNAME) @@ -322,7 +334,7 @@ def test_external_metadata_for_multistatement_virtual_table(self): with db_insert_temp_object(table): url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) - self.assertEqual(resp["error"], "Only single queries supported") + assert resp["error"] == "Only single queries supported" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch("superset.connectors.sqla.models.SqlaTable.external_metadata") @@ -350,7 +362,7 @@ def compare_lists(self, l1, l2, key): obj2 = l2_lookup.get(obj1.get(key)) for k in obj1: if k not in "id" and obj1.get(k): - self.assertEqual(obj1.get(k), obj2.get(k)) + assert obj1.get(k) == obj2.get(k) def test_save(self): self.login(ADMIN_USERNAME) @@ -367,11 +379,11 @@ def test_save(self): elif k == "metrics": self.compare_lists(datasource_post[k], resp[k], "metric_name") elif k == "database": - self.assertEqual(resp[k]["id"], datasource_post[k]["id"]) + assert resp[k]["id"] == datasource_post[k]["id"] elif k == "owners": - self.assertEqual([o["id"] for o in resp[k]], datasource_post["owners"]) + assert [o["id"] for o in resp[k]] == datasource_post["owners"] else: - self.assertEqual(resp[k], datasource_post[k]) + assert resp[k] == datasource_post[k] def test_save_default_endpoint_validation_success(self): self.login(ADMIN_USERNAME) @@ -404,11 +416,11 @@ def test_change_database(self): new_db = self.create_fake_db() datasource_post["database"]["id"] = new_db.id resp = self.save_datasource_from_dict(datasource_post) - self.assertEqual(resp["database"]["id"], new_db.id) + assert resp["database"]["id"] == new_db.id datasource_post["database"]["id"] = db_id resp = self.save_datasource_from_dict(datasource_post) - self.assertEqual(resp["database"]["id"], db_id) + assert resp["database"]["id"] == db_id self.delete_fake_db() @@ -440,7 +452,7 @@ def test_save_duplicate_key(self): ) data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False) - self.assertIn("Duplicate column name(s): ", resp["error"]) + assert "Duplicate column name(s): " in resp["error"] def test_get_datasource(self): admin_user = self.get_user("admin") @@ -454,21 +466,18 @@ def test_get_datasource(self): self.get_json_resp("/datasource/save/", data) url = f"/datasource/get/{tbl.type}/{tbl.id}/" resp = self.get_json_resp(url) - self.assertEqual(resp.get("type"), "table") + assert resp.get("type") == "table" col_names = {o.get("column_name") for o in resp["columns"]} - self.assertEqual( - col_names, - { - "num_boys", - "num", - "gender", - "name", - "ds", - "state", - "num_girls", - "num_california", - }, - ) + assert col_names == { + "num_boys", + "num", + "gender", + "name", + "ds", + "state", + "num_girls", + "num_california", + } def test_get_datasource_with_health_check(self): def my_check(datasource): @@ -491,7 +500,7 @@ def test_get_datasource_failed(self): self.login(ADMIN_USERNAME) resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False) - self.assertEqual(resp.get("error"), "Datasource does not exist") + assert resp.get("error") == "Datasource does not exist" def test_get_datasource_invalid_datasource_failed(self): from superset.daos.datasource import DatasourceDAO @@ -503,7 +512,7 @@ def test_get_datasource_invalid_datasource_failed(self): self.login(ADMIN_USERNAME) resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False) - self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType") + assert resp.get("error") == "'druid' is not a valid DatasourceType" def test_get_samples(test_client, login_as_admin, virtual_dataset): diff --git a/tests/integration_tests/db_engine_specs/ascend_tests.py b/tests/integration_tests/db_engine_specs/ascend_tests.py index ff12656743818..cd1fa37285874 100644 --- a/tests/integration_tests/db_engine_specs/ascend_tests.py +++ b/tests/integration_tests/db_engine_specs/ascend_tests.py @@ -22,11 +22,11 @@ class TestAscendDbEngineSpec(TestDbEngineSpec): def test_convert_dttm(self): dttm = self.get_dttm() - self.assertEqual( - AscendEngineSpec.convert_dttm("DATE", dttm), "CAST('2019-01-02' AS DATE)" + assert ( + AscendEngineSpec.convert_dttm("DATE", dttm) == "CAST('2019-01-02' AS DATE)" ) - self.assertEqual( - AscendEngineSpec.convert_dttm("TIMESTAMP", dttm), - "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)", + assert ( + AscendEngineSpec.convert_dttm("TIMESTAMP", dttm) + == "CAST('2019-01-02T03:04:05.678900' AS TIMESTAMP)" ) diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index 215a01f58538e..715657e4f316b 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -61,18 +61,18 @@ def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec): q10 = "select * from mytable limit 20, x" q11 = "select * from mytable limit x offset 20" - self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None) - self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10) - self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20) - self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None) - self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20) - self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10) - self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10) - self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None) - self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None) - self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None) - self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None) - self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None) + assert engine_spec_class.get_limit_from_sql(q0) is None + assert engine_spec_class.get_limit_from_sql(q1) == 10 + assert engine_spec_class.get_limit_from_sql(q2) == 20 + assert engine_spec_class.get_limit_from_sql(q3) is None + assert engine_spec_class.get_limit_from_sql(q4) == 20 + assert engine_spec_class.get_limit_from_sql(q5) == 10 + assert engine_spec_class.get_limit_from_sql(q6) == 10 + assert engine_spec_class.get_limit_from_sql(q7) is None + assert engine_spec_class.get_limit_from_sql(q8) is None + assert engine_spec_class.get_limit_from_sql(q9) is None + assert engine_spec_class.get_limit_from_sql(q10) is None + assert engine_spec_class.get_limit_from_sql(q11) is None def test_wrapped_semi_tabs(self): self.sql_limit_regex( @@ -141,7 +141,7 @@ def test_limit_expr_and_semicolon(self): ) def test_get_datatype(self): - self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR")) + assert "VARCHAR" == BaseEngineSpec.get_datatype("VARCHAR") def test_limit_with_implicit_offset(self): self.sql_limit_regex( @@ -198,29 +198,26 @@ def test_engine_time_grain_validity(self): for engine in load_engine_specs(): if engine is not BaseEngineSpec: # make sure time grain functions have been defined - self.assertGreater(len(engine.get_time_grain_expressions()), 0) + assert len(engine.get_time_grain_expressions()) > 0 # make sure all defined time grains are supported defined_grains = {grain.duration for grain in engine.get_time_grains()} intersection = time_grains.intersection(defined_grains) - self.assertSetEqual(defined_grains, intersection, engine) + self.assertSetEqual(defined_grains, intersection, engine) # noqa: PT009 def test_get_time_grain_expressions(self): time_grains = MySQLEngineSpec.get_time_grain_expressions() - self.assertEqual( - list(time_grains.keys()), - [ - None, - "PT1S", - "PT1M", - "PT1H", - "P1D", - "P1W", - "P1M", - "P3M", - "P1Y", - "1969-12-29T00:00:00Z/P1W", - ], - ) + assert list(time_grains.keys()) == [ + None, + "PT1S", + "PT1M", + "PT1H", + "P1D", + "P1W", + "P1M", + "P3M", + "P1Y", + "1969-12-29T00:00:00Z/P1W", + ] def test_get_table_names(self): inspector = mock.Mock() @@ -255,11 +252,11 @@ def test_column_datatype_to_string(self): expected = ["STRING", "STRING", "FLOAT"] else: expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"] - self.assertEqual(col_names, expected) + assert col_names == expected def test_convert_dttm(self): dttm = self.get_dttm() - self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm, db_extra=None)) + assert BaseEngineSpec.convert_dttm("", dttm, db_extra=None) is None def test_pyodbc_rows_to_tuples(self): # Test for case when pyodbc.Row is returned (odbc driver) @@ -272,7 +269,7 @@ def test_pyodbc_rows_to_tuples(self): (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)), ] result = BaseEngineSpec.pyodbc_rows_to_tuples(data) - self.assertListEqual(result, expected) + self.assertListEqual(result, expected) # noqa: PT009 def test_pyodbc_rows_to_tuples_passthrough(self): # Test for case when tuples are returned @@ -281,7 +278,7 @@ def test_pyodbc_rows_to_tuples_passthrough(self): (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)), ] result = BaseEngineSpec.pyodbc_rows_to_tuples(data) - self.assertListEqual(result, data) + self.assertListEqual(result, data) # noqa: PT009 @mock.patch("superset.models.core.Database.db_engine_spec", BaseEngineSpec) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/tests/integration_tests/db_engine_specs/base_tests.py index c30c8a0f11c41..c836e71b689be 100644 --- a/tests/integration_tests/db_engine_specs/base_tests.py +++ b/tests/integration_tests/db_engine_specs/base_tests.py @@ -33,4 +33,4 @@ def sql_limit_regex( ): main = Database(database_name="test_database", sqlalchemy_uri="sqlite://") limited = engine_spec_class.apply_limit_to_sql(sql, limit, main, force) - self.assertEqual(expected_sql, limited) + assert expected_sql == limited diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 3518d2cf8a989..fa10bd2ce14bd 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -45,7 +45,7 @@ def test_bigquery_sqla_column_label(self): } for original, expected in test_cases.items(): actual = BigQueryEngineSpec.make_label_compatible(column(original).name) - self.assertEqual(actual, expected) + assert actual == expected def test_timegrain_expressions(self): """ @@ -63,7 +63,7 @@ def test_timegrain_expressions(self): actual = BigQueryEngineSpec.get_timestamp_expr( col=col, pdf=None, time_grain="PT1H" ) - self.assertEqual(str(actual), expected) + assert str(actual) == expected def test_custom_minute_timegrain_expressions(self): """ @@ -104,12 +104,12 @@ def values(self): data1 = [(1, "foo")] with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data1): result = BigQueryEngineSpec.fetch_data(None, 0) - self.assertEqual(result, data1) + assert result == data1 data2 = [Row(1), Row(2)] with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data2): result = BigQueryEngineSpec.fetch_data(None, 0) - self.assertEqual(result, [1, 2]) + assert result == [1, 2] def test_get_extra_table_metadata(self): """ @@ -122,7 +122,7 @@ def test_get_extra_table_metadata(self): database, Table("some_table", "some_schema"), ) - self.assertEqual(result, {}) + assert result == {} index_metadata = [ { @@ -143,7 +143,7 @@ def test_get_extra_table_metadata(self): database, Table("some_table", "some_schema"), ) - self.assertEqual(result, expected_result) + assert result == expected_result def test_get_indexes(self): database = mock.Mock() diff --git a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py index 8b07b2ebdd785..8027c031a5d71 100644 --- a/tests/integration_tests/db_engine_specs/elasticsearch_tests.py +++ b/tests/integration_tests/db_engine_specs/elasticsearch_tests.py @@ -40,4 +40,4 @@ def test_time_grain_expressions(self, time_grain, expected_time_grain_expression actual = ElasticSearchEngineSpec.get_timestamp_expr( col=col, pdf=None, time_grain=time_grain ) - self.assertEqual(str(actual), expected_time_grain_expression) + assert str(actual) == expected_time_grain_expression diff --git a/tests/integration_tests/db_engine_specs/mysql_tests.py b/tests/integration_tests/db_engine_specs/mysql_tests.py index 5f32059484098..e935b99e03d38 100644 --- a/tests/integration_tests/db_engine_specs/mysql_tests.py +++ b/tests/integration_tests/db_engine_specs/mysql_tests.py @@ -30,8 +30,8 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): ) def test_get_datatype_mysql(self): """Tests related to datatype mapping for MySQL""" - self.assertEqual("TINY", MySQLEngineSpec.get_datatype(1)) - self.assertEqual("VARCHAR", MySQLEngineSpec.get_datatype(15)) + assert "TINY" == MySQLEngineSpec.get_datatype(1) + assert "VARCHAR" == MySQLEngineSpec.get_datatype(15) def test_column_datatype_to_string(self): test_cases = ( @@ -49,7 +49,7 @@ def test_column_datatype_to_string(self): actual = MySQLEngineSpec.column_datatype_to_string( original, mysql.dialect() ) - self.assertEqual(actual, expected) + assert actual == expected def test_extract_error_message(self): from MySQLdb._exceptions import OperationalError diff --git a/tests/integration_tests/db_engine_specs/pinot_tests.py b/tests/integration_tests/db_engine_specs/pinot_tests.py index c8deef6fc42b5..40793494eaa45 100755 --- a/tests/integration_tests/db_engine_specs/pinot_tests.py +++ b/tests/integration_tests/db_engine_specs/pinot_tests.py @@ -32,20 +32,14 @@ def test_pinot_time_expression_sec_one_1d_grain(self): + "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', " + "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)" ) - self.assertEqual( - result, - expected, - ) + assert result == expected def test_pinot_time_expression_simple_date_format_1d_grain(self): col = column("tstamp") expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1D") result = str(expr.compile()) expected = "CAST(DATE_TRUNC('day', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)" - self.assertEqual( - result, - expected, - ) + assert result == expected def test_pinot_time_expression_simple_date_format_10m_grain(self): col = column("tstamp") @@ -55,20 +49,14 @@ def test_pinot_time_expression_simple_date_format_10m_grain(self): "CAST(ROUND(DATE_TRUNC('minute', CAST(tstamp AS " + "TIMESTAMP)), 600000) AS TIMESTAMP)" ) - self.assertEqual( - result, - expected, - ) + assert result == expected def test_pinot_time_expression_simple_date_format_1w_grain(self): col = column("tstamp") expr = PinotEngineSpec.get_timestamp_expr(col, "%Y-%m-%d %H:%M:%S", "P1W") result = str(expr.compile()) expected = "CAST(DATE_TRUNC('week', CAST(tstamp AS TIMESTAMP)) AS TIMESTAMP)" - self.assertEqual( - result, - expected, - ) + assert result == expected def test_pinot_time_expression_sec_one_1m_grain(self): col = column("tstamp") @@ -79,10 +67,7 @@ def test_pinot_time_expression_sec_one_1m_grain(self): + "DATETIMECONVERT(tstamp, '1:SECONDS:EPOCH', " + "'1:SECONDS:EPOCH', '1:SECONDS') AS TIMESTAMP)) AS TIMESTAMP)" ) - self.assertEqual( - result, - expected, - ) + assert result == expected def test_pinot_time_expression_millisec_one_1m_grain(self): col = column("tstamp") @@ -93,10 +78,7 @@ def test_pinot_time_expression_millisec_one_1m_grain(self): + "DATETIMECONVERT(tstamp, '1:MILLISECONDS:EPOCH', " + "'1:MILLISECONDS:EPOCH', '1:MILLISECONDS') AS TIMESTAMP)) AS TIMESTAMP)" ) - self.assertEqual( - result, - expected, - ) + assert result == expected def test_invalid_get_time_expression_arguments(self): with self.assertRaises(NotImplementedError): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 4c4261ff57261..e4f9462d63069 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -57,7 +57,7 @@ def test_time_exp_literal_no_grain(self): col = literal_column("COALESCE(a, b)") expr = PostgresEngineSpec.get_timestamp_expr(col, None, None) result = str(expr.compile(None, dialect=postgresql.dialect())) - self.assertEqual(result, "COALESCE(a, b)") + assert result == "COALESCE(a, b)" def test_time_exp_literal_1y_grain(self): """ @@ -66,7 +66,7 @@ def test_time_exp_literal_1y_grain(self): col = literal_column("COALESCE(a, b)") expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y") result = str(expr.compile(None, dialect=postgresql.dialect())) - self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))") + assert result == "DATE_TRUNC('year', COALESCE(a, b))" def test_time_ex_lowr_col_no_grain(self): """ @@ -75,7 +75,7 @@ def test_time_ex_lowr_col_no_grain(self): col = column("lower_case") expr = PostgresEngineSpec.get_timestamp_expr(col, None, None) result = str(expr.compile(None, dialect=postgresql.dialect())) - self.assertEqual(result, "lower_case") + assert result == "lower_case" def test_time_exp_lowr_col_sec_1y(self): """ @@ -84,10 +84,9 @@ def test_time_exp_lowr_col_sec_1y(self): col = column("lower_case") expr = PostgresEngineSpec.get_timestamp_expr(col, "epoch_s", "P1Y") result = str(expr.compile(None, dialect=postgresql.dialect())) - self.assertEqual( - result, - "DATE_TRUNC('year', " - "(timestamp 'epoch' + lower_case * interval '1 second'))", + assert ( + result == "DATE_TRUNC('year', " + "(timestamp 'epoch' + lower_case * interval '1 second'))" ) def test_time_exp_mixed_case_col_1y(self): @@ -97,7 +96,7 @@ def test_time_exp_mixed_case_col_1y(self): col = column("MixedCase") expr = PostgresEngineSpec.get_timestamp_expr(col, None, "P1Y") result = str(expr.compile(None, dialect=postgresql.dialect())) - self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")") + assert result == "DATE_TRUNC('year', \"MixedCase\")" def test_empty_dbapi_cursor_description(self): """ @@ -107,7 +106,7 @@ def test_empty_dbapi_cursor_description(self): # empty description mean no columns, this mocks the following SQL: "SELECT" cursor.description = [] results = PostgresEngineSpec.fetch_data(cursor, 1000) - self.assertEqual(results, []) + assert results == [] def test_engine_alias_name(self): """ @@ -158,13 +157,7 @@ def test_estimate_statement_cost_select_star(self): ) sql = "SELECT * FROM birth_names" results = PostgresEngineSpec.estimate_statement_cost(sql, cursor) - self.assertEqual( - results, - { - "Start-up cost": 0.00, - "Total cost": 1537.91, - }, - ) + assert results == {"Start-up cost": 0.0, "Total cost": 1537.91} def test_estimate_statement_invalid_syntax(self): """ @@ -199,19 +192,10 @@ def test_query_cost_formatter_example_costs(self): }, ] result = PostgresEngineSpec.query_cost_formatter(raw_cost) - self.assertEqual( - result, - [ - { - "Start-up cost": "0.0", - "Total cost": "1537.91", - }, - { - "Start-up cost": "10.0", - "Total cost": "1537.0", - }, - ], - ) + assert result == [ + {"Start-up cost": "0.0", "Total cost": "1537.91"}, + {"Start-up cost": "10.0", "Total cost": "1537.0"}, + ] def test_extract_errors(self): """ diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index b49405765a421..798e31ee431a4 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -33,7 +33,7 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): @skipUnless(TestDbEngineSpec.is_module_installed("pyhive"), "pyhive not installed") def test_get_datatype_presto(self): - self.assertEqual("STRING", PrestoEngineSpec.get_datatype("string")) + assert "STRING" == PrestoEngineSpec.get_datatype("string") def test_get_view_names_with_schema(self): database = mock.MagicMock() @@ -86,10 +86,10 @@ def verify_presto_column(self, column, expected_results): row.Column, row.Type, row.Null = column inspector.bind.execute.return_value.fetchall = mock.Mock(return_value=[row]) results = PrestoEngineSpec.get_columns(inspector, Table("", "")) - self.assertEqual(len(expected_results), len(results)) + assert len(expected_results) == len(results) for expected_result, result in zip(expected_results, results): - self.assertEqual(expected_result[0], result["column_name"]) - self.assertEqual(expected_result[1], str(result["type"])) + assert expected_result[0] == result["column_name"] + assert expected_result[1] == str(result["type"]) def test_presto_get_column(self): presto_column = ("column_name", "boolean", "") @@ -192,8 +192,8 @@ def test_presto_get_fields(self): }, ] for actual_result, expected_result in zip(actual_results, expected_results): - self.assertEqual(actual_result.element.name, expected_result["column_name"]) - self.assertEqual(actual_result.name, expected_result["label"]) + assert actual_result.element.name == expected_result["column_name"] + assert actual_result.name == expected_result["label"] @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -260,9 +260,9 @@ def test_presto_expand_data_with_simple_structural_columns(self): "is_dttm": False, } ] - self.assertEqual(actual_cols, expected_cols) - self.assertEqual(actual_data, expected_data) - self.assertEqual(actual_expanded_cols, expected_expanded_cols) + assert actual_cols == expected_cols + assert actual_data == expected_data + assert actual_expanded_cols == expected_expanded_cols @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -343,9 +343,9 @@ def test_presto_expand_data_with_complex_row_columns(self): "is_dttm": False, }, ] - self.assertEqual(actual_cols, expected_cols) - self.assertEqual(actual_data, expected_data) - self.assertEqual(actual_expanded_cols, expected_expanded_cols) + assert actual_cols == expected_cols + assert actual_data == expected_data + assert actual_expanded_cols == expected_expanded_cols @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -427,9 +427,9 @@ def test_presto_expand_data_with_complex_row_columns_and_null_values(self): "is_dttm": False, }, ] - self.assertEqual(actual_cols, expected_cols) - self.assertEqual(actual_data, expected_data) - self.assertEqual(actual_expanded_cols, expected_expanded_cols) + assert actual_cols == expected_cols + assert actual_data == expected_data + assert actual_expanded_cols == expected_expanded_cols @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -548,9 +548,9 @@ def test_presto_expand_data_with_complex_array_columns(self): "is_dttm": False, }, ] - self.assertEqual(actual_cols, expected_cols) - self.assertEqual(actual_data, expected_data) - self.assertEqual(actual_expanded_cols, expected_expanded_cols) + assert actual_cols == expected_cols + assert actual_data == expected_data + assert actual_expanded_cols == expected_expanded_cols def test_presto_get_extra_table_metadata(self): database = mock.Mock() @@ -582,7 +582,7 @@ def test_presto_where_latest_partition(self): columns, ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) - self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result) + assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result def test_query_cost_formatter(self): raw_cost = [ @@ -645,7 +645,7 @@ def test_query_cost_formatter(self): "Network cost": "354 G", } ] - self.assertEqual(formatted_cost, expected) + assert formatted_cost == expected @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -752,9 +752,9 @@ def test_presto_expand_data_array(self): }, ] - self.assertEqual(actual_cols, expected_cols) - self.assertEqual(actual_data, expected_data) - self.assertEqual(actual_expanded_cols, expected_expanded_cols) + assert actual_cols == expected_cols + assert actual_data == expected_data + assert actual_expanded_cols == expected_expanded_cols @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_table_names") @mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.get_view_names") diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py index 2db17a77b50b6..116487882aeb8 100644 --- a/tests/integration_tests/dict_import_export_tests.py +++ b/tests/integration_tests/dict_import_export_tests.py @@ -91,36 +91,32 @@ def create_table( def yaml_compare(self, obj_1, obj_2): obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False) obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False) - self.assertEqual(obj_1_str, obj_2_str) + assert obj_1_str == obj_2_str def assert_table_equals(self, expected_ds, actual_ds): - self.assertEqual(expected_ds.table_name, actual_ds.table_name) - self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col) - self.assertEqual(expected_ds.schema, actual_ds.schema) - self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) - self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) - self.assertEqual( - {c.column_name for c in expected_ds.columns}, - {c.column_name for c in actual_ds.columns}, - ) - self.assertEqual( - {m.metric_name for m in expected_ds.metrics}, - {m.metric_name for m in actual_ds.metrics}, - ) + assert expected_ds.table_name == actual_ds.table_name + assert expected_ds.main_dttm_col == actual_ds.main_dttm_col + assert expected_ds.schema == actual_ds.schema + assert len(expected_ds.metrics) == len(actual_ds.metrics) + assert len(expected_ds.columns) == len(actual_ds.columns) + assert {c.column_name for c in expected_ds.columns} == { + c.column_name for c in actual_ds.columns + } + assert {m.metric_name for m in expected_ds.metrics} == { + m.metric_name for m in actual_ds.metrics + } def assert_datasource_equals(self, expected_ds, actual_ds): - self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name) - self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col) - self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) - self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) - self.assertEqual( - {c.column_name for c in expected_ds.columns}, - {c.column_name for c in actual_ds.columns}, - ) - self.assertEqual( - {m.metric_name for m in expected_ds.metrics}, - {m.metric_name for m in actual_ds.metrics}, - ) + assert expected_ds.datasource_name == actual_ds.datasource_name + assert expected_ds.main_dttm_col == actual_ds.main_dttm_col + assert len(expected_ds.metrics) == len(actual_ds.metrics) + assert len(expected_ds.columns) == len(actual_ds.columns) + assert {c.column_name for c in expected_ds.columns} == { + c.column_name for c in actual_ds.columns + } + assert {m.metric_name for m in expected_ds.metrics} == { + m.metric_name for m in actual_ds.metrics + } def test_import_table_no_metadata(self): table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1) @@ -143,8 +139,8 @@ def test_import_table_1_col_1_met(self): db.session.commit() imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) - self.assertEqual( - {DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params) + assert {DBREF: ID_PREFIX + 2, "database_name": "main"} == json.loads( + imported.params ) self.yaml_compare(table.export_to_dict(), imported.export_to_dict()) @@ -178,7 +174,7 @@ def test_import_table_override_append(self): db.session.commit() imported_over = self.get_table_by_id(imported_over_table.id) - self.assertEqual(imported_table.id, imported_over.id) + assert imported_table.id == imported_over.id expected_table, _ = self.create_table( "table_override", id=ID_PREFIX + 3, @@ -209,7 +205,7 @@ def test_import_table_override_sync(self): db.session.commit() imported_over = self.get_table_by_id(imported_over_table.id) - self.assertEqual(imported_table.id, imported_over.id) + assert imported_table.id == imported_over.id expected_table, _ = self.create_table( "table_override", id=ID_PREFIX + 3, @@ -239,7 +235,7 @@ def test_import_table_override_identical(self): ) imported_copy_table = SqlaTable.import_from_dict(dict_copy_table) db.session.commit() - self.assertEqual(imported_table.id, imported_copy_table.id) + assert imported_table.id == imported_copy_table.id self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id)) self.yaml_compare( imported_copy_table.export_to_dict(), imported_table.export_to_dict() @@ -259,12 +255,12 @@ def test_export_datasource_ui_cli(self): "/databaseview/action_post", {"action": "yaml_export", "rowid": 1} ) ui_export = yaml.safe_load(resp) - self.assertEqual( - ui_export["databases"][0]["database_name"], - cli_export["databases"][0]["database_name"], + assert ( + ui_export["databases"][0]["database_name"] + == cli_export["databases"][0]["database_name"] ) - self.assertEqual( - ui_export["databases"][0]["tables"], cli_export["databases"][0]["tables"] + assert ( + ui_export["databases"][0]["tables"] == cli_export["databases"][0]["tables"] ) diff --git a/tests/integration_tests/dynamic_plugins_tests.py b/tests/integration_tests/dynamic_plugins_tests.py index 37b77c1d8d71b..d8f5aab7f44eb 100644 --- a/tests/integration_tests/dynamic_plugins_tests.py +++ b/tests/integration_tests/dynamic_plugins_tests.py @@ -28,7 +28,7 @@ def test_dynamic_plugins_disabled(self): self.login(ADMIN_USERNAME) uri = "/dynamic-plugins/api" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @with_feature_flags(DYNAMIC_PLUGINS=True) def test_dynamic_plugins_enabled(self): @@ -38,4 +38,4 @@ def test_dynamic_plugins_enabled(self): self.login(ADMIN_USERNAME) uri = "/dynamic-plugins/api" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 diff --git a/tests/integration_tests/email_tests.py b/tests/integration_tests/email_tests.py index 12c2fc676b01e..d7afe3551bf47 100644 --- a/tests/integration_tests/email_tests.py +++ b/tests/integration_tests/email_tests.py @@ -222,7 +222,7 @@ def test_send_mime_ssl_server_auth(self, mock_smtp, mock_smtp_ssl): app.config["SMTP_HOST"], app.config["SMTP_PORT"], context=mock.ANY ) called_context = mock_smtp_ssl.call_args.kwargs["context"] - self.assertEqual(called_context.verify_mode, ssl.CERT_REQUIRED) + assert called_context.verify_mode == ssl.CERT_REQUIRED @mock.patch("smtplib.SMTP") def test_send_mime_tls_server_auth(self, mock_smtp): @@ -233,7 +233,7 @@ def test_send_mime_tls_server_auth(self, mock_smtp): utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) mock_smtp.return_value.starttls.assert_called_with(context=mock.ANY) called_context = mock_smtp.return_value.starttls.call_args.kwargs["context"] - self.assertEqual(called_context.verify_mode, ssl.CERT_REQUIRED) + assert called_context.verify_mode == ssl.CERT_REQUIRED @mock.patch("smtplib.SMTP_SSL") @mock.patch("smtplib.SMTP") diff --git a/tests/integration_tests/embedded/dao_tests.py b/tests/integration_tests/embedded/dao_tests.py index eed161581fe71..6949462b79fc3 100644 --- a/tests/integration_tests/embedded/dao_tests.py +++ b/tests/integration_tests/embedded/dao_tests.py @@ -36,13 +36,13 @@ def test_upsert(self): EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) db.session.flush() assert dash.embedded - self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"]) + assert dash.embedded[0].allowed_domains == ["test.example.com"] original_uuid = dash.embedded[0].uuid - self.assertIsNotNone(original_uuid) + assert original_uuid is not None EmbeddedDashboardDAO.upsert(dash, []) db.session.flush() - self.assertEqual(dash.embedded[0].allowed_domains, []) - self.assertEqual(dash.embedded[0].uuid, original_uuid) + assert dash.embedded[0].allowed_domains == [] + assert dash.embedded[0].uuid == original_uuid @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_by_uuid(self): @@ -51,4 +51,4 @@ def test_get_by_uuid(self): db.session.flush() uuid = str(dash.embedded[0].uuid) embedded = EmbeddedDashboardDAO.find_by_id(uuid) - self.assertIsNotNone(embedded) + assert embedded is not None diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py index 3f0acc30c00ed..4de1258884031 100644 --- a/tests/integration_tests/event_logger_tests.py +++ b/tests/integration_tests/event_logger_tests.py @@ -39,7 +39,7 @@ def test_correct_config_object(self): # unmodified object obj = DBEventLogger() res = get_event_logger_from_cfg_value(obj) - self.assertIs(obj, res) + assert obj is res def test_config_class_deprecation(self): # test that assignment of a class object to EVENT_LOGGER is correctly @@ -51,7 +51,7 @@ def test_config_class_deprecation(self): res = get_event_logger_from_cfg_value(DBEventLogger) # class is instantiated and returned - self.assertIsInstance(res, DBEventLogger) + assert isinstance(res, DBEventLogger) def test_raises_typeerror_if_not_abc(self): # test that assignment of non AbstractEventLogger derived type raises @@ -71,19 +71,16 @@ def test_func(): with app.test_request_context("/superset/dashboard/1/?myparam=foo"): result = test_func() payload = mock_log.call_args[1] - self.assertEqual(result, 1) - self.assertEqual( - payload["records"], - [ - { - "myparam": "foo", - "path": "/superset/dashboard/1/", - "url_rule": "/superset/dashboard//", - "object_ref": test_func.__qualname__, - } - ], - ) - self.assertGreaterEqual(payload["duration_ms"], 50) + assert result == 1 + assert payload["records"] == [ + { + "myparam": "foo", + "path": "/superset/dashboard/1/", + "url_rule": "/superset/dashboard//", + "object_ref": test_func.__qualname__, + } + ] + assert payload["duration_ms"] >= 50 @patch.object(DBEventLogger, "log") def test_log_this_with_extra_payload(self, mock_log): @@ -98,19 +95,16 @@ def test_func(arg1, add_extra_log_payload, karg1=1): with app.test_request_context(): result = test_func(1, karg1=2) # pylint: disable=no-value-for-parameter payload = mock_log.call_args[1] - self.assertEqual(result, 2) - self.assertEqual( - payload["records"], - [ - { - "foo": "bar", - "path": "/", - "karg1": 2, - "object_ref": test_func.__qualname__, - } - ], - ) - self.assertGreaterEqual(payload["duration_ms"], 100) + assert result == 2 + assert payload["records"] == [ + { + "foo": "bar", + "path": "/", + "karg1": 2, + "object_ref": test_func.__qualname__, + } + ] + assert payload["duration_ms"] >= 100 @patch("superset.utils.core.g", spec={}) @freeze_time("Jan 14th, 2020", auto_tick_seconds=15) @@ -141,19 +135,16 @@ def log( with logger(action="foo", engine="bar"): pass - self.assertEquals( - logger.records, - [ - { - "records": [{"path": "/", "engine": "bar"}], - "database_id": None, - "user_id": 2, - "duration": 15000, - "curated_payload": {}, - "curated_form_data": {}, - } - ], - ) + assert logger.records == [ + { + "records": [{"path": "/", "engine": "bar"}], + "database_id": None, + "user_id": 2, + "duration": 15000, + "curated_payload": {}, + "curated_form_data": {}, + } + ] @patch("superset.utils.core.g", spec={}) def test_context_manager_log_with_context(self, mock_g): @@ -188,25 +179,22 @@ def log( payload_override={"engine": "sqlite"}, ) - self.assertEquals( - logger.records, - [ - { - "records": [ - { - "path": "/", - "object_ref": {"baz": "food"}, - "payload_override": {"engine": "sqlite"}, - } - ], - "database_id": None, - "user_id": 2, - "duration": 5558756000, - "curated_payload": {}, - "curated_form_data": {}, - } - ], - ) + assert logger.records == [ + { + "records": [ + { + "path": "/", + "object_ref": {"baz": "food"}, + "payload_override": {"engine": "sqlite"}, + } + ], + "database_id": None, + "user_id": 2, + "duration": 5558756000, + "curated_payload": {}, + "curated_form_data": {}, + } + ] @patch("superset.utils.core.g", spec={}) def test_log_with_context_user_null(self, mock_g): diff --git a/tests/integration_tests/form_tests.py b/tests/integration_tests/form_tests.py index 078a9866ee975..4dfbc361386dc 100644 --- a/tests/integration_tests/form_tests.py +++ b/tests/integration_tests/form_tests.py @@ -24,13 +24,13 @@ class TestForm(SupersetTestCase): def test_comma_separated_list_field(self): field = CommaSeparatedListField().bind(Form(), "foo") field.process_formdata([""]) - self.assertEqual(field.data, [""]) + assert field.data == [""] field.process_formdata(["a,comma,separated,list"]) - self.assertEqual(field.data, ["a", "comma", "separated", "list"]) + assert field.data == ["a", "comma", "separated", "list"] def test_filter_not_empty_values(self): - self.assertEqual(filter_not_empty_values(None), None) - self.assertEqual(filter_not_empty_values([]), None) - self.assertEqual(filter_not_empty_values([""]), None) - self.assertEqual(filter_not_empty_values(["hi"]), ["hi"]) + assert filter_not_empty_values(None) is None + assert filter_not_empty_values([]) is None + assert filter_not_empty_values([""]) is None + assert filter_not_empty_values(["hi"]) == ["hi"] diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index e4c9bff51e4c3..702acf4b03492 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -148,52 +148,48 @@ def assert_dash_equals( self, expected_dash, actual_dash, check_position=True, check_slugs=True ): if check_slugs: - self.assertEqual(expected_dash.slug, actual_dash.slug) - self.assertEqual(expected_dash.dashboard_title, actual_dash.dashboard_title) - self.assertEqual(len(expected_dash.slices), len(actual_dash.slices)) + assert expected_dash.slug == actual_dash.slug + assert expected_dash.dashboard_title == actual_dash.dashboard_title + assert len(expected_dash.slices) == len(actual_dash.slices) expected_slices = sorted(expected_dash.slices, key=lambda s: s.slice_name or "") actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "") for e_slc, a_slc in zip(expected_slices, actual_slices): self.assert_slice_equals(e_slc, a_slc) if check_position: - self.assertEqual(expected_dash.position_json, actual_dash.position_json) + assert expected_dash.position_json == actual_dash.position_json def assert_table_equals(self, expected_ds, actual_ds): - self.assertEqual(expected_ds.table_name, actual_ds.table_name) - self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col) - self.assertEqual(expected_ds.schema, actual_ds.schema) - self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) - self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) - self.assertEqual( - {c.column_name for c in expected_ds.columns}, - {c.column_name for c in actual_ds.columns}, - ) - self.assertEqual( - {m.metric_name for m in expected_ds.metrics}, - {m.metric_name for m in actual_ds.metrics}, - ) + assert expected_ds.table_name == actual_ds.table_name + assert expected_ds.main_dttm_col == actual_ds.main_dttm_col + assert expected_ds.schema == actual_ds.schema + assert len(expected_ds.metrics) == len(actual_ds.metrics) + assert len(expected_ds.columns) == len(actual_ds.columns) + assert {c.column_name for c in expected_ds.columns} == { + c.column_name for c in actual_ds.columns + } + assert {m.metric_name for m in expected_ds.metrics} == { + m.metric_name for m in actual_ds.metrics + } def assert_datasource_equals(self, expected_ds, actual_ds): - self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name) - self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col) - self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) - self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) - self.assertEqual( - {c.column_name for c in expected_ds.columns}, - {c.column_name for c in actual_ds.columns}, - ) - self.assertEqual( - {m.metric_name for m in expected_ds.metrics}, - {m.metric_name for m in actual_ds.metrics}, - ) + assert expected_ds.datasource_name == actual_ds.datasource_name + assert expected_ds.main_dttm_col == actual_ds.main_dttm_col + assert len(expected_ds.metrics) == len(actual_ds.metrics) + assert len(expected_ds.columns) == len(actual_ds.columns) + assert {c.column_name for c in expected_ds.columns} == { + c.column_name for c in actual_ds.columns + } + assert {m.metric_name for m in expected_ds.metrics} == { + m.metric_name for m in actual_ds.metrics + } def assert_slice_equals(self, expected_slc, actual_slc): # to avoid bad slice data (no slice_name) expected_slc_name = expected_slc.slice_name or "" actual_slc_name = actual_slc.slice_name or "" - self.assertEqual(expected_slc_name, actual_slc_name) - self.assertEqual(expected_slc.datasource_type, actual_slc.datasource_type) - self.assertEqual(expected_slc.viz_type, actual_slc.viz_type) + assert expected_slc_name == actual_slc_name + assert expected_slc.datasource_type == actual_slc.datasource_type + assert expected_slc.viz_type == actual_slc.viz_type exp_params = json.loads(expected_slc.params) actual_params = json.loads(actual_slc.params) diff_params_keys = ( @@ -208,7 +204,7 @@ def assert_slice_equals(self, expected_slc, actual_slc): actual_params.pop(k) if k in exp_params: exp_params.pop(k) - self.assertEqual(exp_params, actual_params) + assert exp_params == actual_params def assert_only_exported_slc_fields(self, expected_dash, actual_dash): """only exported json has this params @@ -218,9 +214,9 @@ def assert_only_exported_slc_fields(self, expected_dash, actual_dash): actual_slices = sorted(actual_dash.slices, key=lambda s: s.slice_name or "") for e_slc, a_slc in zip(expected_slices, actual_slices): params = a_slc.params_dict - self.assertEqual(e_slc.datasource.name, params["datasource_name"]) - self.assertEqual(e_slc.datasource.schema, params["schema"]) - self.assertEqual(e_slc.datasource.database.name, params["database_name"]) + assert e_slc.datasource.name == params["datasource_name"] + assert e_slc.datasource.schema == params["schema"] + assert e_slc.datasource.database.name == params["database_name"] @unittest.skip("Schema needs to be updated") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -237,17 +233,17 @@ def test_export_1_dashboard(self): birth_dash = self.get_dash_by_slug("births") self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0]) self.assert_dash_equals(birth_dash, exported_dashboards[0]) - self.assertEqual( - id_, - json.loads( + assert ( + id_ + == json.loads( exported_dashboards[0].json_metadata, object_hook=decode_dashboards - )["remote_id"], + )["remote_id"] ) exported_tables = json.loads( resp.data.decode("utf-8"), object_hook=decode_dashboards )["datasources"] - self.assertEqual(1, len(exported_tables)) + assert 1 == len(exported_tables) self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0]) @unittest.skip("Schema needs to be updated") @@ -269,27 +265,28 @@ def test_export_2_dashboards(self): exported_dashboards = sorted( resp_data.get("dashboards"), key=lambda d: d.dashboard_title ) - self.assertEqual(2, len(exported_dashboards)) + assert 2 == len(exported_dashboards) birth_dash = self.get_dash_by_slug("births") self.assert_only_exported_slc_fields(birth_dash, exported_dashboards[0]) self.assert_dash_equals(birth_dash, exported_dashboards[0]) - self.assertEqual( - birth_dash.id, json.loads(exported_dashboards[0].json_metadata)["remote_id"] + assert ( + birth_dash.id + == json.loads(exported_dashboards[0].json_metadata)["remote_id"] ) world_health_dash = self.get_dash_by_slug("world_health") self.assert_only_exported_slc_fields(world_health_dash, exported_dashboards[1]) self.assert_dash_equals(world_health_dash, exported_dashboards[1]) - self.assertEqual( - world_health_dash.id, - json.loads(exported_dashboards[1].json_metadata)["remote_id"], + assert ( + world_health_dash.id + == json.loads(exported_dashboards[1].json_metadata)["remote_id"] ) exported_tables = sorted( resp_data.get("datasources"), key=lambda t: t.table_name ) - self.assertEqual(2, len(exported_tables)) + assert 2 == len(exported_tables) self.assert_table_equals(self.get_table(name="birth_names"), exported_tables[0]) self.assert_table_equals( self.get_table(name="wb_health_population"), exported_tables[1] @@ -302,11 +299,11 @@ def test_import_1_slice(self): ) slc_id = import_chart(expected_slice, None, import_time=1989) slc = self.get_slice(slc_id) - self.assertEqual(slc.datasource.perm, slc.perm) + assert slc.datasource.perm == slc.perm self.assert_slice_equals(expected_slice, slc) table_id = self.get_table(name="wb_health_population").id - self.assertEqual(table_id, self.get_slice(slc_id).datasource_id) + assert table_id == self.get_slice(slc_id).datasource_id @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_import_2_slices_for_same_table(self): @@ -323,13 +320,13 @@ def test_import_2_slices_for_same_table(self): imported_slc_1 = self.get_slice(slc_id_1) imported_slc_2 = self.get_slice(slc_id_2) - self.assertEqual(table_id, imported_slc_1.datasource_id) + assert table_id == imported_slc_1.datasource_id self.assert_slice_equals(slc_1, imported_slc_1) - self.assertEqual(imported_slc_1.datasource.perm, imported_slc_1.perm) + assert imported_slc_1.datasource.perm == imported_slc_1.perm - self.assertEqual(table_id, imported_slc_2.datasource_id) + assert table_id == imported_slc_2.datasource_id self.assert_slice_equals(slc_2, imported_slc_2) - self.assertEqual(imported_slc_2.datasource.perm, imported_slc_2.perm) + assert imported_slc_2.datasource.perm == imported_slc_2.perm def test_import_slices_override(self): schema = get_example_default_schema() @@ -339,7 +336,7 @@ def test_import_slices_override(self): imported_slc_1 = self.get_slice(slc_1_id) slc_2 = self.create_slice("Import Me New", id=10005, schema=schema) slc_2_id = import_chart(slc_2, imported_slc_1, import_time=1990) - self.assertEqual(slc_1_id, slc_2_id) + assert slc_1_id == slc_2_id imported_slc_2 = self.get_slice(slc_2_id) self.assert_slice_equals(slc, imported_slc_2) @@ -379,21 +376,18 @@ def test_import_dashboard_1_slice(self): self.assert_dash_equals( expected_dash, imported_dash, check_position=False, check_slugs=False ) - self.assertEqual( - { - "remote_id": 10002, - "import_time": 1990, - "native_filter_configuration": [], - }, - json.loads(imported_dash.json_metadata), - ) + assert { + "remote_id": 10002, + "import_time": 1990, + "native_filter_configuration": [], + } == json.loads(imported_dash.json_metadata) expected_position = dash_with_1_slice.position # new slice id (auto-incremental) assigned on insert # id from json is used only for updating position with new id meta = expected_position["DASHBOARD_CHART_TYPE-10006"]["meta"] meta["chartId"] = imported_dash.slices[0].id - self.assertEqual(expected_position, imported_dash.position) + assert expected_position == imported_dash.position @pytest.mark.usefixtures("load_energy_table_with_slice") def test_import_dashboard_2_slices(self): @@ -444,9 +438,7 @@ def test_import_dashboard_2_slices(self): }, "native_filter_configuration": [], } - self.assertEqual( - expected_json_metadata, json.loads(imported_dash.json_metadata) - ) + assert expected_json_metadata == json.loads(imported_dash.json_metadata) @pytest.mark.usefixtures("load_energy_table_with_slice") def test_import_override_dashboard_2_slices(self): @@ -478,7 +470,7 @@ def test_import_override_dashboard_2_slices(self): imported_dash_id_2 = import_dashboard(dash_to_import_override, import_time=1992) # override doesn't change the id - self.assertEqual(imported_dash_id_1, imported_dash_id_2) + assert imported_dash_id_1 == imported_dash_id_2 expected_dash = self.create_dashboard( "override_dashboard_new", slcs=[e_slc, b_slc, c_slc], id=10004 ) @@ -487,20 +479,17 @@ def test_import_override_dashboard_2_slices(self): self.assert_dash_equals( expected_dash, imported_dash, check_position=False, check_slugs=False ) - self.assertEqual( - { - "remote_id": 10004, - "import_time": 1992, - "native_filter_configuration": [], - }, - json.loads(imported_dash.json_metadata), - ) + assert { + "remote_id": 10004, + "import_time": 1992, + "native_filter_configuration": [], + } == json.loads(imported_dash.json_metadata) def test_import_new_dashboard_slice_reset_ownership(self): admin_user = security_manager.find_user(username="admin") - self.assertTrue(admin_user) + assert admin_user gamma_user = security_manager.find_user(username="gamma") - self.assertTrue(gamma_user) + assert gamma_user g.user = gamma_user dash_with_1_slice = self._create_dashboard_for_import(id_=10200) @@ -511,35 +500,35 @@ def test_import_new_dashboard_slice_reset_ownership(self): imported_dash_id = import_dashboard(dash_with_1_slice) imported_dash = self.get_dash(imported_dash_id) - self.assertEqual(imported_dash.created_by, gamma_user) - self.assertEqual(imported_dash.changed_by, gamma_user) - self.assertEqual(imported_dash.owners, [gamma_user]) + assert imported_dash.created_by == gamma_user + assert imported_dash.changed_by == gamma_user + assert imported_dash.owners == [gamma_user] imported_slc = imported_dash.slices[0] - self.assertEqual(imported_slc.created_by, gamma_user) - self.assertEqual(imported_slc.changed_by, gamma_user) - self.assertEqual(imported_slc.owners, [gamma_user]) + assert imported_slc.created_by == gamma_user + assert imported_slc.changed_by == gamma_user + assert imported_slc.owners == [gamma_user] @pytest.mark.skip def test_import_override_dashboard_slice_reset_ownership(self): admin_user = security_manager.find_user(username="admin") - self.assertTrue(admin_user) + assert admin_user gamma_user = security_manager.find_user(username="gamma") - self.assertTrue(gamma_user) + assert gamma_user g.user = gamma_user dash_with_1_slice = self._create_dashboard_for_import(id_=10300) imported_dash_id = import_dashboard(dash_with_1_slice) imported_dash = self.get_dash(imported_dash_id) - self.assertEqual(imported_dash.created_by, gamma_user) - self.assertEqual(imported_dash.changed_by, gamma_user) - self.assertEqual(imported_dash.owners, [gamma_user]) + assert imported_dash.created_by == gamma_user + assert imported_dash.changed_by == gamma_user + assert imported_dash.owners == [gamma_user] imported_slc = imported_dash.slices[0] - self.assertEqual(imported_slc.created_by, gamma_user) - self.assertEqual(imported_slc.changed_by, gamma_user) - self.assertEqual(imported_slc.owners, [gamma_user]) + assert imported_slc.created_by == gamma_user + assert imported_slc.changed_by == gamma_user + assert imported_slc.owners == [gamma_user] # re-import with another user shouldn't change the permissions g.user = admin_user @@ -547,14 +536,14 @@ def test_import_override_dashboard_slice_reset_ownership(self): imported_dash_id = import_dashboard(dash_with_1_slice) imported_dash = self.get_dash(imported_dash_id) - self.assertEqual(imported_dash.created_by, gamma_user) - self.assertEqual(imported_dash.changed_by, gamma_user) - self.assertEqual(imported_dash.owners, [gamma_user]) + assert imported_dash.created_by == gamma_user + assert imported_dash.changed_by == gamma_user + assert imported_dash.owners == [gamma_user] imported_slc = imported_dash.slices[0] - self.assertEqual(imported_slc.created_by, gamma_user) - self.assertEqual(imported_slc.changed_by, gamma_user) - self.assertEqual(imported_slc.owners, [gamma_user]) + assert imported_slc.created_by == gamma_user + assert imported_slc.changed_by == gamma_user + assert imported_slc.owners == [gamma_user] def _create_dashboard_for_import(self, id_=10100): slc = self.create_slice( @@ -600,10 +589,11 @@ def test_import_table_1_col_1_met(self): imported_id = import_dataset(table, db_id, import_time=1990) imported = self.get_table_by_id(imported_id) self.assert_table_equals(table, imported) - self.assertEqual( - {"remote_id": 10002, "import_time": 1990, "database_name": "examples"}, - json.loads(imported.params), - ) + assert { + "remote_id": 10002, + "import_time": 1990, + "database_name": "examples", + } == json.loads(imported.params) def test_import_table_2_col_2_met(self): schema = get_example_default_schema() @@ -642,7 +632,7 @@ def test_import_table_override(self): imported_over_id = import_dataset(table_over, db_id, import_time=1992) imported_over = self.get_table_by_id(imported_over_id) - self.assertEqual(imported_id, imported_over.id) + assert imported_id == imported_over.id expected_table = self.create_table( "table_override", id=10003, @@ -673,7 +663,7 @@ def test_import_table_override_identical(self): ) imported_id_copy = import_dataset(copy_table, db_id, import_time=1994) - self.assertEqual(imported_id, imported_id_copy) + assert imported_id == imported_id_copy self.assert_table_equals(copy_table, self.get_table_by_id(imported_id)) diff --git a/tests/integration_tests/log_api_tests.py b/tests/integration_tests/log_api_tests.py index fae09754aa9b5..0ed588d50be86 100644 --- a/tests/integration_tests/log_api_tests.py +++ b/tests/integration_tests/log_api_tests.py @@ -82,7 +82,7 @@ def test_not_enabled(self): arguments = {"filters": [{"col": "action", "opr": "sw", "value": "some_"}]} uri = f"api/v1/log/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 def test_get_list(self): """ @@ -94,11 +94,11 @@ def test_get_list(self): arguments = {"filters": [{"col": "action", "opr": "sw", "value": "some_"}]} uri = f"api/v1/log/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(list(response["result"][0].keys()), EXPECTED_COLUMNS) - self.assertEqual(response["result"][0]["action"], "some_action") - self.assertEqual(response["result"][0]["user"], {"username": "admin"}) + assert list(response["result"][0].keys()) == EXPECTED_COLUMNS + assert response["result"][0]["action"] == "some_action" + assert response["result"][0]["user"] == {"username": "admin"} db.session.delete(log) db.session.commit() @@ -111,10 +111,10 @@ def test_get_list_not_allowed(self): self.login(GAMMA_USERNAME) uri = "api/v1/log/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 self.login(ALPHA_USERNAME) rv = self.client.get(uri) - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 db.session.delete(log) db.session.commit() @@ -127,12 +127,12 @@ def test_get_item(self): self.login(ADMIN_USERNAME) uri = f"api/v1/log/{log.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(list(response["result"].keys()), EXPECTED_COLUMNS) - self.assertEqual(response["result"]["action"], "some_action") - self.assertEqual(response["result"]["user"], {"username": "admin"}) + assert list(response["result"].keys()) == EXPECTED_COLUMNS + assert response["result"]["action"] == "some_action" + assert response["result"]["user"] == {"username": "admin"} db.session.delete(log) db.session.commit() @@ -145,7 +145,7 @@ def test_delete_log(self): self.login(ADMIN_USERNAME) uri = f"api/v1/log/{log.id}" rv = self.client.delete(uri) - self.assertEqual(rv.status_code, 405) + assert rv.status_code == 405 db.session.delete(log) db.session.commit() @@ -160,7 +160,7 @@ def test_update_log(self): log_data = {"action": "some_action"} uri = f"api/v1/log/{log.id}" rv = self.client.put(uri, json=log_data) - self.assertEqual(rv.status_code, 405) + assert rv.status_code == 405 db.session.delete(log) db.session.commit() @@ -176,7 +176,7 @@ def test_get_recent_activity(self): uri = f"api/v1/log/recent_activity/" # noqa: F541 rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) db.session.delete(log1) @@ -184,21 +184,18 @@ def test_get_recent_activity(self): db.session.delete(dash) db.session.commit() - self.assertEqual( - response, - { - "result": [ - { - "action": "dashboard", - "item_type": "dashboard", - "item_url": "/superset/dashboard/dash_slug/", - "item_title": "dash_title", - "time": ANY, - "time_delta_humanized": ANY, - } - ] - }, - ) + assert response == { + "result": [ + { + "action": "dashboard", + "item_type": "dashboard", + "item_url": "/superset/dashboard/dash_slug/", + "item_title": "dash_title", + "time": ANY, + "time_delta_humanized": ANY, + } + ] + } def test_get_recent_activity_actions_filter(self): """ @@ -219,9 +216,9 @@ def test_get_recent_activity_actions_filter(self): db.session.delete(dash) db.session.commit() - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(len(response["result"]), 1) + assert len(response["result"]) == 1 def test_get_recent_activity_distinct_false(self): """ @@ -243,9 +240,9 @@ def test_get_recent_activity_distinct_false(self): db.session.delete(log2) db.session.delete(dash) db.session.commit() - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(len(response["result"]), 2) + assert len(response["result"]) == 2 def test_get_recent_activity_pagination(self): """ @@ -269,31 +266,28 @@ def test_get_recent_activity_pagination(self): uri = f"api/v1/log/recent_activity/?q={prison.dumps(arguments)}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, - { - "result": [ - { - "action": "dashboard", - "item_type": "dashboard", - "item_url": "/superset/dashboard/dash3_slug/", - "item_title": "dash3_title", - "time": ANY, - "time_delta_humanized": ANY, - }, - { - "action": "dashboard", - "item_type": "dashboard", - "item_url": "/superset/dashboard/dash2_slug/", - "item_title": "dash2_title", - "time": ANY, - "time_delta_humanized": ANY, - }, - ] - }, - ) + assert response == { + "result": [ + { + "action": "dashboard", + "item_type": "dashboard", + "item_url": "/superset/dashboard/dash3_slug/", + "item_title": "dash3_title", + "time": ANY, + "time_delta_humanized": ANY, + }, + { + "action": "dashboard", + "item_type": "dashboard", + "item_url": "/superset/dashboard/dash2_slug/", + "item_title": "dash2_title", + "time": ANY, + "time_delta_humanized": ANY, + }, + ] + } arguments = {"page": 1, "page_size": 2} uri = f"api/v1/log/recent_activity/?q={prison.dumps(arguments)}" @@ -307,20 +301,17 @@ def test_get_recent_activity_pagination(self): db.session.delete(dash3) db.session.commit() - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 response = json.loads(rv.data.decode("utf-8")) - self.assertEqual( - response, - { - "result": [ - { - "action": "dashboard", - "item_type": "dashboard", - "item_url": "/superset/dashboard/dash_slug/", - "item_title": "dash_title", - "time": ANY, - "time_delta_humanized": ANY, - } - ] - }, - ) + assert response == { + "result": [ + { + "action": "dashboard", + "item_type": "dashboard", + "item_url": "/superset/dashboard/dash_slug/", + "item_title": "dash_title", + "time": ANY, + "time_delta_humanized": ANY, + } + ] + } diff --git a/tests/integration_tests/logging_configurator_tests.py b/tests/integration_tests/logging_configurator_tests.py index 60e0ded692962..9b4d88c8530ec 100644 --- a/tests/integration_tests/logging_configurator_tests.py +++ b/tests/integration_tests/logging_configurator_tests.py @@ -52,4 +52,4 @@ def configure_logging(self, app_config, debug_mode): cfg.configure_logging(MagicMock(), True) logging.info("test", extra={"testattr": "foo"}) - self.assertTrue(handler.received) + assert handler.received diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index fb22f40fb2216..dec05776a477c 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file +import re from superset.utils.core import DatasourceType from superset.utils import json import unittest @@ -59,22 +60,22 @@ def test_database_schema_presto(self): with model.get_sqla_engine() as engine: db = make_url(engine.url).database - self.assertEqual("hive/default", db) + assert "hive/default" == db with model.get_sqla_engine(schema="core_db") as engine: db = make_url(engine.url).database - self.assertEqual("hive/core_db", db) + assert "hive/core_db" == db sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive" model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri) with model.get_sqla_engine() as engine: db = make_url(engine.url).database - self.assertEqual("hive", db) + assert "hive" == db with model.get_sqla_engine(schema="core_db") as engine: db = make_url(engine.url).database - self.assertEqual("hive/core_db", db) + assert "hive/core_db" == db def test_database_schema_postgres(self): sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod" @@ -82,11 +83,11 @@ def test_database_schema_postgres(self): with model.get_sqla_engine() as engine: db = make_url(engine.url).database - self.assertEqual("prod", db) + assert "prod" == db with model.get_sqla_engine(schema="foo") as engine: db = make_url(engine.url).database - self.assertEqual("prod", db) + assert "prod" == db @unittest.skipUnless( SupersetTestCase.is_module_installed("thrift"), "thrift not installed" @@ -100,11 +101,11 @@ def test_database_schema_hive(self): with model.get_sqla_engine() as engine: db = make_url(engine.url).database - self.assertEqual("default", db) + assert "default" == db with model.get_sqla_engine(schema="core_db") as engine: db = make_url(engine.url).database - self.assertEqual("core_db", db) + assert "core_db" == db @unittest.skipUnless( SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" @@ -115,11 +116,11 @@ def test_database_schema_mysql(self): with model.get_sqla_engine() as engine: db = make_url(engine.url).database - self.assertEqual("superset", db) + assert "superset" == db with model.get_sqla_engine(schema="staging") as engine: db = make_url(engine.url).database - self.assertEqual("staging", db) + assert "staging" == db @unittest.skipUnless( SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" @@ -133,12 +134,12 @@ def test_database_impersonate_user(self): model.impersonate_user = True with model.get_sqla_engine() as engine: username = make_url(engine.url).username - self.assertEqual(example_user.username, username) + assert example_user.username == username model.impersonate_user = False with model.get_sqla_engine() as engine: username = make_url(engine.url).username - self.assertNotEqual(example_user.username, username) + assert example_user.username != username @mock.patch("superset.models.core.create_engine") def test_impersonate_user_presto(self, mocked_create_engine): @@ -344,20 +345,20 @@ def test_single_statement(self): if main_db.backend == "mysql": df = main_db.get_df("SELECT 1", None, None) - self.assertEqual(df.iat[0, 0], 1) + assert df.iat[0, 0] == 1 df = main_db.get_df("SELECT 1;", None, None) - self.assertEqual(df.iat[0, 0], 1) + assert df.iat[0, 0] == 1 def test_multi_statement(self): main_db = get_example_database() if main_db.backend == "mysql": df = main_db.get_df("USE superset; SELECT 1", None, None) - self.assertEqual(df.iat[0, 0], 1) + assert df.iat[0, 0] == 1 df = main_db.get_df("USE superset; SELECT ';';", None, None) - self.assertEqual(df.iat[0, 0], ";") + assert df.iat[0, 0] == ";" @mock.patch("superset.models.core.create_engine") def test_get_sqla_engine(self, mocked_create_engine): @@ -404,20 +405,20 @@ def test_get_timestamp_expression_epoch(self): sqla_literal = ds_col.get_timestamp_expression(None) compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": - self.assertEqual(compiled, "from_unixtime(ds)") + assert compiled == "from_unixtime(ds)" ds_col.python_date_format = "epoch_s" sqla_literal = ds_col.get_timestamp_expression("P1D") compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": - self.assertEqual(compiled, "DATE(from_unixtime(ds))") + assert compiled == "DATE(from_unixtime(ds))" prev_ds_expr = ds_col.expression ds_col.expression = "DATE_ADD(ds, 1)" sqla_literal = ds_col.get_timestamp_expression("P1D") compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": - self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))") + assert compiled == "DATE(from_unixtime(DATE_ADD(ds, 1)))" ds_col.expression = prev_ds_expr def query_with_expr_helper(self, is_timeseries, inner_join=True): @@ -448,16 +449,16 @@ def query_with_expr_helper(self, is_timeseries, inner_join=True): series_limit=15 if inner_join and is_timeseries else None, ) qr = tbl.query(query_obj) - self.assertEqual(qr.status, QueryStatus.SUCCESS) + assert qr.status == QueryStatus.SUCCESS sql = qr.query - self.assertIn(arbitrary_gby, sql) - self.assertIn("name", sql) + assert arbitrary_gby in sql + assert "name" in sql if inner_join and is_timeseries: - self.assertIn("JOIN", sql.upper()) + assert "JOIN" in sql.upper() else: - self.assertNotIn("JOIN", sql.upper()) + assert "JOIN" not in sql.upper() spec.allows_joins = old_inner_join - self.assertFalse(qr.df.empty) + assert not qr.df.empty return qr.df @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -475,7 +476,7 @@ def canonicalize_df(df): name_list1 = canonicalize_df(df1).name.values.tolist() df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False) name_list2 = canonicalize_df(df1).name.values.tolist() - self.assertFalse(df2.empty) + assert not df2.empty assert name_list2 == name_list1 @@ -498,14 +499,14 @@ def test_sql_mutator(self): extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertNotIn("-- COMMENT", sql) + assert "-- COMMENT" not in sql def mutator(*args, **kwargs): return "-- COMMENT\n" + args[0] app.config["SQL_QUERY_MUTATOR"] = mutator sql = tbl.get_query_str(query_obj) - self.assertIn("-- COMMENT", sql) + assert "-- COMMENT" in sql app.config["SQL_QUERY_MUTATOR"] = None @@ -524,15 +525,15 @@ def test_sql_mutator_different_params(self): extras={}, ) sql = tbl.get_query_str(query_obj) - self.assertNotIn("-- COMMENT", sql) + assert "-- COMMENT" not in sql def mutator(sql, database=None, **kwargs): return "-- COMMENT\n--" + "\n" + str(database) + "\n" + sql app.config["SQL_QUERY_MUTATOR"] = mutator mutated_sql = tbl.get_query_str(query_obj) - self.assertIn("-- COMMENT", mutated_sql) - self.assertIn(tbl.database.name, mutated_sql) + assert "-- COMMENT" in mutated_sql + assert tbl.database.name in mutated_sql app.config["SQL_QUERY_MUTATOR"] = None @@ -554,7 +555,7 @@ def test_query_with_non_existent_metrics(self): with self.assertRaises(Exception) as context: tbl.get_query_str(query_obj) - self.assertTrue("Metric 'invalid' does not exist", context.exception) + assert "Metric 'invalid' does not exist", context.exception def test_query_label_without_group_by(self): tbl = self.get_table(name="birth_names") @@ -577,7 +578,7 @@ def test_query_label_without_group_by(self): ) sql = tbl.get_query_str(query_obj) - self.assertRegex(sql, r'name AS ["`]?Given Name["`]?') + assert re.search('name AS ["`]?Given Name["`]?', sql) # noqa: F821 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_data_for_slices_with_no_query_context(self): diff --git a/tests/integration_tests/queries/api_tests.py b/tests/integration_tests/queries/api_tests.py index 2819c23b4141d..92a1f47fcba8e 100644 --- a/tests/integration_tests/queries/api_tests.py +++ b/tests/integration_tests/queries/api_tests.py @@ -138,7 +138,7 @@ def test_get_query(self): self.login(ADMIN_USERNAME) uri = f"api/v1/query/{query.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_result = { "database": {"id": example_db.id}, @@ -163,7 +163,7 @@ def test_get_query(self): "tracking_url": None, } data = json.loads(rv.data.decode("utf-8")) - self.assertIn("changed_on", data["result"]) + assert "changed_on" in data["result"] for key, value in data["result"].items(): # We can't assert timestamp if key not in ( @@ -173,7 +173,7 @@ def test_get_query(self): "start_time", "id", ): - self.assertEqual(value, expected_result[key]) + assert value == expected_result[key] # rollback changes db.session.delete(query) db.session.commit() @@ -189,7 +189,7 @@ def test_get_query_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/query/{max_id + 1}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 db.session.delete(query) db.session.commit() @@ -222,30 +222,30 @@ def test_get_query_no_data_access(self): self.login(username="gamma_1", password="password") uri = f"api/v1/query/{query_gamma2.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 uri = f"api/v1/query/{query_gamma1.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # Gamma2 user, only sees their own queries self.logout() self.login(username="gamma_2", password="password") uri = f"api/v1/query/{query_gamma1.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 uri = f"api/v1/query/{query_gamma2.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # Admin's have the "all query access" permission self.logout() self.login(ADMIN_USERNAME) uri = f"api/v1/query/{query_gamma1.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 uri = f"api/v1/query/{query_gamma2.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # rollback changes db.session.delete(query_gamma1) @@ -262,7 +262,7 @@ def test_get_list_query(self): self.login(ADMIN_USERNAME) uri = "api/v1/query/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == QUERIES_FIXTURE_COUNT # check expected columns @@ -433,11 +433,11 @@ def test_get_updated_since(self): timestamp = datetime.timestamp(now - timedelta(days=2)) * 1000 uri = f"api/v1/query/updated_since?q={prison.dumps({'last_updated_ms': timestamp})}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_result = updated_query.to_dict() data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(len(data["result"]), 1) + assert len(data["result"]) == 1 for key, value in data["result"][0].items(): # We can't assert timestamp if key not in ( @@ -447,7 +447,7 @@ def test_get_updated_since(self): "start_time", "id", ): - self.assertEqual(value, expected_result[key]) + assert value == expected_result[key] # rollback changes db.session.delete(old_query) db.session.delete(updated_query) diff --git a/tests/integration_tests/queries/saved_queries/api_tests.py b/tests/integration_tests/queries/saved_queries/api_tests.py index 4ce0a79dac9a4..aa2c931104e4e 100644 --- a/tests/integration_tests/queries/saved_queries/api_tests.py +++ b/tests/integration_tests/queries/saved_queries/api_tests.py @@ -467,26 +467,22 @@ def test_get_saved_queries_tag_filters(self): # Filter by tag ID filter_params = get_filter_params("saved_query_tag_id", tag.id) response_by_id = self.get_list("saved_query", filter_params) - self.assertEqual(response_by_id.status_code, 200) + assert response_by_id.status_code == 200 data_by_id = json.loads(response_by_id.data.decode("utf-8")) # Filter by tag name filter_params = get_filter_params("saved_query_tags", tag.name) response_by_name = self.get_list("saved_query", filter_params) - self.assertEqual(response_by_name.status_code, 200) + assert response_by_name.status_code == 200 data_by_name = json.loads(response_by_name.data.decode("utf-8")) # Compare results - self.assertEqual( - data_by_id["count"], - data_by_name["count"], - len(expected_saved_queries), - ) - self.assertEqual( - set(query["id"] for query in data_by_id["result"]), - set(query["id"] for query in data_by_name["result"]), - set(query.id for query in expected_saved_queries), + assert data_by_id["count"] == data_by_name["count"], len( + expected_saved_queries ) + assert set(query["id"] for query in data_by_id["result"]) == set( + query["id"] for query in data_by_name["result"] + ), set(query.id for query in expected_saved_queries) @pytest.mark.usefixtures("create_saved_queries") def test_get_saved_query_favorite_filter(self): diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 9df1b3a8723a1..4822b690ed84c 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -70,15 +70,15 @@ def test_schema_deserialization(self): payload = get_query_context("birth_names", add_postprocessing_operations=True) query_context = ChartDataQueryContextSchema().load(payload) - self.assertEqual(len(query_context.queries), len(payload["queries"])) + assert len(query_context.queries) == len(payload["queries"]) for query_idx, query in enumerate(query_context.queries): payload_query = payload["queries"][query_idx] # check basic properties - self.assertEqual(query.extras, payload_query["extras"]) - self.assertEqual(query.filter, payload_query["filters"]) - self.assertEqual(query.columns, payload_query["columns"]) + assert query.extras == payload_query["extras"] + assert query.filter == payload_query["filters"] + assert query.columns == payload_query["columns"] # metrics are mutated during creation for metric_idx, metric in enumerate(query.metrics): @@ -88,16 +88,16 @@ def test_schema_deserialization(self): if "expressionType" in payload_metric else payload_metric["label"] ) - self.assertEqual(metric, payload_metric) + assert metric == payload_metric - self.assertEqual(query.orderby, payload_query["orderby"]) - self.assertEqual(query.time_range, payload_query["time_range"]) + assert query.orderby == payload_query["orderby"] + assert query.time_range == payload_query["time_range"] # check post processing operation properties for post_proc_idx, post_proc in enumerate(query.post_processing): payload_post_proc = payload_query["post_processing"][post_proc_idx] - self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) - self.assertEqual(post_proc["options"], payload_post_proc["options"]) + assert post_proc["operation"] == payload_post_proc["operation"] + assert post_proc["options"] == payload_post_proc["options"] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_cache(self): @@ -128,12 +128,12 @@ def test_cache(self): rehydrated_qo = rehydrated_qc.queries[0] rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo) - self.assertEqual(rehydrated_qc.datasource, query_context.datasource) - self.assertEqual(len(rehydrated_qc.queries), 1) - self.assertEqual(query_cache_key, rehydrated_query_cache_key) - self.assertEqual(rehydrated_qc.result_type, query_context.result_type) - self.assertEqual(rehydrated_qc.result_format, query_context.result_format) - self.assertFalse(rehydrated_qc.force) + assert rehydrated_qc.datasource == query_context.datasource + assert len(rehydrated_qc.queries) == 1 + assert query_cache_key == rehydrated_query_cache_key + assert rehydrated_qc.result_type == query_context.result_type + assert rehydrated_qc.result_format == query_context.result_format + assert not rehydrated_qc.force def test_query_cache_key_changes_when_datasource_is_updated(self): payload = get_query_context("birth_names") @@ -164,7 +164,7 @@ def test_query_cache_key_changes_when_datasource_is_updated(self): cache_key_new = query_context.query_cache_key(query_object) # the new cache_key should be different due to updated datasource - self.assertNotEqual(cache_key_original, cache_key_new) + assert cache_key_original != cache_key_new def test_query_cache_key_changes_when_metric_is_updated(self): payload = get_query_context("birth_names") @@ -198,7 +198,7 @@ def test_query_cache_key_changes_when_metric_is_updated(self): db.session.commit() # the new cache_key should be different due to updated datasource - self.assertNotEqual(cache_key_original, cache_key_new) + assert cache_key_original != cache_key_new def test_query_cache_key_does_not_change_for_non_existent_or_null(self): payload = get_query_context("birth_names", add_postprocessing_operations=True) @@ -228,14 +228,14 @@ def test_query_cache_key_changes_when_post_processing_is_updated(self): query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key = query_context.query_cache_key(query_object) - self.assertEqual(cache_key_original, cache_key) + assert cache_key_original == cache_key # ensure query without post processing operation is different payload["queries"][0].pop("post_processing") query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key = query_context.query_cache_key(query_object) - self.assertNotEqual(cache_key_original, cache_key) + assert cache_key_original != cache_key def test_query_cache_key_changes_when_time_offsets_is_updated(self): payload = get_query_context("birth_names", add_time_offsets=True) @@ -248,7 +248,7 @@ def test_query_cache_key_changes_when_time_offsets_is_updated(self): query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key = query_context.query_cache_key(query_object) - self.assertNotEqual(cache_key_original, cache_key) + assert cache_key_original != cache_key def test_handle_metrics_field(self): """ @@ -265,7 +265,7 @@ def test_handle_metrics_field(self): payload["queries"][0]["metrics"] = ["sum__num", {"label": "abc"}, adhoc_metric] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - self.assertEqual(query_object.metrics, ["sum__num", "abc", adhoc_metric]) + assert query_object.metrics == ["sum__num", "abc", adhoc_metric] def test_convert_deprecated_fields(self): """ @@ -280,12 +280,12 @@ def test_convert_deprecated_fields(self): payload["queries"][0]["granularity_sqla"] = "timecol" payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}] query_context = ChartDataQueryContextSchema().load(payload) - self.assertEqual(len(query_context.queries), 1) + assert len(query_context.queries) == 1 query_object = query_context.queries[0] - self.assertEqual(query_object.granularity, "timecol") - self.assertEqual(query_object.columns, columns) - self.assertEqual(query_object.series_limit, 99) - self.assertEqual(query_object.series_limit_metric, "sum__num") + assert query_object.granularity == "timecol" + assert query_object.columns == columns + assert query_object.series_limit == 99 + assert query_object.series_limit_metric == "sum__num" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_csv_response_format(self): @@ -297,10 +297,10 @@ def test_csv_response_format(self): payload["queries"][0]["row_limit"] = 10 query_context: QueryContext = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() - self.assertEqual(len(responses), 1) + assert len(responses) == 1 data = responses["queries"][0]["data"] - self.assertIn("name,sum__num\n", data) - self.assertEqual(len(data.split("\n")), 12) + assert "name,sum__num\n" in data + assert len(data.split("\n")) == 12 def test_sql_injection_via_groupby(self): """ @@ -352,11 +352,11 @@ def test_samples_response_type(self): payload["queries"][0]["row_limit"] = 5 query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() - self.assertEqual(len(responses), 1) + assert len(responses) == 1 data = responses["queries"][0]["data"] - self.assertIsInstance(data, list) - self.assertEqual(len(data), 5) - self.assertNotIn("sum__num", data[0]) + assert isinstance(data, list) + assert len(data) == 5 + assert "sum__num" not in data[0] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_response_type(self): @@ -489,7 +489,7 @@ def test_query_object_unknown_fields(self): query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() new_cache_key = responses["queries"][0]["cache_key"] - self.assertEqual(orig_cache_key, new_cache_key) + assert orig_cache_key == new_cache_key @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_time_offsets_in_query_object(self): @@ -505,21 +505,18 @@ def test_time_offsets_in_query_object(self): payload["queries"][0]["time_range"] = "1990 : 1991" query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() - self.assertEqual( - responses["queries"][0]["colnames"], - [ - "__timestamp", - "name", - "sum__num", - "sum__num__1 year ago", - "sum__num__1 year later", - ], - ) + assert responses["queries"][0]["colnames"] == [ + "__timestamp", + "name", + "sum__num", + "sum__num__1 year ago", + "sum__num__1 year later", + ] sqls = [ sql for sql in responses["queries"][0]["query"].split(";") if sql.strip() ] - self.assertEqual(len(sqls), 3) + assert len(sqls) == 3 # 1 year ago assert re.search(r"1989-01-01.+1990-01-01", sqls[1], re.S) assert re.search(r"1990-01-01.+1991-01-01", sqls[1], re.S) @@ -560,9 +557,9 @@ def test_processing_time_offsets_cache(self): cache_keys = rv["cache_keys"] cache_keys__1_year_ago = cache_keys[0] cache_keys__1_year_later = cache_keys[1] - self.assertIsNotNone(cache_keys__1_year_ago) - self.assertIsNotNone(cache_keys__1_year_later) - self.assertNotEqual(cache_keys__1_year_ago, cache_keys__1_year_later) + assert cache_keys__1_year_ago is not None + assert cache_keys__1_year_later is not None + assert cache_keys__1_year_ago != cache_keys__1_year_later # swap offsets payload["queries"][0]["time_offsets"] = ["1 year later", "1 year ago"] @@ -570,8 +567,8 @@ def test_processing_time_offsets_cache(self): query_object = query_context.queries[0] rv = query_context.processing_time_offsets(df.copy(), query_object) cache_keys = rv["cache_keys"] - self.assertEqual(cache_keys__1_year_ago, cache_keys[1]) - self.assertEqual(cache_keys__1_year_later, cache_keys[0]) + assert cache_keys__1_year_ago == cache_keys[1] + assert cache_keys__1_year_later == cache_keys[0] # remove all offsets payload["queries"][0]["time_offsets"] = [] @@ -582,9 +579,9 @@ def test_processing_time_offsets_cache(self): query_object, ) - self.assertEqual(rv["df"].shape, df.shape) - self.assertEqual(rv["queries"], []) - self.assertEqual(rv["cache_keys"], []) + assert rv["df"].shape == df.shape + assert rv["queries"] == [] + assert rv["cache_keys"] == [] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_time_offsets_sql(self): @@ -732,7 +729,7 @@ def test_time_offsets_in_query_object_no_limit(self, query_result_mock): row_limit_pattern_with_config_value = r"LIMIT " + re.escape( str(row_limit_value) ) - self.assertEqual(len(sqls), 2) + assert len(sqls) == 2 # 1 year ago assert re.search(r"1989-01-01.+1990-01-01", sqls[0], re.S) assert not re.search(r"LIMIT 100", sqls[0], re.S) diff --git a/tests/integration_tests/reports/api_tests.py b/tests/integration_tests/reports/api_tests.py index 180ac81e2a5fe..7664dc4584e0e 100644 --- a/tests/integration_tests/reports/api_tests.py +++ b/tests/integration_tests/reports/api_tests.py @@ -1673,7 +1673,7 @@ def test_update_report_not_owned(self): } uri = f"api/v1/report/{report_schedule.id}" rv = self.put_assert_metric(uri, report_schedule_data, "put") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 @pytest.mark.usefixtures("create_report_schedules") def test_update_report_preserve_ownership(self): @@ -1819,7 +1819,7 @@ def test_delete_report_not_owned(self): self.login(username="alpha2", password="password") uri = f"api/v1/report/{report_schedule.id}" rv = self.delete_assert_metric(uri, "delete") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 @pytest.mark.usefixtures("create_report_schedules") def test_bulk_delete_report_schedule(self): @@ -1876,7 +1876,7 @@ def test_bulk_delete_report_not_owned(self): self.login(username="alpha2", password="password") uri = f"api/v1/report/?q={prison.dumps(report_schedules_ids)}" rv = self.delete_assert_metric(uri, "bulk_delete") - self.assertEqual(rv.status_code, 403) + assert rv.status_code == 403 @pytest.mark.usefixtures("create_report_schedules") def test_get_list_report_schedule_logs(self): diff --git a/tests/integration_tests/result_set_tests.py b/tests/integration_tests/result_set_tests.py index e58e16f07c057..fcdbd19d5a763 100644 --- a/tests/integration_tests/result_set_tests.py +++ b/tests/integration_tests/result_set_tests.py @@ -28,74 +28,77 @@ class TestSupersetResultSet(SupersetTestCase): def test_dedup(self): - self.assertEqual(dedup(["foo", "bar"]), ["foo", "bar"]) - self.assertEqual( - dedup(["foo", "bar", "foo", "bar", "Foo"]), - ["foo", "bar", "foo__1", "bar__1", "Foo"], - ) - self.assertEqual( - dedup(["foo", "bar", "bar", "bar", "Bar"]), - ["foo", "bar", "bar__1", "bar__2", "Bar"], - ) - self.assertEqual( - dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False), - ["foo", "bar", "bar__1", "bar__2", "Bar__3"], - ) + assert dedup(["foo", "bar"]) == ["foo", "bar"] + assert dedup(["foo", "bar", "foo", "bar", "Foo"]) == [ + "foo", + "bar", + "foo__1", + "bar__1", + "Foo", + ] + assert dedup(["foo", "bar", "bar", "bar", "Bar"]) == [ + "foo", + "bar", + "bar__1", + "bar__2", + "Bar", + ] + assert dedup(["foo", "bar", "bar", "bar", "Bar"], case_sensitive=False) == [ + "foo", + "bar", + "bar__1", + "bar__2", + "Bar__3", + ] def test_get_columns_basic(self): data = [("a1", "b1", "c1"), ("a2", "b2", "c2")] cursor_descr = (("a", "string"), ("b", "string"), ("c", "string")) results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual( - results.columns, - [ - { - "is_dttm": False, - "type": "STRING", - "type_generic": GenericDataType.STRING, - "column_name": "a", - "name": "a", - }, - { - "is_dttm": False, - "type": "STRING", - "type_generic": GenericDataType.STRING, - "column_name": "b", - "name": "b", - }, - { - "is_dttm": False, - "type": "STRING", - "type_generic": GenericDataType.STRING, - "column_name": "c", - "name": "c", - }, - ], - ) + assert results.columns == [ + { + "is_dttm": False, + "type": "STRING", + "type_generic": GenericDataType.STRING, + "column_name": "a", + "name": "a", + }, + { + "is_dttm": False, + "type": "STRING", + "type_generic": GenericDataType.STRING, + "column_name": "b", + "name": "b", + }, + { + "is_dttm": False, + "type": "STRING", + "type_generic": GenericDataType.STRING, + "column_name": "c", + "name": "c", + }, + ] def test_get_columns_with_int(self): data = [("a1", 1), ("a2", 2)] cursor_descr = (("a", "string"), ("b", "int")) results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual( - results.columns, - [ - { - "is_dttm": False, - "type": "STRING", - "type_generic": GenericDataType.STRING, - "column_name": "a", - "name": "a", - }, - { - "is_dttm": False, - "type": "INT", - "type_generic": GenericDataType.NUMERIC, - "column_name": "b", - "name": "b", - }, - ], - ) + assert results.columns == [ + { + "is_dttm": False, + "type": "STRING", + "type_generic": GenericDataType.STRING, + "column_name": "a", + "name": "a", + }, + { + "is_dttm": False, + "type": "INT", + "type_generic": GenericDataType.NUMERIC, + "column_name": "b", + "name": "b", + }, + ] def test_get_columns_type_inference(self): data = [ @@ -104,72 +107,69 @@ def test_get_columns_type_inference(self): ] cursor_descr = (("a", None), ("b", None), ("c", None), ("d", None), ("e", None)) results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual( - results.columns, - [ - { - "is_dttm": False, - "type": "FLOAT", - "type_generic": GenericDataType.NUMERIC, - "column_name": "a", - "name": "a", - }, - { - "is_dttm": False, - "type": "INT", - "type_generic": GenericDataType.NUMERIC, - "column_name": "b", - "name": "b", - }, - { - "is_dttm": False, - "type": "STRING", - "type_generic": GenericDataType.STRING, - "column_name": "c", - "name": "c", - }, - { - "is_dttm": True, - "type": "DATETIME", - "type_generic": GenericDataType.TEMPORAL, - "column_name": "d", - "name": "d", - }, - { - "is_dttm": False, - "type": "BOOL", - "type_generic": GenericDataType.BOOLEAN, - "column_name": "e", - "name": "e", - }, - ], - ) + assert results.columns == [ + { + "is_dttm": False, + "type": "FLOAT", + "type_generic": GenericDataType.NUMERIC, + "column_name": "a", + "name": "a", + }, + { + "is_dttm": False, + "type": "INT", + "type_generic": GenericDataType.NUMERIC, + "column_name": "b", + "name": "b", + }, + { + "is_dttm": False, + "type": "STRING", + "type_generic": GenericDataType.STRING, + "column_name": "c", + "name": "c", + }, + { + "is_dttm": True, + "type": "DATETIME", + "type_generic": GenericDataType.TEMPORAL, + "column_name": "d", + "name": "d", + }, + { + "is_dttm": False, + "type": "BOOL", + "type_generic": GenericDataType.BOOLEAN, + "column_name": "e", + "name": "e", + }, + ] def test_is_date(self): data = [("a", 1), ("a", 2)] cursor_descr = (("a", "string"), ("a", "string")) results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.is_temporal("DATE"), True) - self.assertEqual(results.is_temporal("DATETIME"), True) - self.assertEqual(results.is_temporal("TIME"), True) - self.assertEqual(results.is_temporal("TIMESTAMP"), True) - self.assertEqual(results.is_temporal("STRING"), False) - self.assertEqual(results.is_temporal(""), False) - self.assertEqual(results.is_temporal(None), False) + assert results.is_temporal("DATE") is True + assert results.is_temporal("DATETIME") is True + assert results.is_temporal("TIME") is True + assert results.is_temporal("TIMESTAMP") is True + assert results.is_temporal("STRING") is False + assert results.is_temporal("") is False + assert results.is_temporal(None) is False def test_dedup_with_data(self): data = [("a", 1), ("a", 2)] cursor_descr = (("a", "string"), ("a", "string")) results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) column_names = [col["column_name"] for col in results.columns] - self.assertListEqual(column_names, ["a", "a__1"]) + self.assertListEqual(column_names, ["a", "a__1"]) # noqa: PT009 def test_int64_with_missing_data(self): data = [(None,), (1239162456494753670,), (None,), (None,), (None,), (None,)] cursor_descr = [("user_id", "bigint", None, None, None, None, True)] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "BIGINT") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.NUMERIC) + assert results.columns[0]["type"] == "BIGINT" + assert results.columns[0]["type_generic"] == GenericDataType.NUMERIC def test_data_as_list_of_lists(self): data = [[1, "a"], [2, "b"]] @@ -179,29 +179,26 @@ def test_data_as_list_of_lists(self): ] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) df = results.to_pandas_df() - self.assertEqual( - df_to_records(df), - [{"user_id": 1, "username": "a"}, {"user_id": 2, "username": "b"}], - ) + assert df_to_records(df) == [ + {"user_id": 1, "username": "a"}, + {"user_id": 2, "username": "b"}, + ] def test_nullable_bool(self): data = [(None,), (True,), (None,), (None,), (None,), (None,)] cursor_descr = [("is_test", "bool", None, None, None, None, True)] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "BOOL") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.BOOLEAN) + assert results.columns[0]["type"] == "BOOL" + assert results.columns[0]["type_generic"] == GenericDataType.BOOLEAN df = results.to_pandas_df() - self.assertEqual( - df_to_records(df), - [ - {"is_test": None}, - {"is_test": True}, - {"is_test": None}, - {"is_test": None}, - {"is_test": None}, - {"is_test": None}, - ], - ) + assert df_to_records(df) == [ + {"is_test": None}, + {"is_test": True}, + {"is_test": None}, + {"is_test": None}, + {"is_test": None}, + {"is_test": None}, + ] def test_nested_types(self): data = [ @@ -220,32 +217,29 @@ def test_nested_types(self): ] cursor_descr = [("id",), ("dict_arr",), ("num_arr",), ("map_col",)] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "INT") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.NUMERIC) - self.assertEqual(results.columns[1]["type"], "STRING") - self.assertEqual(results.columns[1]["type_generic"], GenericDataType.STRING) - self.assertEqual(results.columns[2]["type"], "STRING") - self.assertEqual(results.columns[2]["type_generic"], GenericDataType.STRING) - self.assertEqual(results.columns[3]["type"], "STRING") - self.assertEqual(results.columns[3]["type_generic"], GenericDataType.STRING) + assert results.columns[0]["type"] == "INT" + assert results.columns[0]["type_generic"] == GenericDataType.NUMERIC + assert results.columns[1]["type"] == "STRING" + assert results.columns[1]["type_generic"] == GenericDataType.STRING + assert results.columns[2]["type"] == "STRING" + assert results.columns[2]["type_generic"] == GenericDataType.STRING + assert results.columns[3]["type"] == "STRING" + assert results.columns[3]["type_generic"] == GenericDataType.STRING df = results.to_pandas_df() - self.assertEqual( - df_to_records(df), - [ - { - "id": 4, - "dict_arr": '[{"table_name": "unicode_test", "database_id": 1}]', - "num_arr": "[1, 2, 3]", - "map_col": "{'chart_name': 'scatter'}", - }, - { - "id": 3, - "dict_arr": '[{"table_name": "birth_names", "database_id": 1}]', - "num_arr": "[4, 5, 6]", - "map_col": "{'chart_name': 'plot'}", - }, - ], - ) + assert df_to_records(df) == [ + { + "id": 4, + "dict_arr": '[{"table_name": "unicode_test", "database_id": 1}]', + "num_arr": "[1, 2, 3]", + "map_col": "{'chart_name': 'scatter'}", + }, + { + "id": 3, + "dict_arr": '[{"table_name": "birth_names", "database_id": 1}]', + "num_arr": "[4, 5, 6]", + "map_col": "{'chart_name': 'plot'}", + }, + ] def test_single_column_multidim_nested_types(self): data = [ @@ -270,35 +264,30 @@ def test_single_column_multidim_nested_types(self): ] cursor_descr = [("metadata",)] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "STRING") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING) + assert results.columns[0]["type"] == "STRING" + assert results.columns[0]["type_generic"] == GenericDataType.STRING df = results.to_pandas_df() - self.assertEqual( - df_to_records(df), - [ - { - "metadata": '["test", [["foo", 123456, [[["test"], 3432546, 7657658766], [["fake"], 656756765, 324324324324]]]], ["test2", 43, 765765765], null, null]' - } - ], - ) + assert df_to_records(df) == [ + { + "metadata": '["test", [["foo", 123456, [[["test"], 3432546, 7657658766], [["fake"], 656756765, 324324324324]]]], ["test2", 43, 765765765], null, null]' + } + ] def test_nested_list_types(self): data = [([{"TestKey": [123456, "foo"]}],)] cursor_descr = [("metadata",)] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "STRING") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING) + assert results.columns[0]["type"] == "STRING" + assert results.columns[0]["type_generic"] == GenericDataType.STRING df = results.to_pandas_df() - self.assertEqual( - df_to_records(df), [{"metadata": '[{"TestKey": [123456, "foo"]}]'}] - ) + assert df_to_records(df) == [{"metadata": '[{"TestKey": [123456, "foo"]}]'}] def test_empty_datetime(self): data = [(None,)] cursor_descr = [("ds", "timestamp", None, None, None, None, True)] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "TIMESTAMP") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.TEMPORAL) + assert results.columns[0]["type"] == "TIMESTAMP" + assert results.columns[0]["type_generic"] == GenericDataType.TEMPORAL def test_no_type_coercion(self): data = [("a", 1), ("b", 2)] @@ -307,10 +296,10 @@ def test_no_type_coercion(self): ("two", "int", None, None, None, None, True), ] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns[0]["type"], "VARCHAR") - self.assertEqual(results.columns[0]["type_generic"], GenericDataType.STRING) - self.assertEqual(results.columns[1]["type"], "INT") - self.assertEqual(results.columns[1]["type_generic"], GenericDataType.NUMERIC) + assert results.columns[0]["type"] == "VARCHAR" + assert results.columns[0]["type_generic"] == GenericDataType.STRING + assert results.columns[1]["type"] == "INT" + assert results.columns[1]["type_generic"] == GenericDataType.NUMERIC def test_empty_data(self): data = [] @@ -319,4 +308,4 @@ def test_empty_data(self): ("emptytwo", "int", None, None, None, None, True), ] results = SupersetResultSet(data, cursor_descr, BaseEngineSpec) - self.assertEqual(results.columns, []) + assert results.columns == [] diff --git a/tests/integration_tests/security/api_tests.py b/tests/integration_tests/security/api_tests.py index 67aecd73b0912..49c2064e5db78 100644 --- a/tests/integration_tests/security/api_tests.py +++ b/tests/integration_tests/security/api_tests.py @@ -43,7 +43,7 @@ def _assert_get_csrf_token(self): response = self.client.get(uri) self.assert200(response) data = json.loads(response.data.decode("utf-8")) - self.assertEqual(generate_csrf(), data["result"]) + assert generate_csrf() == data["result"] def test_get_csrf_token(self): """ @@ -120,8 +120,8 @@ def test_post_guest_token_authorized(self): audience=get_url_host(), algorithms=["HS256"], ) - self.assertEqual(user, decoded_token["user"]) - self.assertEqual(resource, decoded_token["resources"][0]) + assert user == decoded_token["user"] + assert resource == decoded_token["resources"][0] @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_post_guest_token_bad_resources(self): diff --git a/tests/integration_tests/security/guest_token_security_tests.py b/tests/integration_tests/security/guest_token_security_tests.py index b8bab0390949f..5dcfd1357745f 100644 --- a/tests/integration_tests/security/guest_token_security_tests.py +++ b/tests/integration_tests/security/guest_token_security_tests.py @@ -55,15 +55,15 @@ def authorized_guest(self): def test_is_guest_user__regular_user(self): is_guest = security_manager.is_guest_user(security_manager.find_user("admin")) - self.assertFalse(is_guest) + assert not is_guest def test_is_guest_user__anonymous(self): is_guest = security_manager.is_guest_user(security_manager.get_anonymous_user()) - self.assertFalse(is_guest) + assert not is_guest def test_is_guest_user__guest_user(self): is_guest = security_manager.is_guest_user(self.authorized_guest()) - self.assertTrue(is_guest) + assert is_guest @patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -71,34 +71,34 @@ def test_is_guest_user__guest_user(self): ) def test_is_guest_user__flag_off(self): is_guest = security_manager.is_guest_user(self.authorized_guest()) - self.assertFalse(is_guest) + assert not is_guest def test_get_guest_user__regular_user(self): g.user = security_manager.find_user("admin") guest_user = security_manager.get_current_guest_user_if_guest() - self.assertIsNone(guest_user) + assert guest_user is None def test_get_guest_user__anonymous_user(self): g.user = security_manager.get_anonymous_user() guest_user = security_manager.get_current_guest_user_if_guest() - self.assertIsNone(guest_user) + assert guest_user is None def test_get_guest_user__guest_user(self): g.user = self.authorized_guest() guest_user = security_manager.get_current_guest_user_if_guest() - self.assertEqual(guest_user, g.user) + assert guest_user == g.user def test_get_guest_user_roles_explicit(self): guest = self.authorized_guest() roles = security_manager.get_user_roles(guest) - self.assertEqual(guest.roles, roles) + assert guest.roles == roles def test_get_guest_user_roles_implicit(self): guest = self.authorized_guest() g.user = guest roles = security_manager.get_user_roles() - self.assertEqual(guest.roles, roles) + assert guest.roles == roles @patch.dict( @@ -142,17 +142,17 @@ def setUp(self) -> None: def test_has_guest_access__regular_user(self): g.user = security_manager.find_user("admin") has_guest_access = security_manager.has_guest_access(self.dash) - self.assertFalse(has_guest_access) + assert not has_guest_access def test_has_guest_access__anonymous_user(self): g.user = security_manager.get_anonymous_user() has_guest_access = security_manager.has_guest_access(self.dash) - self.assertFalse(has_guest_access) + assert not has_guest_access def test_has_guest_access__authorized_guest_user(self): g.user = self.authorized_guest has_guest_access = security_manager.has_guest_access(self.dash) - self.assertTrue(has_guest_access) + assert has_guest_access def test_has_guest_access__authorized_guest_user__non_zero_resource_index(self): # set up a user who has authorized access, plus another resource @@ -163,7 +163,7 @@ def test_has_guest_access__authorized_guest_user__non_zero_resource_index(self): g.user = guest has_guest_access = security_manager.has_guest_access(self.dash) - self.assertTrue(has_guest_access) + assert has_guest_access def test_has_guest_access__unauthorized_guest_user__different_resource_id(self): g.user = security_manager.get_guest_user_from_token( @@ -173,14 +173,14 @@ def test_has_guest_access__unauthorized_guest_user__different_resource_id(self): } ) has_guest_access = security_manager.has_guest_access(self.dash) - self.assertFalse(has_guest_access) + assert not has_guest_access def test_has_guest_access__unauthorized_guest_user__different_resource_type(self): g.user = security_manager.get_guest_user_from_token( {"user": {}, "resources": [{"type": "dirt", "id": self.embedded.uuid}]} ) has_guest_access = security_manager.has_guest_access(self.dash) - self.assertFalse(has_guest_access) + assert not has_guest_access def test_raise_for_dashboard_access_as_guest(self): g.user = self.authorized_guest diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index ffd38bd533745..05c353fdec047 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -188,7 +188,7 @@ def test_model_view_rls_add_success(self): "clause": "client_id=1", }, ) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 rls1 = ( db.session.query(RowLevelSecurityFilter).filter_by(name="rls1") ).one_or_none() @@ -214,7 +214,7 @@ def test_model_view_rls_add_name_unique(self): "clause": "client_id=1", }, ) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 @pytest.mark.usefixtures("create_dataset") def test_model_view_rls_add_tables_required(self): @@ -231,7 +231,7 @@ def test_model_view_rls_add_tables_required(self): "clause": "client_id=1", }, ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 data = json.loads(rv.data.decode("utf-8")) assert data["message"] == {"tables": ["Shorter than minimum length 1."]} @@ -326,8 +326,8 @@ def test_invalid_role_failure(self): } rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload) status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 422) - self.assertEqual(data["message"], "[l'Some roles do not exist']") + assert status_code == 422 + assert data["message"] == "[l'Some roles do not exist']" def test_invalid_table_failure(self): self.login(ADMIN_USERNAME) @@ -340,8 +340,8 @@ def test_invalid_table_failure(self): } rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload) status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 422) - self.assertEqual(data["message"], "[l'Datasource does not exist']") + assert status_code == 422 + assert data["message"] == "[l'Datasource does not exist']" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_post_success(self): @@ -357,7 +357,7 @@ def test_post_success(self): rv = self.client.post("/api/v1/rowlevelsecurity/", json=payload) status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 201) + assert status_code == 201 rls = ( db.session.query(RowLevelSecurityFilter) @@ -366,11 +366,11 @@ def test_post_success(self): ) assert rls - self.assertEqual(rls.name, "rls 1") - self.assertEqual(rls.clause, "1=1") - self.assertEqual(rls.filter_type, "Base") - self.assertEqual(rls.tables[0].id, table.id) - self.assertEqual(rls.roles[0].id, 1) + assert rls.name == "rls 1" + assert rls.clause == "1=1" + assert rls.filter_type == "Base" + assert rls.tables[0].id == table.id + assert rls.roles[0].id == 1 db.session.delete(rls) db.session.commit() @@ -388,8 +388,8 @@ def test_invalid_id_failure(self): } rv = self.client.put("/api/v1/rowlevelsecurity/99999999", json=payload) status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 404) - self.assertEqual(data["message"], "Not found") + assert status_code == 404 + assert data["message"] == "Not found" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_invalid_role_failure(self): @@ -410,8 +410,8 @@ def test_invalid_role_failure(self): } rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload) status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 422) - self.assertEqual(data["message"], "[l'Some roles do not exist']") + assert status_code == 422 + assert data["message"] == "[l'Some roles do not exist']" db.session.delete(rls) db.session.commit() @@ -439,8 +439,8 @@ def test_invalid_table_failure(self): } rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload) status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 422) - self.assertEqual(data["message"], "[l'Datasource does not exist']") + assert status_code == 422 + assert data["message"] == "[l'Datasource does not exist']" db.session.delete(rls) db.session.commit() @@ -472,7 +472,7 @@ def test_put_success(self): rv = self.client.put(f"/api/v1/rowlevelsecurity/{rls.id}", json=payload) status_code, _data = rv.status_code, json.loads(rv.data.decode("utf-8")) # noqa: F841 - self.assertEqual(status_code, 201) + assert status_code == 201 rls = ( db.session.query(RowLevelSecurityFilter) @@ -480,11 +480,11 @@ def test_put_success(self): .one_or_none() ) - self.assertEqual(rls.name, "rls put success") - self.assertEqual(rls.clause, "2=2") - self.assertEqual(rls.filter_type, "Base") - self.assertEqual(rls.tables[0].id, tables[1].id) - self.assertEqual(rls.roles[0].id, roles[1].id) + assert rls.name == "rls put success" + assert rls.clause == "2=2" + assert rls.filter_type == "Base" + assert rls.tables[0].id == tables[1].id + assert rls.roles[0].id == roles[1].id db.session.delete(rls) db.session.commit() @@ -498,8 +498,8 @@ def test_invalid_id_failure(self): rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}") status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 404) - self.assertEqual(data["message"], "Not found") + assert status_code == 404 + assert data["message"] == "Not found" @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_energy_table_with_slice") @@ -530,8 +530,8 @@ def test_bulk_delete_success(self): rv = self.client.delete(f"/api/v1/rowlevelsecurity/?q={ids_to_delete}") status_code, data = rv.status_code, json.loads(rv.data.decode("utf-8")) - self.assertEqual(status_code, 200) - self.assertEqual(data["message"], "Deleted 2 rules") + assert status_code == 200 + assert data["message"] == "Deleted 2 rules" class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase): @@ -543,7 +543,7 @@ def test_rls_tables_related_api(self): params = prison.dumps({"page": 0, "page_size": 100}) rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) result = data["result"] @@ -561,7 +561,7 @@ def test_rls_roles_related_api(self): params = prison.dumps({"page": 0, "page_size": 100}) rv = self.client.get(f"/api/v1/rowlevelsecurity/related/roles?q={params}") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) result = data["result"] @@ -584,7 +584,7 @@ def test_table_related_filter(self): params = prison.dumps({"page": 0, "page_size": 10}) rv = self.client.get(f"/api/v1/rowlevelsecurity/related/tables?q={params}") - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) result = data["result"] received_tables = {table["text"].split(".")[-1] for table in result} @@ -664,7 +664,7 @@ def test_rls_filter_alters_query(self): tbl = self.get_table(name="birth_names") sql = tbl.get_query_str(self.query_obj) - self.assertRegex(sql, RLS_ALICE_REGEX) + assert re.search(RLS_ALICE_REGEX, sql) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_rls_filter_does_not_alter_unrelated_query(self): @@ -679,7 +679,7 @@ def test_rls_filter_does_not_alter_unrelated_query(self): tbl = self.get_table(name="birth_names") sql = tbl.get_query_str(self.query_obj) - self.assertNotRegex(sql, RLS_ALICE_REGEX) + assert not re.search(RLS_ALICE_REGEX, sql) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_multiple_rls_filters_are_unionized(self): @@ -695,8 +695,8 @@ def test_multiple_rls_filters_are_unionized(self): tbl = self.get_table(name="birth_names") sql = tbl.get_query_str(self.query_obj) - self.assertRegex(sql, RLS_ALICE_REGEX) - self.assertRegex(sql, RLS_GENDER_REGEX) + assert re.search(RLS_ALICE_REGEX, sql) + assert re.search(RLS_GENDER_REGEX, sql) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_energy_table_with_slice") @@ -709,8 +709,8 @@ def test_rls_filter_for_all_datasets(self): births_sql = births.get_query_str(self.query_obj) energy_sql = energy.get_query_str(self.query_obj) - self.assertRegex(births_sql, RLS_ALICE_REGEX) - self.assertRegex(energy_sql, RLS_ALICE_REGEX) + assert re.search(RLS_ALICE_REGEX, births_sql) + assert re.search(RLS_ALICE_REGEX, energy_sql) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_dataset_id_can_be_string(self): @@ -721,4 +721,4 @@ def test_dataset_id_can_be_string(self): ) sql = dataset.get_query_str(self.query_obj) - self.assertRegex(sql, RLS_ALICE_REGEX) + assert re.search(RLS_ALICE_REGEX, sql) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index bd76448d4899f..199c1328f1081 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -173,7 +173,7 @@ def test_after_insert_dataset(self): db.session.commit() table = db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() - self.assertEqual(table.perm, f"[tmp_db1].[tmp_perm_table](id:{table.id})") + assert table.perm == f"[tmp_db1].[tmp_perm_table](id:{table.id})" pvm_dataset = security_manager.find_permission_view_menu( "datasource_access", table.perm @@ -183,10 +183,10 @@ def test_after_insert_dataset(self): ) # Assert dataset permission is created and local perms are ok - self.assertIsNotNone(pvm_dataset) - self.assertEqual(table.perm, f"[tmp_db1].[tmp_perm_table](id:{table.id})") - self.assertEqual(table.schema_perm, "[tmp_db1].[tmp_schema]") - self.assertIsNotNone(pvm_schema) + assert pvm_dataset is not None + assert table.perm == f"[tmp_db1].[tmp_perm_table](id:{table.id})" + assert table.schema_perm == "[tmp_db1].[tmp_schema]" + assert pvm_schema is not None # assert on permission hooks call_args = security_manager.on_permission_view_after_insert.call_args @@ -220,18 +220,18 @@ def test_after_insert_dataset_rollback(self): pvm_dataset = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table](id:{table.id})" ) - self.assertIsNotNone(pvm_dataset) + assert pvm_dataset is not None table_id = table.id db.session.rollback() table = ( db.session.query(SqlaTable).filter_by(table_name="tmp_table").one_or_none() ) - self.assertIsNone(table) + assert table is None pvm_dataset = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table](id:{table_id})" ) - self.assertIsNone(pvm_dataset) + assert pvm_dataset is None db.session.delete(tmp_db1) db.session.commit() @@ -250,16 +250,18 @@ def test_after_insert_dataset_table_none(self): db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() ) # Assert permission is created - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", stored_table.perm ) + is not None ) # Assert no bogus permission is created - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[None].[tmp_perm_table](id:{stored_table.id})" ) + is None ) # Cleanup @@ -273,11 +275,11 @@ def test_after_insert_database(self): db.session.add(tmp_db1) tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() - self.assertEqual(tmp_db1.perm, f"[tmp_db1].(id:{tmp_db1.id})") + assert tmp_db1.perm == f"[tmp_db1].(id:{tmp_db1.id})" tmp_db1_pvm = security_manager.find_permission_view_menu( "database_access", tmp_db1.perm ) - self.assertIsNotNone(tmp_db1_pvm) + assert tmp_db1_pvm is not None # Assert the hook is called security_manager.on_permission_view_after_insert.assert_has_calls( @@ -298,13 +300,13 @@ def test_after_insert_database_rollback(self): pvm_database = security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) - self.assertIsNotNone(pvm_database) + assert pvm_database is not None db.session.rollback() pvm_database = security_manager.find_permission_view_menu( "database_access", f"[tmp_db1](id:{tmp_db1.id})" ) - self.assertIsNone(pvm_database) + assert pvm_database is None def test_after_update_database__perm_database_access(self): security_manager.on_view_menu_after_update = Mock() @@ -314,24 +316,27 @@ def test_after_update_database__perm_database_access(self): db.session.commit() tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu("database_access", tmp_db1.perm) + is not None ) tmp_db1.database_name = "tmp_db2" db.session.commit() # Assert that the old permission was updated - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) + is None ) # Assert that the db permission was updated - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) + is not None ) # Assert the hook is called @@ -353,37 +358,42 @@ def test_after_update_database_rollback(self): db.session.commit() tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu("database_access", tmp_db1.perm) + is not None ) tmp_db1.database_name = "tmp_db2" db.session.flush() # Assert that the old permission was updated - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) + is None ) # Assert that the db permission was updated - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) + is not None ) db.session.rollback() - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) + is not None ) # Assert that the db permission was updated - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) + is None ) db.session.delete(tmp_db1) @@ -402,24 +412,27 @@ def test_after_update_database__perm_database_access_exists(self): "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu("database_access", tmp_db1.perm) + is not None ) tmp_db1.database_name = "tmp_db2" db.session.commit() # Assert that the old permission was updated - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) + is None ) # Assert that the db permission was updated - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db2].(id:{tmp_db1.id})" ) + is not None ) security_manager.on_permission_view_after_delete.assert_has_calls( @@ -464,19 +477,21 @@ def test_after_update_database__perm_datasource_access(self): table2 = db.session.query(SqlaTable).filter_by(table_name="tmp_table2").one() # assert initial perms - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) + is not None ) - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table2](id:{table2.id})" ) + is not None ) - self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(table2.perm, f"[tmp_db1].[tmp_table2](id:{table2.id})") + assert slice1.perm == f"[tmp_db1].[tmp_table1](id:{table1.id})" + assert table1.perm == f"[tmp_db1].[tmp_table1](id:{table1.id})" + assert table2.perm == f"[tmp_db1].[tmp_table2](id:{table2.id})" # Refresh and update the database name tmp_db1 = db.session.query(Database).filter_by(database_name="tmp_db1").one() @@ -484,31 +499,35 @@ def test_after_update_database__perm_datasource_access(self): db.session.commit() # Assert that the old permissions were updated - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) + is None ) - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table2](id:{table2.id})" ) + is None ) # Assert that the db permission was updated - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db2].[tmp_table1](id:{table1.id})" ) + is not None ) - self.assertIsNotNone( + assert ( security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db2].[tmp_table2](id:{table2.id})" ) + is not None ) - self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual(table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual(table2.perm, f"[tmp_db2].[tmp_table2](id:{table2.id})") + assert slice1.perm == f"[tmp_db2].[tmp_table1](id:{table1.id})" + assert table1.perm == f"[tmp_db2].[tmp_table1](id:{table1.id})" + assert table2.perm == f"[tmp_db2].[tmp_table2](id:{table2.id})" # Assert hooks are called tmp_db1_view_menu = security_manager.find_view_menu( @@ -543,7 +562,7 @@ def test_after_delete_database(self): database_pvm = security_manager.find_permission_view_menu( "database_access", tmp_db1.perm ) - self.assertIsNotNone(database_pvm) + assert database_pvm is not None role1 = Role(name="tmp_role1") role1.permissions.append(database_pvm) db.session.add(role1) @@ -554,13 +573,14 @@ def test_after_delete_database(self): # Assert that PVM is removed from Role role1 = security_manager.find_role("tmp_role1") - self.assertEqual(role1.permissions, []) + assert role1.permissions == [] # Assert that the old permission was updated - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) + is None ) # Cleanup @@ -576,7 +596,7 @@ def test_after_delete_database_rollback(self): database_pvm = security_manager.find_permission_view_menu( "database_access", tmp_db1.perm ) - self.assertIsNotNone(database_pvm) + assert database_pvm is not None role1 = Role(name="tmp_role1") role1.permissions.append(database_pvm) db.session.add(role1) @@ -586,12 +606,13 @@ def test_after_delete_database_rollback(self): db.session.flush() role1 = security_manager.find_role("tmp_role1") - self.assertEqual(role1.permissions, []) + assert role1.permissions == [] - self.assertIsNone( + assert ( security_manager.find_permission_view_menu( "database_access", f"[tmp_db1].(id:{tmp_db1.id})" ) + is None ) db.session.rollback() @@ -602,7 +623,7 @@ def test_after_delete_database_rollback(self): ) role1 = security_manager.find_role("tmp_role1") - self.assertEqual(role1.permissions, [database_pvm]) + assert role1.permissions == [database_pvm] # Cleanup db.session.delete(role1) @@ -627,7 +648,7 @@ def test_after_delete_dataset(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None role1 = Role(name="tmp_role1") role1.permissions.append(table1_pvm) @@ -642,16 +663,16 @@ def test_after_delete_dataset(self): db.session.commit() role1 = security_manager.find_role("tmp_role1") - self.assertEqual(role1.permissions, []) + assert role1.permissions == [] table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(table1_pvm) + assert table1_pvm is None table1_view_menu = security_manager.find_view_menu( f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(table1_view_menu) + assert table1_view_menu is None # Assert the hook is called security_manager.on_permission_view_after_delete.assert_has_calls( @@ -681,7 +702,7 @@ def test_after_delete_dataset_rollback(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None role1 = Role(name="tmp_role1") role1.permissions.append(table1_pvm) @@ -696,12 +717,12 @@ def test_after_delete_dataset_rollback(self): db.session.flush() role1 = security_manager.find_role("tmp_role1") - self.assertEqual(role1.permissions, []) + assert role1.permissions == [] table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(table1_pvm) + assert table1_pvm is None # Test rollback, permissions exist everything is correctly rollback db.session.rollback() @@ -709,8 +730,8 @@ def test_after_delete_dataset_rollback(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) - self.assertEqual(role1.permissions, [table1_pvm]) + assert table1_pvm is not None + assert role1.permissions == [table1_pvm] # cleanup db.session.delete(table1) @@ -745,7 +766,7 @@ def test_after_update_dataset__name_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # refresh table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() @@ -757,25 +778,23 @@ def test_after_update_dataset__name_changes(self): old_table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(old_table1_pvm) + assert old_table1_pvm is None # Test new permission exist new_table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1_changed](id:{table1.id})" ) - self.assertIsNotNone(new_table1_pvm) + assert new_table1_pvm is not None # test dataset permission changed changed_table1 = ( db.session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() ) - self.assertEqual( - changed_table1.perm, f"[tmp_db].[tmp_table1_changed](id:{table1.id})" - ) + assert changed_table1.perm == f"[tmp_db].[tmp_table1_changed](id:{table1.id})" # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() - self.assertEqual(slice1.perm, f"[tmp_db].[tmp_table1_changed](id:{table1.id})") + assert slice1.perm == f"[tmp_db].[tmp_table1_changed](id:{table1.id})" # Assert hook is called view_menu_dataset = security_manager.find_view_menu( @@ -824,13 +843,13 @@ def test_after_update_dataset_rollback(self): old_table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(old_table1_pvm) + assert old_table1_pvm is None # Test new permission exist new_table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1_changed](id:{table1.id})" ) - self.assertIsNotNone(new_table1_pvm) + assert new_table1_pvm is not None # Test rollback db.session.rollback() @@ -838,7 +857,7 @@ def test_after_update_dataset_rollback(self): old_table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(old_table1_pvm) + assert old_table1_pvm is not None # cleanup db.session.delete(slice1) @@ -873,7 +892,7 @@ def test_after_update_dataset__db_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # refresh table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() @@ -885,25 +904,25 @@ def test_after_update_dataset__db_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(table1_pvm) + assert table1_pvm is None # Test new permission exist table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db2].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # test dataset permission and schema permission changed changed_table1 = ( db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) - self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 + assert changed_table1.perm == f"[tmp_db2].[tmp_table1](id:{table1.id})" + assert changed_table1.schema_perm == "[tmp_db2].[tmp_schema]" # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() - self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual(slice1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541 + assert slice1.perm == f"[tmp_db2].[tmp_table1](id:{table1.id})" + assert slice1.schema_perm == f"[tmp_db2].[tmp_schema]" # noqa: F541 # cleanup db.session.delete(slice1) @@ -937,7 +956,7 @@ def test_after_update_dataset__schema_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # refresh table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() @@ -949,19 +968,19 @@ def test_after_update_dataset__schema_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # test dataset schema permission changed changed_table1 = ( db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) - self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 + assert changed_table1.perm == f"[tmp_db1].[tmp_table1](id:{table1.id})" + assert changed_table1.schema_perm == "[tmp_db1].[tmp_schema_changed]" # noqa: F541 # Test Chart schema permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() - self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 + assert slice1.perm == f"[tmp_db1].[tmp_table1](id:{table1.id})" + assert slice1.schema_perm == "[tmp_db1].[tmp_schema_changed]" # noqa: F541 # cleanup db.session.delete(slice1) @@ -994,7 +1013,7 @@ def test_after_update_dataset__schema_none(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # refresh table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() @@ -1005,8 +1024,8 @@ def test_after_update_dataset__schema_none(self): # refresh table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() - self.assertEqual(table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertIsNone(table1.schema_perm) + assert table1.perm == f"[tmp_db1].[tmp_table1](id:{table1.id})" + assert table1.schema_perm is None # cleanup db.session.delete(slice1) @@ -1041,7 +1060,7 @@ def test_after_update_dataset__name_db_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # refresh table1 = db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() @@ -1054,27 +1073,25 @@ def test_after_update_dataset__name_db_changes(self): table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db1].[tmp_table1](id:{table1.id})" ) - self.assertIsNone(table1_pvm) + assert table1_pvm is None # Test new permission exist table1_pvm = security_manager.find_permission_view_menu( "datasource_access", f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" ) - self.assertIsNotNone(table1_pvm) + assert table1_pvm is not None # test dataset permission and schema permission changed changed_table1 = ( db.session.query(SqlaTable).filter_by(table_name="tmp_table1_changed").one() ) - self.assertEqual( - changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" - ) - self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 + assert changed_table1.perm == f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" + assert changed_table1.schema_perm == "[tmp_db2].[tmp_schema]" # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() - self.assertEqual(slice1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})") - self.assertEqual(slice1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541 + assert slice1.perm == f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" + assert slice1.schema_perm == f"[tmp_db2].[tmp_schema]" # noqa: F541 # cleanup db.session.delete(slice1) @@ -1100,9 +1117,9 @@ def test_hybrid_perm_database(self): .one() ) - self.assertEqual(record.get_perm(), record.perm) - self.assertEqual(record.id, id_) - self.assertEqual(record.database_name, "tmp_database3") + assert record.get_perm() == record.perm + assert record.id == id_ + assert record.database_name == "tmp_database3" db.session.delete(database) db.session.commit() @@ -1124,10 +1141,10 @@ def test_set_perm_slice(self): db.session.commit() slice = db.session.query(Slice).filter_by(slice_name="slice_name").one() - self.assertEqual(slice.perm, table.perm) - self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") - self.assertEqual(slice.schema_perm, table.schema_perm) - self.assertIsNone(slice.schema_perm) + assert slice.perm == table.perm + assert slice.perm == f"[tmp_database].[tmp_perm_table](id:{table.id})" + assert slice.schema_perm == table.schema_perm + assert slice.schema_perm is None table.schema = "tmp_perm_schema" table.table_name = "tmp_perm_table_v2" @@ -1135,15 +1152,11 @@ def test_set_perm_slice(self): table = ( db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() ) - self.assertEqual(slice.perm, table.perm) - self.assertEqual( - slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" - ) - self.assertEqual( - table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" - ) - self.assertEqual(slice.schema_perm, table.schema_perm) - self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") + assert slice.perm == table.perm + assert slice.perm == f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" + assert table.perm == f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" + assert slice.schema_perm == table.schema_perm + assert slice.schema_perm == "[tmp_database].[tmp_perm_schema]" db.session.delete(slice) db.session.delete(table) @@ -1160,7 +1173,7 @@ def test_schemas_accessible_by_user_admin(self, mock_sm_g, mock_g): schemas = security_manager.get_schemas_accessible_by_user( database, None, {"1", "2", "3"} ) - self.assertEqual(schemas, {"1", "2", "3"}) # no changes + assert schemas == {"1", "2", "3"} # no changes @patch("superset.utils.core.g") @patch("superset.security.manager.g") @@ -1174,7 +1187,7 @@ def test_schemas_accessible_by_user_schema_access(self, mock_sm_g, mock_g): database, None, {"1", "2", "3"} ) # temp_schema is not passed in the params - self.assertEqual(schemas, {"1"}) + assert schemas == {"1"} delete_schema_perm("[examples].[1]") def test_schemas_accessible_by_user_datasource_access(self): @@ -1185,7 +1198,7 @@ def test_schemas_accessible_by_user_datasource_access(self): schemas = security_manager.get_schemas_accessible_by_user( database, None, {"temp_schema", "2", "3"} ) - self.assertEqual(schemas, {"temp_schema"}) + assert schemas == {"temp_schema"} def test_schemas_accessible_by_user_datasource_and_schema_access(self): # User has schema access to the datasource temp_schema.wb_health_population in examples DB. @@ -1196,11 +1209,11 @@ def test_schemas_accessible_by_user_datasource_and_schema_access(self): schemas = security_manager.get_schemas_accessible_by_user( database, None, {"temp_schema", "2", "3"} ) - self.assertEqual(schemas, {"temp_schema", "2"}) + assert schemas == {"temp_schema", "2"} vm = security_manager.find_permission_view_menu( "schema_access", "[examples].[2]" ) - self.assertIsNotNone(vm) + assert vm is not None delete_schema_perm("[examples].[2]") @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @@ -1211,8 +1224,8 @@ def test_gamma_user_schema_access_to_dashboards(self): self.login(GAMMA_USERNAME) data = str(self.client.get("api/v1/dashboard/").data) - self.assertIn("/superset/dashboard/world_health/", data) - self.assertNotIn("/superset/dashboard/births/", data) + assert "/superset/dashboard/world_health/" in data + assert "/superset/dashboard/births/" not in data @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("public_role_like_gamma") @@ -1276,40 +1289,40 @@ def test_sqllab_gamma_user_schema_access_to_sqllab(self): NEW_FLASK_GET_SQL_DBS_REQUEST = f"/api/v1/database/?q={prison.dumps(arguments)}" self.login(GAMMA_USERNAME) databases_json = self.client.get(NEW_FLASK_GET_SQL_DBS_REQUEST).json - self.assertEqual(databases_json["count"], 1) + assert databases_json["count"] == 1 def assert_can_read(self, view_menu, permissions_set): if view_menu in NEW_SECURITY_CONVERGE_VIEWS: - self.assertIn(("can_read", view_menu), permissions_set) + assert ("can_read", view_menu) in permissions_set else: - self.assertIn(("can_list", view_menu), permissions_set) + assert ("can_list", view_menu) in permissions_set def assert_can_write(self, view_menu, permissions_set): if view_menu in NEW_SECURITY_CONVERGE_VIEWS: - self.assertIn(("can_write", view_menu), permissions_set) + assert ("can_write", view_menu) in permissions_set else: - self.assertIn(("can_add", view_menu), permissions_set) - self.assertIn(("can_delete", view_menu), permissions_set) - self.assertIn(("can_edit", view_menu), permissions_set) + assert ("can_add", view_menu) in permissions_set + assert ("can_delete", view_menu) in permissions_set + assert ("can_edit", view_menu) in permissions_set def assert_cannot_write(self, view_menu, permissions_set): if view_menu in NEW_SECURITY_CONVERGE_VIEWS: - self.assertNotIn(("can_write", view_menu), permissions_set) + assert ("can_write", view_menu) not in permissions_set else: - self.assertNotIn(("can_add", view_menu), permissions_set) - self.assertNotIn(("can_delete", view_menu), permissions_set) - self.assertNotIn(("can_edit", view_menu), permissions_set) - self.assertNotIn(("can_save", view_menu), permissions_set) + assert ("can_add", view_menu) not in permissions_set + assert ("can_delete", view_menu) not in permissions_set + assert ("can_edit", view_menu) not in permissions_set + assert ("can_save", view_menu) not in permissions_set def assert_can_all(self, view_menu, permissions_set): self.assert_can_read(view_menu, permissions_set) self.assert_can_write(view_menu, permissions_set) def assert_can_menu(self, view_menu, permissions_set): - self.assertIn(("menu_access", view_menu), permissions_set) + assert ("menu_access", view_menu) in permissions_set def assert_cannot_menu(self, view_menu, permissions_set): - self.assertNotIn(("menu_access", view_menu), permissions_set) + assert ("menu_access", view_menu) not in permissions_set def assert_cannot_gamma(self, perm_set): self.assert_cannot_write("Annotation", perm_set) @@ -1323,7 +1336,7 @@ def assert_cannot_gamma(self, perm_set): self.assert_cannot_menu("Upload a CSV", perm_set) self.assert_cannot_menu("ReportSchedule", perm_set) self.assert_cannot_menu("Alerts & Report", perm_set) - self.assertNotIn(("can_csv_upload", "Database"), perm_set) + assert ("can_csv_upload", "Database") not in perm_set def assert_can_gamma(self, perm_set): self.assert_can_read("Dataset", perm_set) @@ -1331,16 +1344,16 @@ def assert_can_gamma(self, perm_set): # make sure that user can create slices and dashboards self.assert_can_all("Dashboard", perm_set) self.assert_can_all("Chart", perm_set) - self.assertIn(("can_csv", "Superset"), perm_set) - self.assertIn(("can_dashboard", "Superset"), perm_set) - self.assertIn(("can_explore", "Superset"), perm_set) - self.assertIn(("can_share_chart", "Superset"), perm_set) - self.assertIn(("can_share_dashboard", "Superset"), perm_set) - self.assertIn(("can_explore_json", "Superset"), perm_set) - self.assertIn(("can_explore_json", "Superset"), perm_set) - self.assertIn(("can_userinfo", "UserDBModelView"), perm_set) - self.assertIn(("can_view_chart_as_table", "Dashboard"), perm_set) - self.assertIn(("can_view_query", "Dashboard"), perm_set) + assert ("can_csv", "Superset") in perm_set + assert ("can_dashboard", "Superset") in perm_set + assert ("can_explore", "Superset") in perm_set + assert ("can_share_chart", "Superset") in perm_set + assert ("can_share_dashboard", "Superset") in perm_set + assert ("can_explore_json", "Superset") in perm_set + assert ("can_explore_json", "Superset") in perm_set + assert ("can_userinfo", "UserDBModelView") in perm_set + assert ("can_view_chart_as_table", "Dashboard") in perm_set + assert ("can_view_query", "Dashboard") in perm_set self.assert_can_menu("Databases", perm_set) self.assert_can_menu("Datasets", perm_set) self.assert_can_menu("Data", perm_set) @@ -1352,11 +1365,11 @@ def assert_can_alpha(self, perm_set): self.assert_can_all("CssTemplate", perm_set) self.assert_can_all("Dataset", perm_set) self.assert_can_read("Database", perm_set) - self.assertIn(("can_csv_upload", "Database"), perm_set) + assert ("can_csv_upload", "Database") in perm_set self.assert_can_menu("Manage", perm_set) self.assert_can_menu("Annotation Layers", perm_set) self.assert_can_menu("CSS Templates", perm_set) - self.assertIn(("all_datasource_access", "all_datasource_access"), perm_set) + assert ("all_datasource_access", "all_datasource_access") in perm_set def assert_cannot_alpha(self, perm_set): self.assert_cannot_write("Queries", perm_set) @@ -1368,76 +1381,56 @@ def assert_can_admin(self, perm_set): self.assert_can_all("Database", perm_set) self.assert_can_all("RoleModelView", perm_set) self.assert_can_all("UserDBModelView", perm_set) - self.assertIn(("all_database_access", "all_database_access"), perm_set) + assert ("all_database_access", "all_database_access") in perm_set self.assert_can_menu("Security", perm_set) self.assert_can_menu("List Users", perm_set) self.assert_can_menu("List Roles", perm_set) def test_is_admin_only(self): - self.assertFalse( - security_manager._is_admin_only( - security_manager.find_permission_view_menu("can_read", "Dataset") - ) + assert not security_manager._is_admin_only( + security_manager.find_permission_view_menu("can_read", "Dataset") ) - self.assertFalse( - security_manager._is_admin_only( - security_manager.find_permission_view_menu( - "all_datasource_access", "all_datasource_access" - ) + assert not security_manager._is_admin_only( + security_manager.find_permission_view_menu( + "all_datasource_access", "all_datasource_access" ) ) log_permissions = ["can_read"] for log_permission in log_permissions: - self.assertTrue( - security_manager._is_admin_only( - security_manager.find_permission_view_menu(log_permission, "Log") - ) + assert security_manager._is_admin_only( + security_manager.find_permission_view_menu(log_permission, "Log") ) - self.assertTrue( - security_manager._is_admin_only( - security_manager.find_permission_view_menu( - "can_edit", "UserDBModelView" - ) - ) + assert security_manager._is_admin_only( + security_manager.find_permission_view_menu("can_edit", "UserDBModelView") ) @unittest.skipUnless( SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" ) def test_is_alpha_only(self): - self.assertFalse( - security_manager._is_alpha_only( - security_manager.find_permission_view_menu("can_read", "Dataset") - ) + assert not security_manager._is_alpha_only( + security_manager.find_permission_view_menu("can_read", "Dataset") ) - self.assertTrue( - security_manager._is_alpha_only( - security_manager.find_permission_view_menu("can_write", "Dataset") - ) + assert security_manager._is_alpha_only( + security_manager.find_permission_view_menu("can_write", "Dataset") ) - self.assertTrue( - security_manager._is_alpha_only( - security_manager.find_permission_view_menu( - "all_datasource_access", "all_datasource_access" - ) + assert security_manager._is_alpha_only( + security_manager.find_permission_view_menu( + "all_datasource_access", "all_datasource_access" ) ) - self.assertTrue( - security_manager._is_alpha_only( - security_manager.find_permission_view_menu( - "all_database_access", "all_database_access" - ) + assert security_manager._is_alpha_only( + security_manager.find_permission_view_menu( + "all_database_access", "all_database_access" ) ) def test_is_gamma_pvm(self): - self.assertTrue( - security_manager._is_gamma_pvm( - security_manager.find_permission_view_menu("can_read", "Dataset") - ) + assert security_manager._is_gamma_pvm( + security_manager.find_permission_view_menu("can_read", "Dataset") ) def test_gamma_permissions_basic(self): @@ -1457,8 +1450,8 @@ def test_alpha_permissions(self): self.assert_can_gamma(alpha_perm_tuples) self.assert_can_alpha(alpha_perm_tuples) self.assert_cannot_alpha(alpha_perm_tuples) - self.assertNotIn(("can_this_form_get", "UserInfoEditView"), alpha_perm_tuples) - self.assertNotIn(("can_this_form_post", "UserInfoEditView"), alpha_perm_tuples) + assert ("can_this_form_get", "UserInfoEditView") not in alpha_perm_tuples + assert ("can_this_form_post", "UserInfoEditView") not in alpha_perm_tuples @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_admin_permissions(self): @@ -1471,34 +1464,31 @@ def test_admin_permissions(self): def test_sql_lab_permissions(self): sql_lab_set = get_perm_tuples("sql_lab") - self.assertEqual( - sql_lab_set, - { - ("can_activate", "TabStateView"), - ("can_csv", "Superset"), - ("can_delete_query", "TabStateView"), - ("can_delete", "TabStateView"), - ("can_execute_sql_query", "SQLLab"), - ("can_export", "SavedQuery"), - ("can_export_csv", "SQLLab"), - ("can_get", "TabStateView"), - ("can_get_results", "SQLLab"), - ("can_migrate_query", "TabStateView"), - ("can_sqllab", "Superset"), - ("can_sqllab_history", "Superset"), - ("can_put", "TabStateView"), - ("can_post", "TabStateView"), - ("can_write", "SavedQuery"), - ("can_read", "Query"), - ("can_read", "Database"), - ("can_read", "SQLLab"), - ("can_read", "SavedQuery"), - ("menu_access", "Query Search"), - ("menu_access", "Saved Queries"), - ("menu_access", "SQL Editor"), - ("menu_access", "SQL Lab"), - }, - ) + assert sql_lab_set == { + ("can_activate", "TabStateView"), + ("can_csv", "Superset"), + ("can_delete_query", "TabStateView"), + ("can_delete", "TabStateView"), + ("can_execute_sql_query", "SQLLab"), + ("can_export", "SavedQuery"), + ("can_export_csv", "SQLLab"), + ("can_get", "TabStateView"), + ("can_get_results", "SQLLab"), + ("can_migrate_query", "TabStateView"), + ("can_sqllab", "Superset"), + ("can_sqllab_history", "Superset"), + ("can_put", "TabStateView"), + ("can_post", "TabStateView"), + ("can_write", "SavedQuery"), + ("can_read", "Query"), + ("can_read", "Database"), + ("can_read", "SQLLab"), + ("can_read", "SavedQuery"), + ("menu_access", "Query Search"), + ("menu_access", "Saved Queries"), + ("menu_access", "SQL Editor"), + ("menu_access", "SQL Lab"), + } self.assert_cannot_alpha(sql_lab_set) @@ -1519,15 +1509,15 @@ def test_gamma_permissions(self): self.assert_cannot_write("UserDBModelView", gamma_perm_set) self.assert_cannot_write("RoleModelView", gamma_perm_set) - self.assertIn(("can_csv", "Superset"), gamma_perm_set) - self.assertIn(("can_dashboard", "Superset"), gamma_perm_set) - self.assertIn(("can_explore", "Superset"), gamma_perm_set) - self.assertIn(("can_share_chart", "Superset"), gamma_perm_set) - self.assertIn(("can_share_dashboard", "Superset"), gamma_perm_set) - self.assertIn(("can_explore_json", "Superset"), gamma_perm_set) - self.assertIn(("can_userinfo", "UserDBModelView"), gamma_perm_set) - self.assertIn(("can_view_chart_as_table", "Dashboard"), gamma_perm_set) - self.assertIn(("can_view_query", "Dashboard"), gamma_perm_set) + assert ("can_csv", "Superset") in gamma_perm_set + assert ("can_dashboard", "Superset") in gamma_perm_set + assert ("can_explore", "Superset") in gamma_perm_set + assert ("can_share_chart", "Superset") in gamma_perm_set + assert ("can_share_dashboard", "Superset") in gamma_perm_set + assert ("can_explore_json", "Superset") in gamma_perm_set + assert ("can_userinfo", "UserDBModelView") in gamma_perm_set + assert ("can_view_chart_as_table", "Dashboard") in gamma_perm_set + assert ("can_view_query", "Dashboard") in gamma_perm_set def test_views_are_secured(self): """Preventing the addition of unsecured views without has_access decorator""" @@ -1583,7 +1573,7 @@ def test_can_access_datasource(self, mock_raise_for_access): datasource = self.get_datasource_mock() mock_raise_for_access.return_value = None - self.assertTrue(security_manager.can_access_datasource(datasource=datasource)) + assert security_manager.can_access_datasource(datasource=datasource) mock_raise_for_access.side_effect = SupersetSecurityException( SupersetError( @@ -1593,7 +1583,7 @@ def test_can_access_datasource(self, mock_raise_for_access): ) ) - self.assertFalse(security_manager.can_access_datasource(datasource=datasource)) + assert not security_manager.can_access_datasource(datasource=datasource) @patch("superset.security.SupersetSecurityManager.raise_for_access") def test_can_access_table(self, mock_raise_for_access): @@ -1601,7 +1591,7 @@ def test_can_access_table(self, mock_raise_for_access): table = Table("bar", "foo") mock_raise_for_access.return_value = None - self.assertTrue(security_manager.can_access_table(database, table)) + assert security_manager.can_access_table(database, table) mock_raise_for_access.side_effect = SupersetSecurityException( SupersetError( @@ -1609,7 +1599,7 @@ def test_can_access_table(self, mock_raise_for_access): ) ) - self.assertFalse(security_manager.can_access_table(database, table)) + assert not security_manager.can_access_table(database, table) @patch("superset.security.SupersetSecurityManager.is_owner") @patch("superset.security.SupersetSecurityManager.can_access") @@ -1883,12 +1873,12 @@ def test_get_user_roles(self): admin = security_manager.find_user("admin") with override_user(admin): roles = security_manager.get_user_roles() - self.assertEqual(admin.roles, roles) + assert admin.roles == roles def test_get_anonymous_roles(self): with override_user(security_manager.get_anonymous_user()): roles = security_manager.get_user_roles() - self.assertEqual([security_manager.get_public_role()], roles) + assert [security_manager.get_public_role()] == roles def test_all_database_access(self): gamma_user = security_manager.find_user(username="gamma") @@ -2011,14 +2001,13 @@ def test_create_guest_access_token(self, get_time_mock): audience=aud, ) - self.assertEqual(user, decoded_token["user"]) - self.assertEqual(resources, decoded_token["resources"]) - self.assertEqual(now, decoded_token["iat"]) - self.assertEqual(aud, decoded_token["aud"]) - self.assertEqual("guest", decoded_token["type"]) - self.assertEqual( - now + (self.app.config["GUEST_TOKEN_JWT_EXP_SECONDS"]), - decoded_token["exp"], + assert user == decoded_token["user"] + assert resources == decoded_token["resources"] + assert now == decoded_token["iat"] + assert aud == decoded_token["aud"] + assert "guest" == decoded_token["type"] + assert ( + now + self.app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] == decoded_token["exp"] ) def test_get_guest_user(self): @@ -2028,8 +2017,8 @@ def test_get_guest_user(self): guest_user = security_manager.get_guest_user_from_request(fake_request) - self.assertIsNotNone(guest_user) - self.assertEqual("test_guest", guest_user.username) + assert guest_user is not None + assert "test_guest" == guest_user.username def test_get_guest_user_with_request_form(self): token = self.create_guest_token() @@ -2039,8 +2028,8 @@ def test_get_guest_user_with_request_form(self): guest_user = security_manager.get_guest_user_from_request(fake_request) - self.assertIsNotNone(guest_user) - self.assertEqual("test_guest", guest_user.username) + assert guest_user is not None + assert "test_guest" == guest_user.username @patch("superset.security.SupersetSecurityManager._get_current_epoch_time") def test_get_guest_user_expired_token(self, get_time_mock): @@ -2054,7 +2043,7 @@ def test_get_guest_user_expired_token(self, get_time_mock): guest_user = security_manager.get_guest_user_from_request(fake_request) - self.assertIsNone(guest_user) + assert guest_user is None def test_get_guest_user_no_user(self): user = None @@ -2065,7 +2054,7 @@ def test_get_guest_user_no_user(self): fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token guest_user = security_manager.get_guest_user_from_request(fake_request) - self.assertIsNone(guest_user) + assert guest_user is None self.assertRaisesRegex(ValueError, "Guest token does not contain a user claim") def test_get_guest_user_no_resource(self): @@ -2105,7 +2094,7 @@ def test_get_guest_user_not_guest_type(self): fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token guest_user = security_manager.get_guest_user_from_request(fake_request) - self.assertIsNone(guest_user) + assert guest_user is None self.assertRaisesRegex(ValueError, "This is not a guest token.") def test_get_guest_user_bad_audience(self): @@ -2133,7 +2122,7 @@ def test_get_guest_user_bad_audience(self): guest_user = security_manager.get_guest_user_from_request(fake_request) self.assertRaisesRegex(jwt.exceptions.InvalidAudienceError, "Invalid audience") - self.assertIsNone(guest_user) + assert guest_user is None @patch("superset.security.SupersetSecurityManager._get_current_epoch_time") def test_create_guest_access_token_callable_audience(self, get_time_mock): @@ -2153,6 +2142,6 @@ def test_create_guest_access_token_callable_audience(self, get_time_mock): audience="cool_code", ) app.config["GUEST_TOKEN_JWT_AUDIENCE"].assert_called_once() - self.assertEqual("cool_code", decoded_token["aud"]) - self.assertEqual("guest", decoded_token["type"]) + assert "cool_code" == decoded_token["aud"] + assert "guest" == decoded_token["type"] app.config["GUEST_TOKEN_JWT_AUDIENCE"] = None diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index e6c1f686b13c7..19d6e56fb6441 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -67,7 +67,7 @@ def test_get_from_empty_bootstrap_data(self): result = data.get("result") assert result["active_tab"] is None # noqa: E711 assert result["tab_state_ids"] == [] - self.assertEqual(len(result["databases"]), 0) + assert len(result["databases"]) == 0 @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -126,7 +126,7 @@ def test_get_from_bootstrap_data_with_latest_query(self): # associated with any tabs resp = self.get_json_resp("/api/v1/sqllab/") result = resp["result"] - self.assertEqual(result["active_tab"]["id"], tab_state_id) + assert result["active_tab"]["id"] == tab_state_id @mock.patch.dict( "superset.extensions.feature_flag_manager._feature_flags", @@ -220,8 +220,8 @@ def test_estimate_required_params(self): } } resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, failed_resp) - self.assertEqual(rv.status_code, 400) + self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 + assert rv.status_code == 400 data = {"sql": "SELECT 1"} rv = self.client.post( @@ -230,8 +230,8 @@ def test_estimate_required_params(self): ) failed_resp = {"message": {"database_id": ["Missing data for required field."]}} resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, failed_resp) - self.assertEqual(rv.status_code, 400) + self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 + assert rv.status_code == 400 data = {"database_id": 1} rv = self.client.post( @@ -240,8 +240,8 @@ def test_estimate_required_params(self): ) failed_resp = {"message": {"sql": ["Missing data for required field."]}} resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, failed_resp) - self.assertEqual(rv.status_code, 400) + self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 + assert rv.status_code == 400 def test_estimate_valid_request(self): self.login(ADMIN_USERNAME) @@ -270,8 +270,8 @@ def test_estimate_valid_request(self): success_resp = {"result": formatter_response} resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, success_resp) - self.assertEqual(rv.status_code, 200) + self.assertDictEqual(resp_data, success_resp) # noqa: PT009 + assert rv.status_code == 200 def test_format_sql_request(self): self.login(ADMIN_USERNAME) @@ -283,8 +283,8 @@ def test_format_sql_request(self): ) success_resp = {"result": "SELECT\n 1\nFROM my_table"} resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, success_resp) - self.assertEqual(rv.status_code, 200) + self.assertDictEqual(resp_data, success_resp) # noqa: PT009 + assert rv.status_code == 200 @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False) def test_execute_required_params(self): @@ -303,8 +303,8 @@ def test_execute_required_params(self): } } resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, failed_resp) - self.assertEqual(rv.status_code, 400) + self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 + assert rv.status_code == 400 data = {"sql": "SELECT 1", "client_id": client_id} rv = self.client.post( @@ -313,8 +313,8 @@ def test_execute_required_params(self): ) failed_resp = {"message": {"database_id": ["Missing data for required field."]}} resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, failed_resp) - self.assertEqual(rv.status_code, 400) + self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 + assert rv.status_code == 400 data = {"database_id": 1, "client_id": client_id} rv = self.client.post( @@ -323,8 +323,8 @@ def test_execute_required_params(self): ) failed_resp = {"message": {"sql": ["Missing data for required field."]}} resp_data = json.loads(rv.data.decode("utf-8")) - self.assertDictEqual(resp_data, failed_resp) - self.assertEqual(rv.status_code, 400) + self.assertDictEqual(resp_data, failed_resp) # noqa: PT009 + assert rv.status_code == 400 @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False) def test_execute_valid_request(self) -> None: @@ -342,8 +342,8 @@ def test_execute_valid_request(self) -> None: json=data, ) resp_data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(resp_data.get("status"), "success") - self.assertEqual(rv.status_code, 200) + assert resp_data.get("status") == "success" + assert rv.status_code == 200 @mock.patch( "tests.integration_tests.superset_test_custom_template_processors.datetime" @@ -366,7 +366,7 @@ def test_execute_custom_templated(self, sql_lab_mock, mock_dt) -> None: "/api/v1/sqllab/execute/", raise_on_error=False, json_=json_payload ) assert sql_lab_mock.called - self.assertEqual(sql_lab_mock.call_args[0][1], "SELECT '1970-01-01' as test") + assert sql_lab_mock.call_args[0][1] == "SELECT '1970-01-01' as test" self.delete_fake_db_for_macros() @@ -419,8 +419,8 @@ def test_get_results_with_display_limit(self): self.get_resp(f"/api/v1/sqllab/results/?q={prison.dumps(arguments)}") ) - self.assertEqual(result_key, expected_key) - self.assertEqual(result_limited, expected_limited) + assert result_key == expected_key + assert result_limited == expected_limited app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack @@ -454,6 +454,6 @@ def test_export_results(self, get_df_mock: mock.Mock) -> None: data = csv.reader(io.StringIO(resp)) expected_data = csv.reader(io.StringIO("foo\n1\n2")) - self.assertEqual(list(expected_data), list(data)) + assert list(expected_data) == list(data) db.session.delete(query_obj) db.session.commit() diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index c286bf3a438bd..901d667810e7b 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -58,7 +58,7 @@ def test_validator_success(self, flask_g): errors = self.validator.validate(sql, None, schema, self.database) - self.assertEqual([], errors) + assert [] == errors @patch("superset.utils.core.g") def test_validator_db_error(self, flask_g): @@ -95,7 +95,7 @@ def test_validator_query_error(self, flask_g): errors = self.validator.validate(sql, None, schema, self.database) - self.assertEqual(1, len(errors)) + assert 1 == len(errors) class TestPostgreSQLValidator(SupersetTestCase): diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 1b5245568813c..fb03f37e62c7b 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -83,12 +83,12 @@ def test_is_time_druid_time_col(self): database = Database(database_name="druid_db", sqlalchemy_uri="druid://db") tbl = SqlaTable(table_name="druid_tbl", database=database) col = TableColumn(column_name="__time", type="INTEGER", table=tbl) - self.assertEqual(col.is_dttm, None) + assert col.is_dttm is None DruidEngineSpec.alter_new_orm_column(col) - self.assertEqual(col.is_dttm, True) + assert col.is_dttm is True col = TableColumn(column_name="__not_time", type="INTEGER", table=tbl) - self.assertEqual(col.is_temporal, False) + assert col.is_temporal is False def test_temporal_varchar(self): """Ensure a column with is_dttm set to true evaluates to is_temporal == True""" @@ -125,13 +125,13 @@ def test_db_column_types(self): tbl = SqlaTable(table_name="col_type_test_tbl", database=get_example_database()) for str_type, db_col_type in test_cases.items(): col = TableColumn(column_name="foo", type=str_type, table=tbl) - self.assertEqual(col.is_temporal, db_col_type == GenericDataType.TEMPORAL) - self.assertEqual(col.is_numeric, db_col_type == GenericDataType.NUMERIC) - self.assertEqual(col.is_string, db_col_type == GenericDataType.STRING) + assert col.is_temporal == (db_col_type == GenericDataType.TEMPORAL) + assert col.is_numeric == (db_col_type == GenericDataType.NUMERIC) + assert col.is_string == (db_col_type == GenericDataType.STRING) for str_type, db_col_type in test_cases.items(): col = TableColumn(column_name="foo", type=str_type, table=tbl, is_dttm=True) - self.assertTrue(col.is_temporal) + assert col.is_temporal @patch("superset.jinja_context.get_user_id", return_value=1) @patch("superset.jinja_context.get_username", return_value="abc") @@ -161,7 +161,7 @@ def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table1.get_extra_cache_keys(query_obj) - self.assertTrue(table1.has_extra_cache_key_calls(query_obj)) + assert table1.has_extra_cache_key_calls(query_obj) assert set(extra_cache_keys) == {1, "abc", "abc@test.com"} # Table with Jinja callable disabled. @@ -177,8 +177,8 @@ def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): ) query_obj = dict(**base_query_obj, extras={}) extra_cache_keys = table2.get_extra_cache_keys(query_obj) - self.assertTrue(table2.has_extra_cache_key_calls(query_obj)) - self.assertListEqual(extra_cache_keys, []) + assert table2.has_extra_cache_key_calls(query_obj) + self.assertListEqual(extra_cache_keys, []) # noqa: PT009 # Table with no Jinja callable. query = "SELECT 'abc' as user" @@ -190,15 +190,15 @@ def test_extra_cache_keys(self, mock_user_email, mock_username, mock_user_id): query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"}) extra_cache_keys = table3.get_extra_cache_keys(query_obj) - self.assertFalse(table3.has_extra_cache_key_calls(query_obj)) - self.assertListEqual(extra_cache_keys, []) + assert not table3.has_extra_cache_key_calls(query_obj) + self.assertListEqual(extra_cache_keys, []) # noqa: PT009 # With Jinja callable in SQL expression. query_obj = dict( **base_query_obj, extras={"where": "(user != '{{ current_username() }}')"} ) extra_cache_keys = table3.get_extra_cache_keys(query_obj) - self.assertTrue(table3.has_extra_cache_key_calls(query_obj)) + assert table3.has_extra_cache_key_calls(query_obj) assert extra_cache_keys == ["abc"] @patch("superset.jinja_context.get_username", return_value="abc") @@ -393,11 +393,9 @@ def test_where_operators(self): sqla_query = table.get_sqla_query(**query_obj) sql = table.database.compile_sqla_query(sqla_query.sqla_query) if isinstance(filter_.expected, list): - self.assertTrue( - any([candidate in sql for candidate in filter_.expected]) - ) + assert any([candidate in sql for candidate in filter_.expected]) else: - self.assertIn(filter_.expected, sql) + assert filter_.expected in sql @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_boolean_type_where_operators(self): @@ -434,7 +432,7 @@ def test_boolean_type_where_operators(self): # https://github.com/sqlalchemy/sqlalchemy/blob/master/lib/sqlalchemy/dialects/mysql/base.py if not dialect.supports_native_boolean and dialect.name != "mysql": operand = "(1, 0)" - self.assertIn(f"IN {operand}", sql) + assert f"IN {operand}" in sql def test_incorrect_jinja_syntax_raises_correct_exception(self): query_obj = { diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index cc1a813c93b30..964e3a9ceb0ee 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -91,7 +91,7 @@ def test_sql_json(self): self.login(ADMIN_USERNAME) data = self.run_sql("SELECT * FROM birth_names LIMIT 10", "1") - self.assertLess(0, len(data["data"])) + assert 0 < len(data["data"]) data = self.run_sql("SELECT * FROM nonexistent_table", "2") if backend() == "presto": @@ -220,8 +220,8 @@ def test_sql_json_cta_dynamic_db(self, ctas_method): names_count = engine.execute( f"SELECT COUNT(*) FROM birth_names" # noqa: F541 ).first() - self.assertEqual( - names_count[0], len(data) + assert names_count[0] == len( + data ) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True # cleanup @@ -238,14 +238,14 @@ def test_multi_sql(self): SELECT * FROM birth_names LIMIT 2; """ data = self.run_sql(multi_sql, "2234") - self.assertLess(0, len(data["data"])) + assert 0 < len(data["data"]) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_explain(self): self.login(ADMIN_USERNAME) data = self.run_sql("EXPLAIN SELECT * FROM birth_names", "1") - self.assertLess(0, len(data["data"])) + assert 0 < len(data["data"]) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_json_has_access(self): @@ -261,21 +261,21 @@ def test_sql_json_has_access(self): data = self.run_sql(QUERY_1, "1", username="Gagarin") db.session.query(Query).delete() db.session.commit() - self.assertLess(0, len(data["data"])) + assert 0 < len(data["data"]) def test_sqllab_has_access(self): for username in (ADMIN_USERNAME, GAMMA_SQLLAB_USERNAME): self.login(username) for endpoint in ("/sqllab/", "/sqllab/history/"): resp = self.client.get(endpoint) - self.assertEqual(200, resp.status_code) + assert 200 == resp.status_code def test_sqllab_no_access(self): self.login(GAMMA_USERNAME) for endpoint in ("/sqllab/", "/sqllab/history/"): resp = self.client.get(endpoint) # Redirects to the main page - self.assertEqual(302, resp.status_code) + assert 302 == resp.status_code def test_sql_json_schema_access(self): examples_db = get_example_database() @@ -311,7 +311,7 @@ def test_sql_json_schema_access(self): data = self.run_sql( f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser" ) - self.assertEqual(1, len(data["data"])) + assert 1 == len(data["data"]) data = self.run_sql( f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", @@ -319,7 +319,7 @@ def test_sql_json_schema_access(self): username="SchemaUser", schema=CTAS_SCHEMA_NAME, ) - self.assertEqual(1, len(data["data"])) + assert 1 == len(data["data"]) # postgres needs a schema as a part of the table name. if db_backend == "mysql": @@ -329,7 +329,7 @@ def test_sql_json_schema_access(self): username="SchemaUser", schema=CTAS_SCHEMA_NAME, ) - self.assertEqual(1, len(data["data"])) + assert 1 == len(data["data"]) db.session.query(Query).delete() with get_example_database().get_sqla_engine() as engine: @@ -349,77 +349,75 @@ def test_ps_conversion_no_dict(self): data = [["a", 4, 4.0]] results = SupersetResultSet(data, cols, BaseEngineSpec) - self.assertEqual(len(data), results.size) - self.assertEqual(len(cols), len(results.columns)) + assert len(data) == results.size + assert len(cols) == len(results.columns) def test_pa_conversion_tuple(self): cols = ["string_col", "int_col", "list_col", "float_col"] data = [("Text", 111, [123], 1.0)] results = SupersetResultSet(data, cols, BaseEngineSpec) - self.assertEqual(len(data), results.size) - self.assertEqual(len(cols), len(results.columns)) + assert len(data) == results.size + assert len(cols) == len(results.columns) def test_pa_conversion_dict(self): cols = ["string_col", "dict_col", "int_col"] data = [["a", {"c1": 1, "c2": 2, "c3": 3}, 4]] results = SupersetResultSet(data, cols, BaseEngineSpec) - self.assertEqual(len(data), results.size) - self.assertEqual(len(cols), len(results.columns)) + assert len(data) == results.size + assert len(cols) == len(results.columns) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_sql_limit(self): self.login(ADMIN_USERNAME) test_limit = 1 data = self.run_sql("SELECT * FROM birth_names", client_id="sql_limit_1") - self.assertGreater(len(data["data"]), test_limit) + assert len(data["data"]) > test_limit data = self.run_sql( "SELECT * FROM birth_names", client_id="sql_limit_2", query_limit=test_limit ) - self.assertEqual(len(data["data"]), test_limit) + assert len(data["data"]) == test_limit data = self.run_sql( f"SELECT * FROM birth_names LIMIT {test_limit}", client_id="sql_limit_3", query_limit=test_limit + 1, ) - self.assertEqual(len(data["data"]), test_limit) - self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.QUERY) + assert len(data["data"]) == test_limit + assert data["query"]["limitingFactor"] == LimitingFactor.QUERY data = self.run_sql( f"SELECT * FROM birth_names LIMIT {test_limit + 1}", client_id="sql_limit_4", query_limit=test_limit, ) - self.assertEqual(len(data["data"]), test_limit) - self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.DROPDOWN) + assert len(data["data"]) == test_limit + assert data["query"]["limitingFactor"] == LimitingFactor.DROPDOWN data = self.run_sql( f"SELECT * FROM birth_names LIMIT {test_limit}", client_id="sql_limit_5", query_limit=test_limit, ) - self.assertEqual(len(data["data"]), test_limit) - self.assertEqual( - data["query"]["limitingFactor"], LimitingFactor.QUERY_AND_DROPDOWN - ) + assert len(data["data"]) == test_limit + assert data["query"]["limitingFactor"] == LimitingFactor.QUERY_AND_DROPDOWN data = self.run_sql( "SELECT * FROM birth_names", client_id="sql_limit_6", query_limit=10000, ) - self.assertEqual(len(data["data"]), 1200) - self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED) + assert len(data["data"]) == 1200 + assert data["query"]["limitingFactor"] == LimitingFactor.NOT_LIMITED data = self.run_sql( "SELECT * FROM birth_names", client_id="sql_limit_7", query_limit=1200, ) - self.assertEqual(len(data["data"]), 1200) - self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED) + assert len(data["data"]) == 1200 + assert data["query"]["limitingFactor"] == LimitingFactor.NOT_LIMITED @pytest.mark.usefixtures("load_birth_names_data") def test_query_api_filter(self) -> None: @@ -434,7 +432,7 @@ def test_query_api_filter(self) -> None: data = self.get_json_resp(url) admin = security_manager.find_user("admin") gamma_sqllab = security_manager.find_user("gamma_sqllab") - self.assertEqual(3, len(data["result"])) + assert 3 == len(data["result"]) user_queries = [ result.get("user").get("first_name") for result in data["result"] ] @@ -461,7 +459,7 @@ def test_query_api_can_access_all_queries(self) -> None: self.login(GAMMA_SQLLAB_USERNAME) url = "/api/v1/query/" data = self.get_json_resp(url) - self.assertEqual(3, len(data["result"])) + assert 3 == len(data["result"]) # Remove all_query_access from gamma sqllab all_queries_view = security_manager.find_permission_view_menu( @@ -521,10 +519,9 @@ def test_query_api_can_access_sql_editor_id_associated_queries(self) -> None: ] } url = f"/api/v1/query/?q={prison.dumps(arguments)}" - self.assertEqual( - {"SELECT 1", "SELECT 2"}, - {r.get("sql") for r in self.get_json_resp(url)["result"]}, - ) + assert {"SELECT 1", "SELECT 2"} == { + r.get("sql") for r in self.get_json_resp(url)["result"] + } @pytest.mark.usefixtures("load_birth_names_data") def test_query_admin_can_access_all_queries(self) -> None: @@ -537,7 +534,7 @@ def test_query_admin_can_access_all_queries(self) -> None: url = "/api/v1/query/" data = self.get_json_resp(url) - self.assertEqual(3, len(data["result"])) + assert 3 == len(data["result"]) def test_api_database(self): self.login(ADMIN_USERNAME) @@ -555,10 +552,9 @@ def test_api_database(self): } url = f"api/v1/database/?q={prison.dumps(arguments)}" - self.assertEqual( - {"examples", "fake_db_100", "main"}, - {r.get("database_name") for r in self.get_json_resp(url)["result"]}, - ) + assert {"examples", "fake_db_100", "main"} == { + r.get("database_name") for r in self.get_json_resp(url)["result"] + } self.delete_fake_db() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py index 07aa7f5b7a785..e5901b5b82cb7 100644 --- a/tests/integration_tests/strategy_tests.py +++ b/tests/integration_tests/strategy_tests.py @@ -86,7 +86,7 @@ def test_top_n_dashboards_strategy(self): expected = [ {"chart_id": chart.id, "dashboard_id": dash.id} for chart in dash.slices ] - self.assertCountEqual(result, expected) + self.assertCountEqual(result, expected) # noqa: PT009 def reset_tag(self, tag): """Remove associated object from tag, used to reset tests""" @@ -106,7 +106,7 @@ def test_dashboard_tags_strategy(self): strategy = DashboardTagsStrategy(["tag1"]) result = strategy.get_payloads() expected = [] - self.assertEqual(result, expected) + assert result == expected # tag dashboard 'births' with `tag1` tag1 = get_tag("tag1", db.session, TagType.custom) @@ -118,7 +118,7 @@ def test_dashboard_tags_strategy(self): db.session.add(tagged_object) db.session.commit() - self.assertCountEqual(strategy.get_payloads(), tag1_urls) + self.assertCountEqual(strategy.get_payloads(), tag1_urls) # noqa: PT009 strategy = DashboardTagsStrategy(["tag2"]) tag2 = get_tag("tag2", db.session, TagType.custom) @@ -126,7 +126,7 @@ def test_dashboard_tags_strategy(self): result = strategy.get_payloads() expected = [] - self.assertEqual(result, expected) + assert result == expected # tag first slice dash = self.get_dash_by_slug("unicode-test") @@ -140,10 +140,10 @@ def test_dashboard_tags_strategy(self): db.session.commit() result = strategy.get_payloads() - self.assertCountEqual(result, tag2_urls) + self.assertCountEqual(result, tag2_urls) # noqa: PT009 strategy = DashboardTagsStrategy(["tag1", "tag2"]) result = strategy.get_payloads() expected = tag1_urls + tag2_urls - self.assertCountEqual(result, expected) + self.assertCountEqual(result, expected) # noqa: PT009 diff --git a/tests/integration_tests/tagging_tests.py b/tests/integration_tests/tagging_tests.py index 011227a972604..fe42dd4a5edd8 100644 --- a/tests/integration_tests/tagging_tests.py +++ b/tests/integration_tests/tagging_tests.py @@ -55,7 +55,7 @@ def test_dataset_tagging(self): self.clear_tagged_object_table() # Test to make sure nothing is in the tagged_object table - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() # Create a dataset and add it to the db test_dataset = SqlaTable( @@ -71,16 +71,16 @@ def test_dataset_tagging(self): # Test to make sure that a dataset tag was added to the tagged_object table tags = self.query_tagged_object_table() - self.assertEqual(1, len(tags)) - self.assertEqual("ObjectType.dataset", str(tags[0].object_type)) - self.assertEqual(test_dataset.id, tags[0].object_id) + assert 1 == len(tags) + assert "ObjectType.dataset" == str(tags[0].object_type) + assert test_dataset.id == tags[0].object_id # Cleanup the db db.session.delete(test_dataset) db.session.commit() # Test to make sure the tag is deleted when the associated object is deleted - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() @pytest.mark.usefixtures("with_tagging_system_feature") def test_chart_tagging(self): @@ -94,7 +94,7 @@ def test_chart_tagging(self): self.clear_tagged_object_table() # Test to make sure nothing is in the tagged_object table - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() # Create a chart and add it to the db test_chart = Slice( @@ -109,16 +109,16 @@ def test_chart_tagging(self): # Test to make sure that a chart tag was added to the tagged_object table tags = self.query_tagged_object_table() - self.assertEqual(1, len(tags)) - self.assertEqual("ObjectType.chart", str(tags[0].object_type)) - self.assertEqual(test_chart.id, tags[0].object_id) + assert 1 == len(tags) + assert "ObjectType.chart" == str(tags[0].object_type) + assert test_chart.id == tags[0].object_id # Cleanup the db db.session.delete(test_chart) db.session.commit() # Test to make sure the tag is deleted when the associated object is deleted - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() @pytest.mark.usefixtures("with_tagging_system_feature") def test_dashboard_tagging(self): @@ -132,7 +132,7 @@ def test_dashboard_tagging(self): self.clear_tagged_object_table() # Test to make sure nothing is in the tagged_object table - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() # Create a dashboard and add it to the db test_dashboard = Dashboard() @@ -145,16 +145,16 @@ def test_dashboard_tagging(self): # Test to make sure that a dashboard tag was added to the tagged_object table tags = self.query_tagged_object_table() - self.assertEqual(1, len(tags)) - self.assertEqual("ObjectType.dashboard", str(tags[0].object_type)) - self.assertEqual(test_dashboard.id, tags[0].object_id) + assert 1 == len(tags) + assert "ObjectType.dashboard" == str(tags[0].object_type) + assert test_dashboard.id == tags[0].object_id # Cleanup the db db.session.delete(test_dashboard) db.session.commit() # Test to make sure the tag is deleted when the associated object is deleted - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() @pytest.mark.usefixtures("with_tagging_system_feature") def test_saved_query_tagging(self): @@ -168,7 +168,7 @@ def test_saved_query_tagging(self): self.clear_tagged_object_table() # Test to make sure nothing is in the tagged_object table - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() # Create a saved query and add it to the db test_saved_query = SavedQuery(id=1, label="test saved query") @@ -178,24 +178,24 @@ def test_saved_query_tagging(self): # Test to make sure that a saved query tag was added to the tagged_object table tags = self.query_tagged_object_table() - self.assertEqual(2, len(tags)) + assert 2 == len(tags) - self.assertEqual("ObjectType.query", str(tags[0].object_type)) - self.assertEqual("owner:None", str(tags[0].tag.name)) - self.assertEqual("TagType.owner", str(tags[0].tag.type)) - self.assertEqual(test_saved_query.id, tags[0].object_id) + assert "ObjectType.query" == str(tags[0].object_type) + assert "owner:None" == str(tags[0].tag.name) + assert "TagType.owner" == str(tags[0].tag.type) + assert test_saved_query.id == tags[0].object_id - self.assertEqual("ObjectType.query", str(tags[1].object_type)) - self.assertEqual("type:query", str(tags[1].tag.name)) - self.assertEqual("TagType.type", str(tags[1].tag.type)) - self.assertEqual(test_saved_query.id, tags[1].object_id) + assert "ObjectType.query" == str(tags[1].object_type) + assert "type:query" == str(tags[1].tag.name) + assert "TagType.type" == str(tags[1].tag.type) + assert test_saved_query.id == tags[1].object_id # Cleanup the db db.session.delete(test_saved_query) db.session.commit() # Test to make sure the tag is deleted when the associated object is deleted - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() @pytest.mark.usefixtures("with_tagging_system_feature") def test_favorite_tagging(self): @@ -209,7 +209,7 @@ def test_favorite_tagging(self): self.clear_tagged_object_table() # Test to make sure nothing is in the tagged_object table - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() # Create a favorited object and add it to the db test_saved_query = FavStar(user_id=1, class_name="slice", obj_id=1) @@ -218,16 +218,16 @@ def test_favorite_tagging(self): # Test to make sure that a favorited object tag was added to the tagged_object table tags = self.query_tagged_object_table() - self.assertEqual(1, len(tags)) - self.assertEqual("ObjectType.chart", str(tags[0].object_type)) - self.assertEqual(test_saved_query.obj_id, tags[0].object_id) + assert 1 == len(tags) + assert "ObjectType.chart" == str(tags[0].object_type) + assert test_saved_query.obj_id == tags[0].object_id # Cleanup the db db.session.delete(test_saved_query) db.session.commit() # Test to make sure the tag is deleted when the associated object is deleted - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() @with_feature_flags(TAGGING_SYSTEM=False) def test_tagging_system(self): @@ -240,7 +240,7 @@ def test_tagging_system(self): self.clear_tagged_object_table() # Test to make sure nothing is in the tagged_object table - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() # Create a dataset and add it to the db test_dataset = SqlaTable( @@ -282,7 +282,7 @@ def test_tagging_system(self): # Test to make sure that no tags were added to the tagged_object table tags = self.query_tagged_object_table() - self.assertEqual(0, len(tags)) + assert 0 == len(tags) # Cleanup the db db.session.delete(test_dataset) @@ -293,4 +293,4 @@ def test_tagging_system(self): db.session.commit() # Test to make sure all the tags are deleted when the associated objects are deleted - self.assertEqual([], self.query_tagged_object_table()) + assert [] == self.query_tagged_object_table() diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index 3f6e499449d53..3de8b67fb8a4f 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -135,7 +135,7 @@ def test_get_tag(self): self.login(ADMIN_USERNAME) uri = f"api/v1/tag/{tag.id}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 expected_result = { "changed_by": None, "changed_on_delta_humanized": "now", @@ -146,7 +146,7 @@ def test_get_tag(self): } data = json.loads(rv.data.decode("utf-8")) for key, value in expected_result.items(): - self.assertEqual(value, data["result"][key]) + assert value == data["result"][key] # rollback changes db.session.delete(tag) db.session.commit() @@ -160,7 +160,7 @@ def test_get_tag_not_found(self): self.login(ADMIN_USERNAME) uri = f"api/v1/tag/{max_id + 1}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 # cleanup db.session.delete(tag) db.session.commit() @@ -173,7 +173,7 @@ def test_get_list_tag(self): self.login(ADMIN_USERNAME) uri = "api/v1/tag/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == TAGS_FIXTURE_COUNT # check expected columns @@ -211,7 +211,7 @@ def test_get_list_tag_filtered(self): } uri = f"api/v1/tag/?{parse.urlencode({'q': prison.dumps(query)})}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 2 @@ -219,7 +219,7 @@ def test_get_list_tag_filtered(self): query["filters"][0]["value"] = False uri = f"api/v1/tag/?{parse.urlencode({'q': prison.dumps(query)})}" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 3 @@ -249,10 +249,10 @@ def test_add_tagged_objects(self): data = {"properties": {"tags": example_tag_names}} rv = self.client.post(uri, json=data, follow_redirects=True) # successful request - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 # check that tags were created in database tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names)) - self.assertEqual(tags.count(), 2) + assert tags.count() == 2 # check that tagged objects were created tag_ids = [tags[0].id, tags[1].id] tagged_objects = db.session.query(TaggedObject).filter( @@ -308,7 +308,7 @@ def test_delete_tagged_objects(self): uri = f"api/v1/tag/{dashboard_type.value}/{dashboard_id}/{tags.first().name}" rv = self.client.delete(uri, follow_redirects=True) # successful request - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # ensure that tagged object no longer exists tagged_object = ( db.session.query(TaggedObject) @@ -358,14 +358,14 @@ def test_get_objects_by_tag(self): TaggedObject.object_id == dashboard_id, TaggedObject.object_type == dashboard_type.name, ) - self.assertEqual(tagged_objects.count(), 2) + assert tagged_objects.count() == 2 uri = f'api/v1/tag/get_objects/?tags={",".join(tag_names)}' rv = self.client.get(uri) # successful request - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 fetched_objects = rv.json["result"] - self.assertEqual(len(fetched_objects), 1) - self.assertEqual(fetched_objects[0]["id"], dashboard_id) + assert len(fetched_objects) == 1 + assert fetched_objects[0]["id"] == dashboard_id # clean up tagged object tagged_objects.delete() @@ -394,12 +394,12 @@ def test_get_all_objects(self): TaggedObject.object_id == dashboard_id, TaggedObject.object_type == dashboard_type.name, ) - self.assertEqual(tagged_objects.count(), 2) - self.assertEqual(tagged_objects.first().object_id, dashboard_id) + assert tagged_objects.count() == 2 + assert tagged_objects.first().object_id == dashboard_id uri = "api/v1/tag/get_objects/" rv = self.client.get(uri) # successful request - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 fetched_objects = rv.json["result"] # check that the dashboard object was fetched assert dashboard_id in [obj["id"] for obj in fetched_objects] @@ -413,25 +413,25 @@ def test_delete_tags(self): # check that tags exist in the database example_tag_names = ["example_tag_1", "example_tag_2", "example_tag_3"] tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names)) - self.assertEqual(tags.count(), 3) + assert tags.count() == 3 # delete the first tag uri = f"api/v1/tag/?q={prison.dumps(example_tag_names[:1])}" rv = self.client.delete(uri, follow_redirects=True) # successful request - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # check that tag does not exist in the database tag = db.session.query(Tag).filter(Tag.name == example_tag_names[0]).first() assert tag is None tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names)) - self.assertEqual(tags.count(), 2) + assert tags.count() == 2 # delete multiple tags uri = f"api/v1/tag/?q={prison.dumps(example_tag_names[1:])}" rv = self.client.delete(uri, follow_redirects=True) # successful request - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 # check that tags are all gone tags = db.session.query(Tag).filter(Tag.name.in_(example_tag_names)) - self.assertEqual(tags.count(), 0) + assert tags.count() == 0 @pytest.mark.usefixtures("create_tags") def test_delete_favorite_tag(self): @@ -442,7 +442,7 @@ def test_delete_favorite_tag(self): tag = db.session.query(Tag).first() rv = self.client.post(uri, follow_redirects=True) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 from sqlalchemy import and_ # noqa: F811 from superset.tags.models import user_favorite_tag_table # noqa: F811 from flask import g # noqa: F401, F811 @@ -463,7 +463,7 @@ def test_delete_favorite_tag(self): uri = f"api/v1/tag/{tag.id}/favorites/" rv = self.client.delete(uri, follow_redirects=True) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 association_row = ( db.session.query(user_favorite_tag_table) .filter( @@ -483,7 +483,7 @@ def test_add_tag_not_found(self): uri = "api/v1/tag/123/favorites/" # noqa: F541 rv = self.client.post(uri, follow_redirects=True) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_tags") def test_delete_favorite_tag_not_found(self): @@ -491,7 +491,7 @@ def test_delete_favorite_tag_not_found(self): uri = "api/v1/tag/123/favorites/" # noqa: F541 rv = self.client.delete(uri, follow_redirects=True) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("create_tags") @patch("superset.daos.tag.g") @@ -501,7 +501,7 @@ def test_add_tag_user_not_found(self, flask_g): uri = "api/v1/tag/123/favorites/" # noqa: F541 rv = self.client.post(uri, follow_redirects=True) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 @pytest.mark.usefixtures("create_tags") @patch("superset.daos.tag.g") @@ -511,7 +511,7 @@ def test_delete_favorite_tag_user_not_found(self, flask_g): uri = "api/v1/tag/123/favorites/" # noqa: F541 rv = self.client.delete(uri, follow_redirects=True) - self.assertEqual(rv.status_code, 422) + assert rv.status_code == 422 @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_post_tag(self): @@ -527,7 +527,7 @@ def test_post_tag(self): json={"name": "my_tag", "objects_to_tag": [["dashboard", dashboard.id]]}, ) - self.assertEqual(rv.status_code, 201) + assert rv.status_code == 201 self.get_user(username="admin").get_id() # noqa: F841 tag = ( db.session.query(Tag) @@ -550,7 +550,7 @@ def test_post_tag_no_name_400(self): json={"name": "", "objects_to_tag": [["dashboard", dashboard.id]]}, ) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") @pytest.mark.usefixtures("create_tags") @@ -563,7 +563,7 @@ def test_put_tag(self): uri, json={"name": "new_name", "description": "new description"} ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 tag = ( db.session.query(Tag) @@ -581,7 +581,7 @@ def test_failed_put_tag(self): uri = f"api/v1/tag/{tag_to_update.id}" rv = self.client.put(uri, json={"foo": "bar"}) - self.assertEqual(rv.status_code, 400) + assert rv.status_code == 400 @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_post_bulk_tag(self): @@ -617,7 +617,7 @@ def test_post_bulk_tag(self): }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 result = TagDAO.get_tagged_objects_for_tags(tags, ["dashboard"]) assert len(result) == 1 @@ -686,7 +686,7 @@ def test_post_bulk_tag_skipped_tags_perm(self): }, ) - self.assertEqual(rv.status_code, 200) + assert rv.status_code == 200 result = rv.json["result"] assert len(result["objects_tagged"]) == 2 assert len(result["objects_skipped"]) == 1 diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py index cbab3f84f92ba..ccb9a9c734c13 100644 --- a/tests/integration_tests/thumbnails_tests.py +++ b/tests/integration_tests/thumbnails_tests.py @@ -73,7 +73,7 @@ def test_get_async_dashboard_screenshot(self): "admin", thumbnail_url, ) - self.assertEqual(response.getcode(), 202) + assert response.getcode() == 202 class TestWebDriverScreenshotErrorDetector(SupersetTestCase): @@ -217,7 +217,7 @@ def test_dashboard_thumbnail_disabled(self): self.login(ADMIN_USERNAME) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=False) @@ -228,7 +228,7 @@ def test_chart_thumbnail_disabled(self): self.login(ADMIN_USERNAME) _, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -255,7 +255,7 @@ def test_get_async_dashboard_screenshot_as_selenium(self): assert mock_adjust_string.call_args[0][2] == "admin" rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 202) + assert rv.status_code == 202 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -283,7 +283,7 @@ def test_get_async_dashboard_screenshot_as_current_user(self): assert mock_adjust_string.call_args[0][2] == username rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 202) + assert rv.status_code == 202 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -295,7 +295,7 @@ def test_get_async_dashboard_notfound(self): self.login(ADMIN_USERNAME) uri = f"api/v1/dashboard/{max_id + 1}/thumbnail/1234/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @@ -306,7 +306,7 @@ def test_get_async_dashboard_not_allowed(self): self.login(ADMIN_USERNAME) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -333,7 +333,7 @@ def test_get_async_chart_screenshot_as_selenium(self): assert mock_adjust_string.call_args[0][2] == "admin" rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 202) + assert rv.status_code == 202 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -361,7 +361,7 @@ def test_get_async_chart_screenshot_as_current_user(self): assert mock_adjust_string.call_args[0][2] == username rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 202) + assert rv.status_code == 202 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -373,7 +373,7 @@ def test_get_async_chart_notfound(self): self.login(ADMIN_USERNAME) uri = f"api/v1/chart/{max_id + 1}/thumbnail/1234/" rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) + assert rv.status_code == 404 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -387,8 +387,8 @@ def test_get_cached_chart_wrong_digest(self): self.login(ADMIN_USERNAME) id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) rv = self.client.get(f"api/v1/chart/{id_}/thumbnail/1234/") - self.assertEqual(rv.status_code, 302) - self.assertEqual(rv.headers["Location"], thumbnail_url) + assert rv.status_code == 302 + assert rv.headers["Location"] == thumbnail_url @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -402,8 +402,8 @@ def test_get_cached_dashboard_screenshot(self): self.login(ADMIN_USERNAME) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.data, self.mock_image) + assert rv.status_code == 200 + assert rv.data == self.mock_image @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -417,8 +417,8 @@ def test_get_cached_chart_screenshot(self): self.login(ADMIN_USERNAME) id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) rv = self.client.get(thumbnail_url) - self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.data, self.mock_image) + assert rv.status_code == 200 + assert rv.data == self.mock_image @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -432,5 +432,5 @@ def test_get_cached_dashboard_wrong_digest(self): self.login(ADMIN_USERNAME) id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) rv = self.client.get(f"api/v1/dashboard/{id_}/thumbnail/1234/") - self.assertEqual(rv.status_code, 302) - self.assertEqual(rv.headers["Location"], thumbnail_url) + assert rv.status_code == 302 + assert rv.headers["Location"] == thumbnail_url diff --git a/tests/integration_tests/users/api_tests.py b/tests/integration_tests/users/api_tests.py index 7894b1856c1ef..416e2be572cd1 100644 --- a/tests/integration_tests/users/api_tests.py +++ b/tests/integration_tests/users/api_tests.py @@ -35,36 +35,36 @@ def test_get_me_logged_in(self): rv = self.client.get(meUri) - self.assertEqual(200, rv.status_code) + assert 200 == rv.status_code response = json.loads(rv.data.decode("utf-8")) - self.assertEqual("admin", response["result"]["username"]) - self.assertEqual(True, response["result"]["is_active"]) - self.assertEqual(False, response["result"]["is_anonymous"]) + assert "admin" == response["result"]["username"] + assert True is response["result"]["is_active"] + assert False is response["result"]["is_anonymous"] def test_get_me_with_roles(self): self.login(ADMIN_USERNAME) rv = self.client.get(meUri + "roles/") - self.assertEqual(200, rv.status_code) + assert 200 == rv.status_code response = json.loads(rv.data.decode("utf-8")) roles = list(response["result"]["roles"].keys()) - self.assertEqual("Admin", roles.pop()) + assert "Admin" == roles.pop() @patch("superset.security.manager.g") def test_get_my_roles_anonymous(self, mock_g): mock_g.user = security_manager.get_anonymous_user rv = self.client.get(meUri + "roles/") - self.assertEqual(401, rv.status_code) + assert 401 == rv.status_code def test_get_me_unauthorized(self): rv = self.client.get(meUri) - self.assertEqual(401, rv.status_code) + assert 401 == rv.status_code @patch("superset.security.manager.g") def test_get_me_anonymous(self, mock_g): mock_g.user = security_manager.get_anonymous_user rv = self.client.get(meUri) - self.assertEqual(401, rv.status_code) + assert 401 == rv.status_code class TestUserApi(SupersetTestCase): diff --git a/tests/integration_tests/utils/encrypt_tests.py b/tests/integration_tests/utils/encrypt_tests.py index cc882ee64b074..dccfc54a77344 100644 --- a/tests/integration_tests/utils/encrypt_tests.py +++ b/tests/integration_tests/utils/encrypt_tests.py @@ -53,8 +53,8 @@ def setUp(self) -> None: def test_create_field(self): field = encrypted_field_factory.create(String(1024)) - self.assertTrue(isinstance(field, EncryptedType)) - self.assertEqual(self.app.config["SECRET_KEY"], field.key) + assert isinstance(field, EncryptedType) + assert self.app.config["SECRET_KEY"] == field.key def test_custom_adapter(self): self.app.config["SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"] = ( @@ -62,10 +62,10 @@ def test_custom_adapter(self): ) encrypted_field_factory.init_app(self.app) field = encrypted_field_factory.create(String(1024)) - self.assertTrue(isinstance(field, StringEncryptedType)) - self.assertFalse(isinstance(field, EncryptedType)) - self.assertTrue(getattr(field, "__created_by_enc_field_adapter__")) - self.assertEqual(self.app.config["SECRET_KEY"], field.key) + assert isinstance(field, StringEncryptedType) + assert not isinstance(field, EncryptedType) + assert getattr(field, "__created_by_enc_field_adapter__") + assert self.app.config["SECRET_KEY"] == field.key def test_ensure_encrypted_field_factory_is_used(self): """ diff --git a/tests/integration_tests/utils/machine_auth_tests.py b/tests/integration_tests/utils/machine_auth_tests.py index 0dc8d4a1249f6..40de236b6a901 100644 --- a/tests/integration_tests/utils/machine_auth_tests.py +++ b/tests/integration_tests/utils/machine_auth_tests.py @@ -25,7 +25,7 @@ class MachineAuthProviderTests(SupersetTestCase): def test_get_auth_cookies(self): user = self.get_user("admin") auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user) - self.assertIsNotNone(auth_cookies["session"]) + assert auth_cookies["session"] is not None @patch("superset.utils.machine_auth.MachineAuthProvider.get_auth_cookies") def test_auth_driver_user(self, get_auth_cookies): diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index cbdea3f60b623..202a5b84d45c9 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -132,14 +132,14 @@ def test_zlib_compression(self): json_str = '{"test": 1}' blob = zlib_compress(json_str) got_str = zlib_decompress(blob) - self.assertEqual(json_str, got_str) + assert json_str == got_str def test_merge_extra_filters(self): # does nothing if no extra filters form_data = {"A": 1, "B": 2, "c": "test"} expected = {**form_data, "adhoc_filters": [], "applied_time_extras": {}} merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected # empty extra_filters form_data = {"A": 1, "B": 2, "c": "test", "extra_filters": []} expected = { @@ -150,7 +150,7 @@ def test_merge_extra_filters(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected # copy over extra filters into empty filters form_data = { "extra_filters": [ @@ -182,7 +182,7 @@ def test_merge_extra_filters(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected # adds extra filters to existing filters form_data = { "extra_filters": [ @@ -230,7 +230,7 @@ def test_merge_extra_filters(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected # adds extra filters to existing filters and sets time options form_data = { "extra_filters": [ @@ -262,7 +262,7 @@ def test_merge_extra_filters(self): }, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_merge_extra_filters_ignores_empty_filters(self): form_data = { @@ -273,7 +273,7 @@ def test_merge_extra_filters_ignores_empty_filters(self): } expected = {"adhoc_filters": [], "applied_time_extras": {}} merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_merge_extra_filters_ignores_nones(self): form_data = { @@ -301,7 +301,7 @@ def test_merge_extra_filters_ignores_nones(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_merge_extra_filters_ignores_equal_filters(self): form_data = { @@ -361,7 +361,7 @@ def test_merge_extra_filters_ignores_equal_filters(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_merge_extra_filters_merges_different_val_types(self): form_data = { @@ -415,7 +415,7 @@ def test_merge_extra_filters_merges_different_val_types(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected form_data = { "extra_filters": [ {"col": "a", "op": "in", "val": "someval"}, @@ -467,7 +467,7 @@ def test_merge_extra_filters_merges_different_val_types(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_merge_extra_filters_adds_unequal_lists(self): form_data = { @@ -530,27 +530,24 @@ def test_merge_extra_filters_adds_unequal_lists(self): "applied_time_extras": {}, } merge_extra_filters(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_merge_extra_filters_when_applied_time_extras_predefined(self): form_data = {"applied_time_extras": {"__time_range": "Last week"}} merge_extra_filters(form_data) - self.assertEqual( - form_data, - { - "applied_time_extras": {"__time_range": "Last week"}, - "adhoc_filters": [], - }, - ) + assert form_data == { + "applied_time_extras": {"__time_range": "Last week"}, + "adhoc_filters": [], + } def test_merge_request_params_when_url_params_undefined(self): form_data = {"since": "2000", "until": "now"} url_params = {"form_data": form_data, "dashboard_ids": "(1,2,3,4,5)"} merge_request_params(form_data, url_params) - self.assertIn("url_params", form_data.keys()) - self.assertIn("dashboard_ids", form_data["url_params"]) - self.assertNotIn("form_data", form_data.keys()) + assert "url_params" in form_data.keys() + assert "dashboard_ids" in form_data["url_params"] + assert "form_data" not in form_data.keys() def test_merge_request_params_when_url_params_predefined(self): form_data = { @@ -560,30 +557,26 @@ def test_merge_request_params_when_url_params_predefined(self): } url_params = {"form_data": form_data, "dashboard_ids": "(1,2,3,4,5)"} merge_request_params(form_data, url_params) - self.assertIn("url_params", form_data.keys()) - self.assertIn("abc", form_data["url_params"]) - self.assertEqual( - url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"] - ) + assert "url_params" in form_data.keys() + assert "abc" in form_data["url_params"] + assert url_params["dashboard_ids"] == form_data["url_params"]["dashboard_ids"] def test_format_timedelta(self): - self.assertEqual(json.format_timedelta(timedelta(0)), "0:00:00") - self.assertEqual(json.format_timedelta(timedelta(days=1)), "1 day, 0:00:00") - self.assertEqual(json.format_timedelta(timedelta(minutes=-6)), "-0:06:00") - self.assertEqual( - json.format_timedelta(timedelta(0) - timedelta(days=1, hours=5, minutes=6)), - "-1 day, 5:06:00", + assert json.format_timedelta(timedelta(0)) == "0:00:00" + assert json.format_timedelta(timedelta(days=1)) == "1 day, 0:00:00" + assert json.format_timedelta(timedelta(minutes=-6)) == "-0:06:00" + assert ( + json.format_timedelta(timedelta(0) - timedelta(days=1, hours=5, minutes=6)) + == "-1 day, 5:06:00" ) - self.assertEqual( - json.format_timedelta( - timedelta(0) - timedelta(days=16, hours=4, minutes=3) - ), - "-16 days, 4:03:00", + assert ( + json.format_timedelta(timedelta(0) - timedelta(days=16, hours=4, minutes=3)) + == "-16 days, 4:03:00" ) def test_validate_json(self): valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}' - self.assertIsNone(json.validate_json(valid)) + assert json.validate_json(valid) is None invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}' with self.assertRaises(json.JSONDecodeError): json.validate_json(invalid) @@ -601,7 +594,7 @@ def test_convert_legacy_filters_into_adhoc_where(self): ] } convert_legacy_filters_into_adhoc(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_convert_legacy_filters_into_adhoc_filters(self): form_data = {"filters": [{"col": "a", "op": "in", "val": "someval"}]} @@ -618,7 +611,7 @@ def test_convert_legacy_filters_into_adhoc_filters(self): ] } convert_legacy_filters_into_adhoc(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_convert_legacy_filters_into_adhoc_present_and_empty(self): form_data = {"adhoc_filters": [], "where": "a = 1"} @@ -633,7 +626,7 @@ def test_convert_legacy_filters_into_adhoc_present_and_empty(self): ] } convert_legacy_filters_into_adhoc(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_convert_legacy_filters_into_adhoc_having(self): form_data = {"having": "COUNT(1) = 1"} @@ -648,7 +641,7 @@ def test_convert_legacy_filters_into_adhoc_having(self): ] } convert_legacy_filters_into_adhoc(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_convert_legacy_filters_into_adhoc_present_and_nonempty(self): form_data = { @@ -664,23 +657,23 @@ def test_convert_legacy_filters_into_adhoc_present_and_nonempty(self): ] } convert_legacy_filters_into_adhoc(form_data) - self.assertEqual(form_data, expected) + assert form_data == expected def test_parse_js_uri_path_items_eval_undefined(self): - self.assertIsNone(parse_js_uri_path_item("undefined", eval_undefined=True)) - self.assertIsNone(parse_js_uri_path_item("null", eval_undefined=True)) - self.assertEqual("undefined", parse_js_uri_path_item("undefined")) - self.assertEqual("null", parse_js_uri_path_item("null")) + assert parse_js_uri_path_item("undefined", eval_undefined=True) is None + assert parse_js_uri_path_item("null", eval_undefined=True) is None + assert "undefined" == parse_js_uri_path_item("undefined") + assert "null" == parse_js_uri_path_item("null") def test_parse_js_uri_path_items_unquote(self): - self.assertEqual("slashed/name", parse_js_uri_path_item("slashed%2fname")) - self.assertEqual( - "slashed%2fname", parse_js_uri_path_item("slashed%2fname", unquote=False) + assert "slashed/name" == parse_js_uri_path_item("slashed%2fname") + assert "slashed%2fname" == parse_js_uri_path_item( + "slashed%2fname", unquote=False ) def test_parse_js_uri_path_items_item_optional(self): - self.assertIsNone(parse_js_uri_path_item(None)) - self.assertIsNotNone(parse_js_uri_path_item("item")) + assert parse_js_uri_path_item(None) is None + assert parse_js_uri_path_item("item") is not None def test_get_stacktrace(self): app.config["SHOW_STACKTRACE"] = True @@ -688,7 +681,7 @@ def test_get_stacktrace(self): raise Exception("NONONO!") except Exception: stacktrace = get_stacktrace() - self.assertIn("NONONO", stacktrace) + assert "NONONO" in stacktrace app.config["SHOW_STACKTRACE"] = False try: @@ -698,31 +691,31 @@ def test_get_stacktrace(self): assert stacktrace is None def test_split(self): - self.assertEqual(list(split("a b")), ["a", "b"]) - self.assertEqual(list(split("a,b", delimiter=",")), ["a", "b"]) - self.assertEqual(list(split("a,(b,a)", delimiter=",")), ["a", "(b,a)"]) - self.assertEqual( - list(split('a,(b,a),"foo , bar"', delimiter=",")), - ["a", "(b,a)", '"foo , bar"'], - ) - self.assertEqual( - list(split("a,'b,c'", delimiter=",", quote="'")), ["a", "'b,c'"] - ) - self.assertEqual(list(split('a "b c"')), ["a", '"b c"']) - self.assertEqual(list(split(r'a "b \" c"')), ["a", r'"b \" c"']) + assert list(split("a b")) == ["a", "b"] + assert list(split("a,b", delimiter=",")) == ["a", "b"] + assert list(split("a,(b,a)", delimiter=",")) == ["a", "(b,a)"] + assert list(split('a,(b,a),"foo , bar"', delimiter=",")) == [ + "a", + "(b,a)", + '"foo , bar"', + ] + assert list(split("a,'b,c'", delimiter=",", quote="'")) == ["a", "'b,c'"] + assert list(split('a "b c"')) == ["a", '"b c"'] + assert list(split('a "b \\" c"')) == ["a", '"b \\" c"'] def test_get_or_create_db(self): get_or_create_db("test_db", "sqlite:///superset.db") database = db.session.query(Database).filter_by(database_name="test_db").one() - self.assertIsNotNone(database) - self.assertEqual(database.sqlalchemy_uri, "sqlite:///superset.db") - self.assertIsNotNone( + assert database is not None + assert database.sqlalchemy_uri == "sqlite:///superset.db" + assert ( security_manager.find_permission_view_menu("database_access", database.perm) + is not None ) # Test change URI get_or_create_db("test_db", "sqlite:///changed.db") database = db.session.query(Database).filter_by(database_name="test_db").one() - self.assertEqual(database.sqlalchemy_uri, "sqlite:///changed.db") + assert database.sqlalchemy_uri == "sqlite:///changed.db" db.session.delete(database) db.session.commit() @@ -738,22 +731,16 @@ def test_get_or_create_db_existing_invalid_uri(self): assert database.sqlalchemy_uri == "sqlite:///superset.db" def test_as_list(self): - self.assertListEqual(as_list(123), [123]) - self.assertListEqual(as_list([123]), [123]) - self.assertListEqual(as_list("foo"), ["foo"]) + self.assertListEqual(as_list(123), [123]) # noqa: PT009 + self.assertListEqual(as_list([123]), [123]) # noqa: PT009 + self.assertListEqual(as_list("foo"), ["foo"]) # noqa: PT009 def test_merge_extra_filters_with_no_extras(self): form_data = { "time_range": "Last 10 days", } merge_extra_form_data(form_data) - self.assertEqual( - form_data, - { - "time_range": "Last 10 days", - "adhoc_filters": [], - }, - ) + assert form_data == {"time_range": "Last 10 days", "adhoc_filters": []} def test_merge_extra_filters_with_unset_legacy_time_range(self): """ @@ -767,14 +754,11 @@ def test_merge_extra_filters_with_unset_legacy_time_range(self): "extra_form_data": {"time_range": "Last year"}, } merge_extra_filters(form_data) - self.assertEqual( - form_data, - { - "time_range": "Last year", - "applied_time_extras": {}, - "adhoc_filters": [], - }, - ) + assert form_data == { + "time_range": "Last year", + "applied_time_extras": {}, + "adhoc_filters": [], + } def test_merge_extra_filters_with_extras(self): form_data = { @@ -817,41 +801,45 @@ def test_merge_extra_filters_with_extras(self): def test_ssl_certificate_parse(self): parsed_certificate = parse_ssl_cert(ssl_certificate) - self.assertEqual(parsed_certificate.serial_number, 12355228710836649848) + assert parsed_certificate.serial_number == 12355228710836649848 def test_ssl_certificate_file_creation(self): path = create_ssl_cert_file(ssl_certificate) expected_filename = md5_sha_from_str(ssl_certificate) - self.assertIn(expected_filename, path) - self.assertTrue(os.path.exists(path)) + assert expected_filename in path + assert os.path.exists(path) def test_get_email_address_list(self): - self.assertEqual(get_email_address_list("a@a"), ["a@a"]) - self.assertEqual(get_email_address_list(" a@a "), ["a@a"]) - self.assertEqual(get_email_address_list("a@a\n"), ["a@a"]) - self.assertEqual(get_email_address_list(",a@a;"), ["a@a"]) - self.assertEqual( - get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f"), - ["a@a", "b@b", "c@c", "a-c@c", "d@d", "f@f"], - ) + assert get_email_address_list("a@a") == ["a@a"] + assert get_email_address_list(" a@a ") == ["a@a"] + assert get_email_address_list("a@a\n") == ["a@a"] + assert get_email_address_list(",a@a;") == ["a@a"] + assert get_email_address_list(",a@a; b@b c@c a-c@c; d@d, f@f") == [ + "a@a", + "b@b", + "c@c", + "a-c@c", + "d@d", + "f@f", + ] def test_get_form_data_default(self) -> None: form_data, slc = get_form_data() - self.assertEqual(slc, None) + assert slc is None def test_get_form_data_request_args(self) -> None: with app.test_request_context( query_string={"form_data": json.dumps({"foo": "bar"})} ): form_data, slc = get_form_data() - self.assertEqual(form_data, {"foo": "bar"}) - self.assertEqual(slc, None) + assert form_data == {"foo": "bar"} + assert slc is None def test_get_form_data_request_form(self) -> None: with app.test_request_context(data={"form_data": json.dumps({"foo": "bar"})}): form_data, slc = get_form_data() - self.assertEqual(form_data, {"foo": "bar"}) - self.assertEqual(slc, None) + assert form_data == {"foo": "bar"} + assert slc is None def test_get_form_data_request_form_with_queries(self) -> None: # the CSV export uses for requests, even when sending requests to @@ -862,8 +850,8 @@ def test_get_form_data_request_form_with_queries(self) -> None: } ): form_data, slc = get_form_data() - self.assertEqual(form_data, {"url_params": {"foo": "bar"}}) - self.assertEqual(slc, None) + assert form_data == {"url_params": {"foo": "bar"}} + assert slc is None def test_get_form_data_request_args_and_form(self) -> None: with app.test_request_context( @@ -871,16 +859,16 @@ def test_get_form_data_request_args_and_form(self) -> None: query_string={"form_data": json.dumps({"baz": "bar"})}, ): form_data, slc = get_form_data() - self.assertEqual(form_data, {"baz": "bar", "foo": "bar"}) - self.assertEqual(slc, None) + assert form_data == {"baz": "bar", "foo": "bar"} + assert slc is None def test_get_form_data_globals(self) -> None: with app.test_request_context(): g.form_data = {"foo": "bar"} form_data, slc = get_form_data() delattr(g, "form_data") - self.assertEqual(form_data, {"foo": "bar"}) - self.assertEqual(slc, None) + assert form_data == {"foo": "bar"} + assert slc is None def test_get_form_data_corrupted_json(self) -> None: with app.test_request_context( @@ -888,8 +876,8 @@ def test_get_form_data_corrupted_json(self) -> None: query_string={"form_data": '{"baz": "bar"'}, ): form_data, slc = get_form_data() - self.assertEqual(form_data, {}) - self.assertEqual(slc, None) + assert form_data == {} + assert slc is None @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_log_this(self) -> None: @@ -912,29 +900,29 @@ def test_log_this(self) -> None: .first() ) - self.assertEqual(record.dashboard_id, dashboard_id) - self.assertEqual(json.loads(record.json)["dashboard_id"], str(dashboard_id)) - self.assertEqual(json.loads(record.json)["form_data"]["slice_id"], slc.id) + assert record.dashboard_id == dashboard_id + assert json.loads(record.json)["dashboard_id"] == str(dashboard_id) + assert json.loads(record.json)["form_data"]["slice_id"] == slc.id - self.assertEqual( - json.loads(record.json)["form_data"]["viz_type"], - slc.viz.form_data["viz_type"], + assert ( + json.loads(record.json)["form_data"]["viz_type"] + == slc.viz.form_data["viz_type"] ) def test_schema_validate_json(self): valid = '{"a": 5, "b": [1, 5, ["g", "h"]]}' - self.assertIsNone(schema.validate_json(valid)) + assert schema.validate_json(valid) is None invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}' self.assertRaises(marshmallow.ValidationError, schema.validate_json, invalid) def test_schema_one_of_case_insensitive(self): validator = schema.OneOfCaseInsensitive(choices=[1, 2, 3, "FoO", "BAR", "baz"]) - self.assertEqual(1, validator(1)) - self.assertEqual(2, validator(2)) - self.assertEqual("FoO", validator("FoO")) - self.assertEqual("FOO", validator("FOO")) - self.assertEqual("bar", validator("bar")) - self.assertEqual("BaZ", validator("BaZ")) + assert 1 == validator(1) + assert 2 == validator(2) + assert "FoO" == validator("FoO") + assert "FOO" == validator("FOO") + assert "bar" == validator("bar") + assert "BaZ" == validator("BaZ") self.assertRaises(marshmallow.ValidationError, validator, "qwerty") self.assertRaises(marshmallow.ValidationError, validator, 4) diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py index 86f3853a47deb..872c178bfa547 100644 --- a/tests/integration_tests/viz_tests.py +++ b/tests/integration_tests/viz_tests.py @@ -82,8 +82,8 @@ def test_process_metrics(self): "SUM(SP_URB_TOTL)", "count", ] - self.assertEqual(test_viz.metric_labels, expect_metric_labels) - self.assertEqual(test_viz.all_metrics, expect_metric_labels) + assert test_viz.metric_labels == expect_metric_labels + assert test_viz.all_metrics == expect_metric_labels def test_get_df_returns_empty_df(self): form_data = {"dummy": 123} @@ -91,8 +91,8 @@ def test_get_df_returns_empty_df(self): datasource = self.get_datasource_mock() test_viz = viz.BaseViz(datasource, form_data) result = test_viz.get_df(query_obj) - self.assertEqual(type(result), pd.DataFrame) - self.assertTrue(result.empty) + assert type(result) == pd.DataFrame + assert result.empty def test_get_df_handles_dttm_col(self): form_data = {"dummy": 123} @@ -148,31 +148,31 @@ def test_cache_timeout(self): datasource = self.get_datasource_mock() datasource.cache_timeout = 0 test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual(0, test_viz.cache_timeout) + assert 0 == test_viz.cache_timeout datasource.cache_timeout = 156 test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual(156, test_viz.cache_timeout) + assert 156 == test_viz.cache_timeout datasource.cache_timeout = None datasource.database.cache_timeout = 0 - self.assertEqual(0, test_viz.cache_timeout) + assert 0 == test_viz.cache_timeout datasource.database.cache_timeout = 1666 - self.assertEqual(1666, test_viz.cache_timeout) + assert 1666 == test_viz.cache_timeout datasource.database.cache_timeout = None test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual( - app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"], - test_viz.cache_timeout, + assert ( + app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] + == test_viz.cache_timeout ) data_cache_timeout = app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = None datasource.database.cache_timeout = None test_viz = viz.BaseViz(datasource, form_data={}) - self.assertEqual(app.config["CACHE_DEFAULT_TIMEOUT"], test_viz.cache_timeout) + assert app.config["CACHE_DEFAULT_TIMEOUT"] == test_viz.cache_timeout # restore DATA_CACHE_CONFIG timeout app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] = data_cache_timeout @@ -195,14 +195,14 @@ def test_groupby_nulls(self): ) test_viz = viz.DistributionBarViz(datasource, form_data) data = test_viz.get_data(df)[0] - self.assertEqual("votes", data["key"]) + assert "votes" == data["key"] expected_values = [ {"x": "pepperoni", "y": 5}, {"x": "cheese", "y": 3}, {"x": NULL_STRING, "y": 2}, {"x": "anchovies", "y": 1}, ] - self.assertEqual(expected_values, data["values"]) + assert expected_values == data["values"] def test_groupby_nans(self): form_data = { @@ -216,7 +216,7 @@ def test_groupby_nans(self): df = pd.DataFrame({"beds": [0, 1, nan, 2], "count": [30, 42, 3, 29]}) test_viz = viz.DistributionBarViz(datasource, form_data) data = test_viz.get_data(df)[0] - self.assertEqual("count", data["key"]) + assert "count" == data["key"] expected_values = [ {"x": "1.0", "y": 42}, {"x": "0.0", "y": 30}, @@ -224,7 +224,7 @@ def test_groupby_nans(self): {"x": NULL_STRING, "y": 3}, ] - self.assertEqual(expected_values, data["values"]) + assert expected_values == data["values"] def test_column_nulls(self): form_data = { @@ -254,7 +254,7 @@ def test_column_nulls(self): "values": [{"x": "pepperoni", "y": 5}, {"x": "cheese", "y": 3}], }, ] - self.assertEqual(expected, data) + assert expected == data def test_column_metrics_in_order(self): form_data = { @@ -292,7 +292,7 @@ def test_column_metrics_in_order(self): }, ] - self.assertEqual(expected, data) + assert expected == data def test_column_metrics_in_order_with_breakdowns(self): form_data = { @@ -342,7 +342,7 @@ def test_column_metrics_in_order_with_breakdowns(self): }, ] - self.assertEqual(expected, data) + assert expected == data class TestPairedTTest(SupersetTestCase): @@ -445,7 +445,7 @@ def test_get_data_transforms_dataframe(self): }, ], } - self.assertEqual(data, expected) + assert data == expected def test_get_data_empty_null_keys(self): form_data = {"groupby": [], "metrics": [""]} @@ -472,7 +472,7 @@ def test_get_data_empty_null_keys(self): } ], } - self.assertEqual(data, expected) + assert data == expected form_data = {"groupby": [], "metrics": [None]} with self.assertRaises(ValueError): @@ -487,10 +487,10 @@ def test_query_obj_time_series_option(self, super_query_obj): test_viz = viz.PartitionViz(datasource, form_data) super_query_obj.return_value = {} query_obj = test_viz.query_obj() - self.assertFalse(query_obj["is_timeseries"]) + assert not query_obj["is_timeseries"] test_viz.form_data["time_series_option"] = "agg_sum" query_obj = test_viz.query_obj() - self.assertTrue(query_obj["is_timeseries"]) + assert query_obj["is_timeseries"] def test_levels_for_computes_levels(self): raw = {} @@ -506,37 +506,37 @@ def test_levels_for_computes_levels(self): time_op = "agg_sum" test_viz = viz.PartitionViz(Mock(), {}) levels = test_viz.levels_for(time_op, groups, df) - self.assertEqual(4, len(levels)) + assert 4 == len(levels) expected = {DTTM_ALIAS: 1800, "metric1": 45, "metric2": 450, "metric3": 4500} - self.assertEqual(expected, levels[0].to_dict()) + assert expected == levels[0].to_dict() expected = { DTTM_ALIAS: {"a1": 600, "b1": 600, "c1": 600}, "metric1": {"a1": 6, "b1": 15, "c1": 24}, "metric2": {"a1": 60, "b1": 150, "c1": 240}, "metric3": {"a1": 600, "b1": 1500, "c1": 2400}, } - self.assertEqual(expected, levels[1].to_dict()) - self.assertEqual(["groupA", "groupB"], levels[2].index.names) - self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) + assert expected == levels[1].to_dict() + assert ["groupA", "groupB"] == levels[2].index.names + assert ["groupA", "groupB", "groupC"] == levels[3].index.names time_op = "agg_mean" levels = test_viz.levels_for(time_op, groups, df) - self.assertEqual(4, len(levels)) + assert 4 == len(levels) expected = { DTTM_ALIAS: 200.0, "metric1": 5.0, "metric2": 50.0, "metric3": 500.0, } - self.assertEqual(expected, levels[0].to_dict()) + assert expected == levels[0].to_dict() expected = { DTTM_ALIAS: {"a1": 200, "c1": 200, "b1": 200}, "metric1": {"a1": 2, "b1": 5, "c1": 8}, "metric2": {"a1": 20, "b1": 50, "c1": 80}, "metric3": {"a1": 200, "b1": 500, "c1": 800}, } - self.assertEqual(expected, levels[1].to_dict()) - self.assertEqual(["groupA", "groupB"], levels[2].index.names) - self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) + assert expected == levels[1].to_dict() + assert ["groupA", "groupB"] == levels[2].index.names + assert ["groupA", "groupB", "groupC"] == levels[3].index.names def test_levels_for_diff_computes_difference(self): raw = {} @@ -553,15 +553,15 @@ def test_levels_for_diff_computes_difference(self): time_op = "point_diff" levels = test_viz.levels_for_diff(time_op, groups, df) expected = {"metric1": 6, "metric2": 60, "metric3": 600} - self.assertEqual(expected, levels[0].to_dict()) + assert expected == levels[0].to_dict() expected = { "metric1": {"a1": 2, "b1": 2, "c1": 2}, "metric2": {"a1": 20, "b1": 20, "c1": 20}, "metric3": {"a1": 200, "b1": 200, "c1": 200}, } - self.assertEqual(expected, levels[1].to_dict()) - self.assertEqual(4, len(levels)) - self.assertEqual(["groupA", "groupB", "groupC"], levels[3].index.names) + assert expected == levels[1].to_dict() + assert 4 == len(levels) + assert ["groupA", "groupB", "groupC"] == levels[3].index.names def test_levels_for_time_calls_process_data_and_drops_cols(self): raw = {} @@ -581,16 +581,16 @@ def return_args(df_drop, aggregate): test_viz.process_data = Mock(side_effect=return_args) levels = test_viz.levels_for_time(groups, df) - self.assertEqual(4, len(levels)) + assert 4 == len(levels) cols = [DTTM_ALIAS, "metric1", "metric2", "metric3"] - self.assertEqual(sorted(cols), sorted(levels[0].columns.tolist())) + assert sorted(cols) == sorted(levels[0].columns.tolist()) cols += ["groupA"] - self.assertEqual(sorted(cols), sorted(levels[1].columns.tolist())) + assert sorted(cols) == sorted(levels[1].columns.tolist()) cols += ["groupB"] - self.assertEqual(sorted(cols), sorted(levels[2].columns.tolist())) + assert sorted(cols) == sorted(levels[2].columns.tolist()) cols += ["groupC"] - self.assertEqual(sorted(cols), sorted(levels[3].columns.tolist())) - self.assertEqual(4, len(test_viz.process_data.mock_calls)) + assert sorted(cols) == sorted(levels[3].columns.tolist()) + assert 4 == len(test_viz.process_data.mock_calls) def test_nest_values_returns_hierarchy(self): raw = {} @@ -605,12 +605,12 @@ def test_nest_values_returns_hierarchy(self): groups = ["groupA", "groupB", "groupC"] levels = test_viz.levels_for("agg_sum", groups, df) nest = test_viz.nest_values(levels) - self.assertEqual(3, len(nest)) + assert 3 == len(nest) for i in range(0, 3): - self.assertEqual("metric" + str(i + 1), nest[i]["name"]) - self.assertEqual(3, len(nest[0]["children"])) - self.assertEqual(1, len(nest[0]["children"][0]["children"])) - self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"])) + assert "metric" + str(i + 1) == nest[i]["name"] + assert 3 == len(nest[0]["children"]) + assert 1 == len(nest[0]["children"][0]["children"]) + assert 1 == len(nest[0]["children"][0]["children"][0]["children"]) def test_nest_procs_returns_hierarchy(self): raw = {} @@ -633,15 +633,15 @@ def test_nest_procs_returns_hierarchy(self): ) procs[i] = pivot nest = test_viz.nest_procs(procs) - self.assertEqual(3, len(nest)) + assert 3 == len(nest) for i in range(0, 3): - self.assertEqual("metric" + str(i + 1), nest[i]["name"]) - self.assertEqual(None, nest[i].get("val")) - self.assertEqual(3, len(nest[0]["children"])) - self.assertEqual(3, len(nest[0]["children"][0]["children"])) - self.assertEqual(1, len(nest[0]["children"][0]["children"][0]["children"])) - self.assertEqual( - 1, len(nest[0]["children"][0]["children"][0]["children"][0]["children"]) + assert "metric" + str(i + 1) == nest[i]["name"] + assert None is nest[i].get("val") + assert 3 == len(nest[0]["children"]) + assert 3 == len(nest[0]["children"][0]["children"]) + assert 1 == len(nest[0]["children"][0]["children"][0]["children"]) + assert 1 == len( + nest[0]["children"][0]["children"][0]["children"][0]["children"] ) def test_get_data_calls_correct_method(self): @@ -662,33 +662,33 @@ def test_get_data_calls_correct_method(self): test_viz.form_data["groupby"] = ["groups"] test_viz.form_data["time_series_option"] = "not_time" test_viz.get_data(df) - self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[0][1][0]) + assert "agg_sum" == test_viz.levels_for.mock_calls[0][1][0] test_viz.form_data["time_series_option"] = "agg_sum" test_viz.get_data(df) - self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[1][1][0]) + assert "agg_sum" == test_viz.levels_for.mock_calls[1][1][0] test_viz.form_data["time_series_option"] = "agg_mean" test_viz.get_data(df) - self.assertEqual("agg_mean", test_viz.levels_for.mock_calls[2][1][0]) + assert "agg_mean" == test_viz.levels_for.mock_calls[2][1][0] test_viz.form_data["time_series_option"] = "point_diff" test_viz.levels_for_diff = Mock(return_value=1) test_viz.get_data(df) - self.assertEqual("point_diff", test_viz.levels_for_diff.mock_calls[0][1][0]) + assert "point_diff" == test_viz.levels_for_diff.mock_calls[0][1][0] test_viz.form_data["time_series_option"] = "point_percent" test_viz.get_data(df) - self.assertEqual("point_percent", test_viz.levels_for_diff.mock_calls[1][1][0]) + assert "point_percent" == test_viz.levels_for_diff.mock_calls[1][1][0] test_viz.form_data["time_series_option"] = "point_factor" test_viz.get_data(df) - self.assertEqual("point_factor", test_viz.levels_for_diff.mock_calls[2][1][0]) + assert "point_factor" == test_viz.levels_for_diff.mock_calls[2][1][0] test_viz.levels_for_time = Mock(return_value=1) test_viz.nest_procs = Mock(return_value=1) test_viz.form_data["time_series_option"] = "adv_anal" test_viz.get_data(df) - self.assertEqual(1, len(test_viz.levels_for_time.mock_calls)) - self.assertEqual(1, len(test_viz.nest_procs.mock_calls)) + assert 1 == len(test_viz.levels_for_time.mock_calls) + assert 1 == len(test_viz.nest_procs.mock_calls) test_viz.form_data["time_series_option"] = "time_series" test_viz.get_data(df) - self.assertEqual("agg_sum", test_viz.levels_for.mock_calls[3][1][0]) - self.assertEqual(7, len(test_viz.nest_values.mock_calls)) + assert "agg_sum" == test_viz.levels_for.mock_calls[3][1][0] + assert 7 == len(test_viz.nest_values.mock_calls) class TestRoseVis(SupersetTestCase): @@ -724,7 +724,7 @@ def test_rose_vis_get_data(self): {"time": t3, "value": 9, "key": ("c1",), "name": ("c1",)}, ], } - self.assertEqual(expected, res) + assert expected == res class TestTimeSeriesTableViz(SupersetTestCase): @@ -741,13 +741,13 @@ def test_get_data_metrics(self): test_viz = viz.TimeTableViz(datasource, form_data) data = test_viz.get_data(df) # Check method correctly transforms data - self.assertEqual({"count", "sum__A"}, set(data["columns"])) + assert {"count", "sum__A"} == set(data["columns"]) time_format = "%Y-%m-%d %H:%M:%S" expected = { t1.strftime(time_format): {"sum__A": 15, "count": 6}, t2.strftime(time_format): {"sum__A": 20, "count": 7}, } - self.assertEqual(expected, data["records"]) + assert expected == data["records"] def test_get_data_group_by(self): form_data = {"metrics": ["sum__A"], "groupby": ["groupby1"]} @@ -762,13 +762,13 @@ def test_get_data_group_by(self): test_viz = viz.TimeTableViz(datasource, form_data) data = test_viz.get_data(df) # Check method correctly transforms data - self.assertEqual({"a1", "a2", "a3"}, set(data["columns"])) + assert {"a1", "a2", "a3"} == set(data["columns"]) time_format = "%Y-%m-%d %H:%M:%S" expected = { t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25}, t2.strftime(time_format): {"a1": 30, "a2": 35, "a3": 40}, } - self.assertEqual(expected, data["records"]) + assert expected == data["records"] @patch("superset.viz.BaseViz.query_obj") def test_query_obj_throws_metrics_and_groupby(self, super_query_obj): @@ -788,7 +788,7 @@ def test_query_obj_order_by(self): self.get_datasource_mock(), {"metrics": ["sum__A", "count"], "groupby": []} ) query_obj = test_viz.query_obj() - self.assertEqual(query_obj["orderby"], [("sum__A", False)]) + assert query_obj["orderby"] == [("sum__A", False)] class TestBaseDeckGLViz(SupersetTestCase): @@ -838,7 +838,7 @@ def test_get_properties(self): with self.assertRaises(NotImplementedError) as context: test_viz_deckgl.get_properties(mock_d) - self.assertTrue("" in str(context.exception)) + assert "" in str(context.exception) def test_process_spatial_query_obj(self): form_data = load_fixture("deck_path_form_data.json") @@ -850,7 +850,7 @@ def test_process_spatial_query_obj(self): with self.assertRaises(ValueError) as context: test_viz_deckgl.process_spatial_query_obj(mock_key, mock_gb) - self.assertTrue("Bad spatial key" in str(context.exception)) + assert "Bad spatial key" in str(context.exception) test_form_data = { "latlong_key": {"type": "latlong", "lonCol": "lon", "latCol": "lat"}, @@ -886,14 +886,14 @@ def test_parse_coordinates(self): viz_instance = viz.BaseDeckGLViz(datasource, form_data) coord = viz_instance.parse_coordinates("1.23, 3.21") - self.assertEqual(coord, (1.23, 3.21)) + assert coord == (1.23, 3.21) coord = viz_instance.parse_coordinates("1.23 3.21") - self.assertEqual(coord, (1.23, 3.21)) + assert coord == (1.23, 3.21) - self.assertEqual(viz_instance.parse_coordinates(None), None) + assert viz_instance.parse_coordinates(None) is None - self.assertEqual(viz_instance.parse_coordinates(""), None) + assert viz_instance.parse_coordinates("") is None def test_parse_coordinates_raises(self): form_data = load_fixture("deck_path_form_data.json") @@ -1001,7 +1001,7 @@ def test_timeseries_unicode_data(self): "key": ("Real Madrid C.F.\U0001f1fa\U0001f1f8\U0001f1ec\U0001f1e7",), }, ] - self.assertEqual(expected, viz_data) + assert expected == viz_data def test_process_data_resample(self): datasource = self.get_datasource_mock() @@ -1015,15 +1015,10 @@ def test_process_data_resample(self): } ) - self.assertEqual( - viz.NVD3TimeSeriesViz( - datasource, - {"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"}, - ) - .process_data(df)["y"] - .tolist(), - [1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0], - ) + assert viz.NVD3TimeSeriesViz( + datasource, + {"metrics": ["y"], "resample_method": "sum", "resample_rule": "1D"}, + ).process_data(df)["y"].tolist() == [1.0, 2.0, 0.0, 0.0, 5.0, 0.0, 7.0] np.testing.assert_equal( viz.NVD3TimeSeriesViz( @@ -1043,48 +1038,33 @@ def test_apply_rolling(self): ), data={"y": [1.0, 2.0, 3.0, 4.0]}, ) - self.assertEqual( - viz.NVD3TimeSeriesViz( - datasource, - { - "metrics": ["y"], - "rolling_type": "cumsum", - "rolling_periods": 0, - "min_periods": 0, - }, - ) - .apply_rolling(df)["y"] - .tolist(), - [1.0, 3.0, 6.0, 10.0], - ) - self.assertEqual( - viz.NVD3TimeSeriesViz( - datasource, - { - "metrics": ["y"], - "rolling_type": "sum", - "rolling_periods": 2, - "min_periods": 0, - }, - ) - .apply_rolling(df)["y"] - .tolist(), - [1.0, 3.0, 5.0, 7.0], - ) - self.assertEqual( - viz.NVD3TimeSeriesViz( - datasource, - { - "metrics": ["y"], - "rolling_type": "mean", - "rolling_periods": 10, - "min_periods": 0, - }, - ) - .apply_rolling(df)["y"] - .tolist(), - [1.0, 1.5, 2.0, 2.5], - ) + assert viz.NVD3TimeSeriesViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "cumsum", + "rolling_periods": 0, + "min_periods": 0, + }, + ).apply_rolling(df)["y"].tolist() == [1.0, 3.0, 6.0, 10.0] + assert viz.NVD3TimeSeriesViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "sum", + "rolling_periods": 2, + "min_periods": 0, + }, + ).apply_rolling(df)["y"].tolist() == [1.0, 3.0, 5.0, 7.0] + assert viz.NVD3TimeSeriesViz( + datasource, + { + "metrics": ["y"], + "rolling_type": "mean", + "rolling_periods": 10, + "min_periods": 0, + }, + ).apply_rolling(df)["y"].tolist() == [1.0, 1.5, 2.0, 2.5] def test_apply_rolling_without_data(self): datasource = self.get_datasource_mock()