diff --git a/superset/security/manager.py b/superset/security/manager.py index a8ffb7d03b2f6..f44924c3721d8 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1481,6 +1481,34 @@ def _delete_pvm_on_sqla_event( # pylint: disable=too-many-arguments view_menu_table.delete().where(view_menu_table.c.id == pvm.view_menu_id) ) + def _find_permission_on_sqla_event( + self, connection: Connection, name: str + ) -> Permission: + permission_table = self.permission_model.__table__ # pylint: disable=no-member + + permission_ = connection.execute( + permission_table.select().where(permission_table.c.name == name) + ).fetchone() + permission = Permission() + permission.metadata = None + permission.id = permission_.id + permission.name = permission_.name + return permission + + def _find_view_menu_on_sqla_event( + self, connection: Connection, name: str + ) -> ViewMenu: + view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member + + view_menu_ = connection.execute( + view_menu_table.select().where(view_menu_table.c.name == name) + ).fetchone() + view_menu = ViewMenu() + view_menu.metadata = None + view_menu.id = view_menu_.id + view_menu.name = view_menu_.name + return view_menu + def _insert_pvm_on_sqla_event( self, mapper: Mapper, @@ -1511,20 +1539,36 @@ def _insert_pvm_on_sqla_event( permission = self.find_permission(permission_name) view_menu = self.find_view_menu(view_menu_name) if not permission: - connection.execute(permission_table.insert().values(name=permission_name)) - permission = self.find_permission(permission_name) + _ = connection.execute( + permission_table.insert().values(name=permission_name) + ) + permission = self._find_permission_on_sqla_event( + connection, permission_name + ) self.on_permission_after_insert(mapper, connection, permission) if not view_menu: - connection.execute(view_menu_table.insert().values(name=view_menu_name)) - view_menu = self.find_view_menu(view_menu_name) + _ = connection.execute(view_menu_table.insert().values(name=view_menu_name)) + view_menu = self._find_view_menu_on_sqla_event(connection, view_menu_name) self.on_view_menu_after_insert(mapper, connection, view_menu) connection.execute( permission_view_table.insert().values( permission_id=permission.id, view_menu_id=view_menu.id ) ) - permission = self.find_permission_view_menu(permission_name, view_menu_name) - self.on_permission_view_after_insert(mapper, connection, permission) + permission_view = connection.execute( + permission_view_table.select().where( + permission_view_table.c.permission_id == permission.id, + permission_view_table.c.view_menu_id == view_menu.id, + ) + ).fetchone() + permission_view_model = PermissionView() + permission_view_model.metadata = None + permission_view_model.id = permission_view.id + permission_view_model.permission_id = permission.id + permission_view_model.view_menu_id = view_menu.id + permission_view_model.permission = permission + permission_view_model.view_menu = view_menu + self.on_permission_view_after_insert(mapper, connection, permission_view_model) def on_role_after_update( self, mapper: Mapper, connection: Connection, target: Role diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 26d7c6e772abd..580e0b59cae77 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -188,20 +188,13 @@ def test_after_insert_dataset(self): self.assertIsNotNone(pvm_schema) # assert on permission hooks - view_menu_dataset = security_manager.find_view_menu( - f"[tmp_db1].[tmp_perm_table](id:{table.id})" - ) - view_menu_schema = security_manager.find_view_menu(f"[tmp_db1].[tmp_schema]") - security_manager.on_view_menu_after_insert.assert_has_calls( - [ - call(ANY, ANY, view_menu_dataset), - call(ANY, ANY, view_menu_schema), - ] - ) + call_args = security_manager.on_permission_view_after_insert.call_args + assert call_args.args[2].id == pvm_schema.id + security_manager.on_permission_view_after_insert.assert_has_calls( [ - call(ANY, ANY, pvm_dataset), - call(ANY, ANY, pvm_schema), + call(ANY, ANY, ANY), + call(ANY, ANY, ANY), ] ) @@ -289,9 +282,11 @@ def test_after_insert_database(self): # Assert the hook is called security_manager.on_permission_view_after_insert.assert_has_calls( [ - call(ANY, ANY, tmp_db1_pvm), + call(ANY, ANY, ANY), ] ) + call_args = security_manager.on_permission_view_after_insert.call_args + assert call_args.args[2].id == tmp_db1_pvm.id session.delete(tmp_db1) session.commit()