From 69d21ed392a1a59a0f26171790e91e82b6a2821d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nikola=20Forr=C3=B3?= Date: Wed, 24 Jan 2024 10:04:54 +0100 Subject: [PATCH] Try to handle PendingRollbackError MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nikola Forró --- packit_service/models.py | 260 +++++++++++++++++++-------------------- 1 file changed, 127 insertions(+), 133 deletions(-) diff --git a/packit_service/models.py b/packit_service/models.py index 5dfe2be25..4bd92f95a 100644 --- a/packit_service/models.py +++ b/packit_service/models.py @@ -12,6 +12,7 @@ from datetime import datetime, timedelta, timezone from os import getenv from typing import ( + Any, Dict, Iterable, List, @@ -47,12 +48,14 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import ( Session as SQLASession, + Query, relationship, scoped_session, sessionmaker, ) from sqlalchemy.sql.functions import count from sqlalchemy.types import ARRAY +from sqlalchemy.exc import PendingRollbackError from packit.config import JobConfigTriggerType from packit_service.constants import ALLOWLIST_CONSTANTS @@ -128,6 +131,21 @@ def sa_session_transaction() -> SQLASession: raise +def db_query(session: SQLASession, *entities: Any) -> Query: + """ + Wrapper for SQLASession.query() that tries to access the connection first + and deal with potential PendingRollbackError. + + Inspired by this answer: https://stackoverflow.com/a/69687698 + """ + try: + session.connection() + except PendingRollbackError as ex: + logger.error(f"Pending rollback error: {ex!r}") + session.rollback() + return session.query(*entities) + + def optional_time( datetime_object: Union[datetime, None], fmt: str = "%d/%m/%Y %H:%M:%S" ) -> Union[str, None]: @@ -287,10 +305,8 @@ class BuildsAndTestsConnector: project_event_model_type: ProjectEventModelType def get_project_event_models(self) -> Iterable["ProjectEventModel"]: - return ( - sa_session() - .query(ProjectEventModel) - .filter_by(type=self.project_event_model_type, event_id=self.id) + return db_query(sa_session(), ProjectEventModel).filter_by( + type=self.project_event_model_type, event_id=self.id ) def get_runs(self) -> List["PipelineModel"]: @@ -472,7 +488,7 @@ def get_or_create( ) -> "GitProjectModel": with sa_session_transaction() as session: project = ( - session.query(GitProjectModel) + db_query(session, GitProjectModel) .filter_by( namespace=namespace, repo_name=repo_name, project_url=project_url ) @@ -487,13 +503,12 @@ def get_or_create( @classmethod def get_by_id(cls, id_: int) -> Optional["GitProjectModel"]: - return sa_session().query(GitProjectModel).filter_by(id=id_).first() + return db_query(sa_session(), GitProjectModel).filter_by(id=id_).first() @classmethod def get_range(cls, first: int, last: int) -> Iterable["GitProjectModel"]: return ( - sa_session() - .query(GitProjectModel) + db_query(sa_session(), GitProjectModel) .order_by(GitProjectModel.namespace) .slice(first, last) ) @@ -504,8 +519,7 @@ def get_by_forge( ) -> Iterable["GitProjectModel"]: """Return projects of given forge""" return ( - sa_session() - .query(GitProjectModel) + db_query(sa_session(), GitProjectModel) .filter_by(instance_url=forge) .order_by(GitProjectModel.namespace) .slice(first, last) @@ -516,10 +530,8 @@ def get_by_forge_namespace( cls, forge: str, namespace: str ) -> Iterable["GitProjectModel"]: """Return projects of given forge and namespace""" - return ( - sa_session() - .query(GitProjectModel) - .filter_by(instance_url=forge, namespace=namespace) + return db_query(sa_session(), GitProjectModel).filter_by( + instance_url=forge, namespace=namespace ) @classmethod @@ -528,8 +540,7 @@ def get_project( ) -> Optional["GitProjectModel"]: """Return one project which matches said criteria""" return ( - sa_session() - .query(cls) + db_query(sa_session(), cls) .filter_by(instance_url=forge, namespace=namespace, repo_name=repo_name) .one_or_none() ) @@ -539,8 +550,7 @@ def get_project_prs( cls, first: int, last: int, forge: str, namespace: str, repo_name: str ) -> Iterable["PullRequestModel"]: return ( - sa_session() - .query(PullRequestModel) + db_query(sa_session(), PullRequestModel) .join(PullRequestModel.project) .filter( GitProjectModel.instance_url == forge, @@ -556,8 +566,7 @@ def get_project_issues( cls, forge: str, namespace: str, repo_name: str ) -> Iterable["IssueModel"]: return ( - sa_session() - .query(IssueModel) + db_query(sa_session(), IssueModel) .join(IssueModel.project) .filter( GitProjectModel.instance_url == forge, @@ -571,8 +580,7 @@ def get_project_branches( cls, forge: str, namespace: str, repo_name: str ) -> Iterable["GitBranchModel"]: return ( - sa_session() - .query(GitBranchModel) + db_query(sa_session(), GitBranchModel) .join(GitBranchModel.project) .filter( GitProjectModel.instance_url == forge, @@ -586,8 +594,7 @@ def get_project_releases( cls, forge: str, namespace: str, repo_name: str ) -> Iterable["ProjectReleaseModel"]: return ( - sa_session() - .query(ProjectReleaseModel) + db_query(sa_session(), ProjectReleaseModel) .join(ProjectReleaseModel.project) .filter( GitProjectModel.instance_url == forge, @@ -658,7 +665,7 @@ def get_project_count( """ Number of project models in the database. """ - return sa_session().query(GitProjectModel).count() + return db_query(sa_session(), GitProjectModel).count() @classmethod @ttl_cache(maxsize=_CACHE_MAXSIZE, ttl=_CACHE_TTL) @@ -667,8 +674,8 @@ def get_instance_numbers(cls) -> Dict[str, int]: Get the number of projects per each GIT instances. """ return dict( - sa_session() - .query( + db_query( + sa_session(), GitProjectModel.instance_url, func.count(GitProjectModel.instance_url), ) @@ -690,8 +697,8 @@ def get_instance_numbers_for_active_projects( for project_event_type in ProjectEventModelType: project_event_model = MODEL_FOR_PROJECT_EVENT[project_event_type] query = ( - sa_session() - .query( + db_query( + sa_session(), GitProjectModel.instance_url, GitProjectModel.project_url, ) @@ -762,8 +769,8 @@ def get_project_event_usage_numbers( """ project_event_model = MODEL_FOR_PROJECT_EVENT[project_event_type] query = ( - sa_session() - .query( + db_query( + sa_session(), GitProjectModel.project_url, count(project_event_model.id).over( partition_by=GitProjectModel.project_url @@ -876,8 +883,8 @@ def get_job_usage_numbers( }[job_result_model] query = ( - sa_session() - .query( + db_query( + sa_session(), GitProjectModel.project_url, count(job_result_model.id).over( partition_by=GitProjectModel.project_url @@ -973,7 +980,7 @@ def get_or_create( namespace=namespace, repo_name=repo_name, project_url=project_url ) pr = ( - session.query(PullRequestModel) + db_query(session, PullRequestModel) .filter_by(pr_id=pr_id, project_id=project.id) .first() ) @@ -993,14 +1000,14 @@ def get( namespace=namespace, repo_name=repo_name, project_url=project_url ) return ( - session.query(PullRequestModel) + db_query(session, PullRequestModel) .filter_by(pr_id=pr_id, project_id=project.id) .first() ) @classmethod def get_by_id(cls, id_: int) -> Optional["PullRequestModel"]: - return sa_session().query(PullRequestModel).filter_by(id=id_).first() + return db_query(sa_session(), PullRequestModel).filter_by(id=id_).first() def __repr__(self): return f"PullRequestModel(pr_id={self.pr_id}, project={self.project})" @@ -1025,7 +1032,7 @@ def get_or_create( namespace=namespace, repo_name=repo_name, project_url=project_url ) issue = ( - session.query(IssueModel) + db_query(session, IssueModel) .filter_by(issue_id=issue_id, project_id=project.id) .first() ) @@ -1038,7 +1045,7 @@ def get_or_create( @classmethod def get_by_id(cls, id_: int) -> Optional["IssueModel"]: - return sa_session().query(IssueModel).filter_by(id=id_).first() + return db_query(sa_session(), IssueModel).filter_by(id=id_).first() def __repr__(self): return f"IssueModel(id={self.issue_id}, project={self.project})" @@ -1063,7 +1070,7 @@ def get_or_create( namespace=namespace, repo_name=repo_name, project_url=project_url ) git_branch = ( - session.query(GitBranchModel) + db_query(session, GitBranchModel) .filter_by(name=branch_name, project_id=project.id) .first() ) @@ -1076,7 +1083,7 @@ def get_or_create( @classmethod def get_by_id(cls, id_: int) -> Optional["GitBranchModel"]: - return sa_session().query(GitBranchModel).filter_by(id=id_).first() + return db_query(sa_session(), GitBranchModel).filter_by(id=id_).first() def __repr__(self): return f"GitBranchModel(name={self.name}, project={self.project})" @@ -1107,7 +1114,7 @@ def get_or_create( namespace=namespace, repo_name=repo_name, project_url=project_url ) project_release = ( - session.query(ProjectReleaseModel) + db_query(session, ProjectReleaseModel) .filter_by(tag_name=tag_name, project_id=project.id) .first() ) @@ -1121,7 +1128,7 @@ def get_or_create( @classmethod def get_by_id(cls, id_: int) -> Optional["ProjectReleaseModel"]: - return sa_session().query(ProjectReleaseModel).filter_by(id=id_).first() + return db_query(sa_session(), ProjectReleaseModel).filter_by(id=id_).first() def __repr__(self): return ( @@ -1267,7 +1274,7 @@ def get_or_create( ) -> "ProjectEventModel": with sa_session_transaction() as session: project_event = ( - session.query(ProjectEventModel) + db_query(session, ProjectEventModel) .filter_by(type=type, event_id=event_id, commit_sha=commit_sha) .first() ) @@ -1281,12 +1288,11 @@ def get_or_create( @classmethod def get_by_id(cls, id_: int) -> Optional["ProjectEventModel"]: - return sa_session().query(ProjectEventModel).filter_by(id=id_).first() + return db_query(sa_session(), ProjectEventModel).filter_by(id=id_).first() def get_project_event_object(self) -> Optional[AbstractProjectObjectDbType]: return ( - sa_session() - .query(MODEL_FOR_PROJECT_EVENT[self.type]) + db_query(sa_session(), MODEL_FOR_PROJECT_EVENT[self.type]) .filter_by(id=self.event_id) .first() ) @@ -1378,7 +1384,8 @@ def __repr__(self): @classmethod def __query_merged_runs(cls): - return sa_session().query( + return db_query( + sa_session(), func.min(PipelineModel.id).label("merged_id"), PipelineModel.srpm_build_id, func.array_agg(psql_array([PipelineModel.copr_build_group_id])).label( @@ -1430,7 +1437,7 @@ def get_merged_run(cls, first_id: int) -> Optional[Iterable["PipelineModel"]]: @classmethod def get_run(cls, id_: int) -> Optional["PipelineModel"]: - return sa_session().query(PipelineModel).filter_by(id=id_).first() + return db_query(sa_session(), PipelineModel).filter_by(id=id_).first() class CoprBuildGroupModel(ProjectAndEventsConnector, GroupModel, Base): @@ -1474,7 +1481,9 @@ def create(cls, run_model: "PipelineModel") -> "CoprBuildGroupModel": @classmethod def get_by_id(cls, group_id: int) -> Optional["CoprBuildGroupModel"]: - return sa_session().query(CoprBuildGroupModel).filter_by(id=group_id).first() + return ( + db_query(sa_session(), CoprBuildGroupModel).filter_by(id=group_id).first() + ) class BuildStatus(str, enum.Enum): @@ -1585,14 +1594,12 @@ def get_srpm_build(self) -> Optional["SRPMBuildModel"]: @classmethod def get_by_id(cls, id_: int) -> Optional["CoprBuildTargetModel"]: - return sa_session().query(CoprBuildTargetModel).filter_by(id=id_).first() + return db_query(sa_session(), CoprBuildTargetModel).filter_by(id=id_).first() @classmethod def get_all(cls) -> Iterable["CoprBuildTargetModel"]: - return ( - sa_session() - .query(CoprBuildTargetModel) - .order_by(desc(CoprBuildTargetModel.id)) + return db_query(sa_session(), CoprBuildTargetModel).order_by( + desc(CoprBuildTargetModel.id) ) @classmethod @@ -1604,8 +1611,8 @@ def get_merged_chroots( https://github.com/packit/packit-service/pull/674#discussion_r439819852 """ return ( - sa_session() - .query( + db_query( + sa_session(), # We need something to order our merged builds by, # so set new_id to be min(ids of to-be-merged rows) func.min(CoprBuildTargetModel.id).label("new_id"), @@ -1635,12 +1642,12 @@ def get_all_by_build_id( if isinstance(build_id, int): # See the comment in get_by_task_id() build_id = str(build_id) - return sa_session().query(CoprBuildTargetModel).filter_by(build_id=build_id) + return db_query(sa_session(), CoprBuildTargetModel).filter_by(build_id=build_id) @classmethod def get_all_by_status(cls, status: BuildStatus) -> Iterable["CoprBuildTargetModel"]: """Returns all builds which currently have the given status.""" - return sa_session().query(CoprBuildTargetModel).filter_by(status=status) + return db_query(sa_session(), CoprBuildTargetModel).filter_by(status=status) # returns the build matching the build_id and the target @classmethod @@ -1653,7 +1660,9 @@ def get_by_build_id( # HINT: No operator matches the given name and argument type(s). # You might need to add explicit type casts. build_id = str(build_id) - query = sa_session().query(CoprBuildTargetModel).filter_by(build_id=build_id) + query = db_query(sa_session(), CoprBuildTargetModel).filter_by( + build_id=build_id + ) if target: query = query.filter_by(target=target) return query.first() @@ -1671,8 +1680,7 @@ def get_all_by( with the given commit_sha and optional target. """ query = ( - sa_session() - .query(CoprBuildTargetModel) + db_query(sa_session(), CoprBuildTargetModel) .join( CoprBuildTargetModel.group_of_targets, ) @@ -1702,8 +1710,7 @@ def get_all_by( def get_all_by_commit(cls, commit_sha: str) -> Iterable["CoprBuildTargetModel"]: """Returns all builds that match a given commit sha""" return ( - sa_session() - .query(CoprBuildTargetModel) + db_query(sa_session(), CoprBuildTargetModel) .join( CoprBuildTargetModel.group_of_targets, ) @@ -1782,7 +1789,7 @@ def __repr__(self) -> str: @classmethod def get_by_id(cls, id_: int) -> Optional["KojiBuildGroupModel"]: - return sa_session().query(KojiBuildGroupModel).filter_by(id=id_).first() + return db_query(sa_session(), KojiBuildGroupModel).filter_by(id=id_).first() @classmethod def create(cls, run_model: "PipelineModel") -> "KojiBuildGroupModel": @@ -1868,17 +1875,16 @@ def create( @classmethod def get_by_id(cls, id_: int) -> Optional["BodhiUpdateTargetModel"]: - return sa_session().query(BodhiUpdateTargetModel).filter_by(id=id_).first() + return db_query(sa_session(), BodhiUpdateTargetModel).filter_by(id=id_).first() @classmethod def get_all(cls) -> Iterable["BodhiUpdateTargetModel"]: - return sa_session().query(BodhiUpdateTargetModel) + return db_query(sa_session(), BodhiUpdateTargetModel) @classmethod def get_range(cls, first: int, last: int) -> Iterable["BodhiUpdateTargetModel"]: return ( - sa_session() - .query(BodhiUpdateTargetModel) + db_query(sa_session(), BodhiUpdateTargetModel) .order_by(desc(BodhiUpdateTargetModel.id)) .slice(first, last) ) @@ -1905,7 +1911,7 @@ def __repr__(self) -> str: @classmethod def get_by_id(cls, id_: int) -> Optional["BodhiUpdateGroupModel"]: - return sa_session().query(BodhiUpdateGroupModel).filter_by(id=id_).first() + return db_query(sa_session(), BodhiUpdateGroupModel).filter_by(id=id_).first() @classmethod def create(cls, run_model: "PipelineModel") -> "BodhiUpdateGroupModel": @@ -2015,20 +2021,18 @@ def get_srpm_build(self) -> Optional["SRPMBuildModel"]: @classmethod def get_by_id(cls, id_: int) -> Optional["KojiBuildTargetModel"]: - return sa_session().query(KojiBuildTargetModel).filter_by(id=id_).first() + return db_query(sa_session(), KojiBuildTargetModel).filter_by(id=id_).first() @classmethod def get_all(cls) -> Iterable["KojiBuildTargetModel"]: - return sa_session().query(KojiBuildTargetModel) + return db_query(sa_session(), KojiBuildTargetModel) @classmethod def get_range( cls, first: int, last: int, scratch: bool = None ) -> Iterable["KojiBuildTargetModel"]: - query = ( - sa_session() - .query(KojiBuildTargetModel) - .order_by(desc(KojiBuildTargetModel.id)) + query = db_query(sa_session(), KojiBuildTargetModel).order_by( + desc(KojiBuildTargetModel.id) ) if scratch is not None: @@ -2049,7 +2053,7 @@ def get_by_task_id( # HINT: No operator matches the given name and argument type(s). # You might need to add explicit type casts. task_id = str(task_id) - query = sa_session().query(KojiBuildTargetModel).filter_by(task_id=task_id) + query = db_query(sa_session(), KojiBuildTargetModel).filter_by(task_id=task_id) if target: query = query.filter_by(target=target) return query.first() @@ -2163,13 +2167,12 @@ def get_by_id( cls, id_: int, ) -> Optional["SRPMBuildModel"]: - return sa_session().query(SRPMBuildModel).filter_by(id=id_).first() + return db_query(sa_session(), SRPMBuildModel).filter_by(id=id_).first() @classmethod def get_range(cls, first: int, last: int) -> Iterable["SRPMBuildModel"]: return ( - sa_session() - .query(SRPMBuildModel) + db_query(sa_session(), SRPMBuildModel) .order_by(desc(SRPMBuildModel.id)) .slice(first, last) ) @@ -2181,8 +2184,7 @@ def get_by_copr_build_id( if isinstance(copr_build_id, int): copr_build_id = str(copr_build_id) return ( - sa_session() - .query(SRPMBuildModel) + db_query(sa_session(), SRPMBuildModel) .filter_by(copr_build_id=copr_build_id) .first() ) @@ -2191,13 +2193,9 @@ def get_by_copr_build_id( def get_older_than(cls, delta: timedelta) -> Iterable["SRPMBuildModel"]: """Return builds older than delta, whose logs/artifacts haven't been discarded yet.""" delta_ago = datetime.now(timezone.utc) - delta - return ( - sa_session() - .query(SRPMBuildModel) - .filter( - SRPMBuildModel.build_submitted_time < delta_ago, - SRPMBuildModel.logs.isnot(None), - ) + return db_query(sa_session(), SRPMBuildModel).filter( + SRPMBuildModel.build_submitted_time < delta_ago, + SRPMBuildModel.logs.isnot(None), ) def set_url(self, url: Optional[str]) -> None: @@ -2303,7 +2301,11 @@ def get_namespace(cls, namespace: str) -> Optional["AllowlistModel"]: Returns: Entry that represents namespace or `None` if cannot be found. """ - return sa_session().query(AllowlistModel).filter_by(namespace=namespace).first() + return ( + db_query(sa_session(), AllowlistModel) + .filter_by(namespace=namespace) + .first() + ) @classmethod def get_by_status(cls, status: str) -> Iterable["AllowlistModel"]: @@ -2316,12 +2318,12 @@ def get_by_status(cls, status: str) -> Iterable["AllowlistModel"]: Returns: List of the namespaces with set status. """ - return sa_session().query(AllowlistModel).filter_by(status=status) + return db_query(sa_session(), AllowlistModel).filter_by(status=status) @classmethod def remove_namespace(cls, namespace: str): with sa_session_transaction() as session: - namespace_entry = session.query(AllowlistModel).filter_by( + namespace_entry = db_query(session, AllowlistModel).filter_by( namespace=namespace ) if namespace_entry.one_or_none(): @@ -2329,7 +2331,7 @@ def remove_namespace(cls, namespace: str): @classmethod def get_all(cls) -> Iterable["AllowlistModel"]: - return sa_session().query(AllowlistModel) + return db_query(sa_session(), AllowlistModel) def to_dict(self) -> Dict[str, str]: return { @@ -2416,7 +2418,9 @@ def grouped_targets(self) -> List["TFTTestRunTargetModel"]: @classmethod def get_by_id(cls, group_id: int) -> Optional["TFTTestRunGroupModel"]: - return sa_session().query(TFTTestRunGroupModel).filter_by(id=group_id).first() + return ( + db_query(sa_session(), TFTTestRunGroupModel).filter_by(id=group_id).first() + ) class TFTTestRunTargetModel(GroupAndTargetModelConnector, Base): @@ -2498,8 +2502,7 @@ def create( @classmethod def get_by_pipeline_id(cls, pipeline_id: str) -> Optional["TFTTestRunTargetModel"]: return ( - sa_session() - .query(TFTTestRunTargetModel) + db_query(sa_session(), TFTTestRunTargetModel) .filter_by(pipeline_id=pipeline_id) .first() ) @@ -2510,15 +2513,13 @@ def get_all_by_status( ) -> Iterable["TFTTestRunTargetModel"]: """Returns all runs which currently have their status set to one of the requested statuses.""" - return ( - sa_session() - .query(TFTTestRunTargetModel) - .filter(TFTTestRunTargetModel.status.in_(status)) + return db_query(sa_session(), TFTTestRunTargetModel).filter( + TFTTestRunTargetModel.status.in_(status) ) @classmethod def get_by_id(cls, id: int) -> Optional["TFTTestRunTargetModel"]: - return sa_session().query(TFTTestRunTargetModel).filter_by(id=id).first() + return db_query(sa_session(), TFTTestRunTargetModel).filter_by(id=id).first() @staticmethod def get_all_by_commit_target( @@ -2529,8 +2530,7 @@ def get_all_by_commit_target( All tests with the given commit_sha and optional target. """ query = ( - sa_session() - .query(TFTTestRunTargetModel) + db_query(sa_session(), TFTTestRunTargetModel) .join( TFTTestRunTargetModel.group_of_targets, ) @@ -2552,8 +2552,7 @@ def get_all_by_commit_target( @classmethod def get_range(cls, first: int, last: int) -> Iterable["TFTTestRunTargetModel"]: return ( - sa_session() - .query(TFTTestRunTargetModel) + db_query(sa_session(), TFTTestRunTargetModel) .order_by(desc(TFTTestRunTargetModel.id)) .slice(first, last) ) @@ -2628,7 +2627,7 @@ def set_logs(self, logs: str) -> None: @classmethod def get_by_id(cls, id_: int) -> Optional["SyncReleaseTargetModel"]: - return sa_session().query(SyncReleaseTargetModel).filter_by(id=id_).first() + return db_query(sa_session(), SyncReleaseTargetModel).filter_by(id=id_).first() class SyncReleaseStatus(str, enum.Enum): @@ -2710,11 +2709,11 @@ def set_status(self, status: SyncReleaseStatus) -> None: @classmethod def get_by_id(cls, id_: int) -> Optional["SyncReleaseModel"]: - return sa_session().query(SyncReleaseModel).filter_by(id=id_).first() + return db_query(sa_session(), SyncReleaseModel).filter_by(id=id_).first() @classmethod def get_all_by_status(cls, status: str) -> Iterable["SyncReleaseModel"]: - return sa_session().query(SyncReleaseModel).filter_by(status=status) + return db_query(sa_session(), SyncReleaseModel).filter_by(status=status) @classmethod def get_range( @@ -2724,8 +2723,7 @@ def get_range( job_type: SyncReleaseJobType = SyncReleaseJobType.propose_downstream, ) -> Iterable["SyncReleaseModel"]: return ( - sa_session() - .query(SyncReleaseModel) + db_query(sa_session(), SyncReleaseModel) .order_by(desc(SyncReleaseModel.id)) .filter_by(job_type=job_type) .slice(first, last) @@ -2761,7 +2759,7 @@ def get_project( namespace=namespace, repo_name=repo_name, project_url=project_url ) return ( - session.query(ProjectAuthenticationIssueModel) + db_query(session, ProjectAuthenticationIssueModel) .filter_by(project_id=project.id) .first() ) @@ -2815,22 +2813,21 @@ def get_project(cls, repository: str) -> "GitProjectModel": @classmethod def get_by_id(cls, id: int) -> Optional["GithubInstallationModel"]: - return sa_session().query(GithubInstallationModel).filter_by(id=id).first() + return db_query(sa_session(), GithubInstallationModel).filter_by(id=id).first() @classmethod def get_by_account_login( cls, account_login: str ) -> Optional["GithubInstallationModel"]: return ( - sa_session() - .query(GithubInstallationModel) + db_query(sa_session(), GithubInstallationModel) .filter_by(account_login=account_login) .first() ) @classmethod def get_all(cls) -> Iterable["GithubInstallationModel"]: - return sa_session().query(GithubInstallationModel) + return db_query(sa_session(), GithubInstallationModel) @classmethod def create_or_update(cls, event): @@ -2915,7 +2912,7 @@ def get_or_create( project_url=dist_git_project_url, ) rel = ( - session.query(SourceGitPRDistGitPRModel) + db_query(session, SourceGitPRDistGitPRModel) .filter_by(source_git_pull_request_id=source_git_pull_request.id) .filter_by(dist_git_pull_request_id=dist_git_pull_request.id) .one_or_none() @@ -2930,8 +2927,7 @@ def get_or_create( @classmethod def get_by_id(cls, id_: int) -> Optional["SourceGitPRDistGitPRModel"]: return ( - sa_session() - .query(SourceGitPRDistGitPRModel) + db_query(sa_session(), SourceGitPRDistGitPRModel) .filter_by(id=id_) .one_or_none() ) @@ -2939,8 +2935,7 @@ def get_by_id(cls, id_: int) -> Optional["SourceGitPRDistGitPRModel"]: @classmethod def get_by_source_git_id(cls, id_: int) -> Optional["SourceGitPRDistGitPRModel"]: return ( - sa_session() - .query(SourceGitPRDistGitPRModel) + db_query(sa_session(), SourceGitPRDistGitPRModel) .filter_by(source_git_pull_request_id=id_) .one_or_none() ) @@ -2948,8 +2943,7 @@ def get_by_source_git_id(cls, id_: int) -> Optional["SourceGitPRDistGitPRModel"] @classmethod def get_by_dist_git_id(cls, id_: int) -> Optional["SourceGitPRDistGitPRModel"]: return ( - sa_session() - .query(SourceGitPRDistGitPRModel) + db_query(sa_session(), SourceGitPRDistGitPRModel) .filter_by(dist_git_pull_request_id=id_) .one_or_none() ) @@ -3022,14 +3016,12 @@ def set_build_logs_url(self, build_logs: str): @classmethod def get_by_id(cls, id_: int) -> Optional["VMImageBuildTargetModel"]: - return sa_session().query(VMImageBuildTargetModel).filter_by(id=id_).first() + return db_query(sa_session(), VMImageBuildTargetModel).filter_by(id=id_).first() @classmethod def get_all(cls) -> Iterable["VMImageBuildTargetModel"]: - return ( - sa_session() - .query(VMImageBuildTargetModel) - .order_by(desc(VMImageBuildTargetModel.id)) + return db_query(sa_session(), VMImageBuildTargetModel).order_by( + desc(VMImageBuildTargetModel.id) ) @classmethod @@ -3040,14 +3032,16 @@ def get_all_by_build_id( if isinstance(build_id, int): # See the comment in get_by_task_id() build_id = str(build_id) - return sa_session().query(VMImageBuildTargetModel).filter_by(build_id=build_id) + return db_query(sa_session(), VMImageBuildTargetModel).filter_by( + build_id=build_id + ) @classmethod def get_all_by_status( cls, status: VMImageBuildStatus ) -> Iterable["VMImageBuildTargetModel"]: """Returns all builds which currently have the given status.""" - return sa_session().query(VMImageBuildTargetModel).filter_by(status=status) + return db_query(sa_session(), VMImageBuildTargetModel).filter_by(status=status) @classmethod def get_by_build_id( @@ -3061,7 +3055,9 @@ def get_by_build_id( # HINT: No operator matches the given name and argument type(s). # You might need to add explicit type casts. build_id = str(build_id) - query = sa_session().query(VMImageBuildTargetModel).filter_by(build_id=build_id) + query = db_query(sa_session(), VMImageBuildTargetModel).filter_by( + build_id=build_id + ) if target: query = query.filter_by(target=target) return query.first() @@ -3077,8 +3073,7 @@ def get_all_by( with the given commit_sha and optional target. """ query = ( - sa_session() - .query(VMImageBuildTargetModel) + db_query(sa_session(), VMImageBuildTargetModel) .join( PipelineModel, PipelineModel.vm_image_build_id == VMImageBuildTargetModel.id, @@ -3103,8 +3098,7 @@ def get_all_by( def get_all_by_commit(cls, commit_sha: str) -> Iterable["VMImageBuildTargetModel"]: """Returns all builds that match a given commit sha""" query = ( - sa_session() - .query(VMImageBuildTargetModel) + db_query(sa_session(), VMImageBuildTargetModel) .join( PipelineModel, PipelineModel.vm_image_build_id == VMImageBuildTargetModel.id,