Skip to content

Commit

Permalink
ZK remains primary job registry #632
Browse files Browse the repository at this point in the history
partial revert of d75bd14
  • Loading branch information
bossie committed Jan 29, 2024
1 parent c7c2465 commit dc3a2b0
Showing 1 changed file with 75 additions and 44 deletions.
119 changes: 75 additions & 44 deletions openeogeotrellis/job_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
JobRegistryInterface,
JobDict,
)
from openeo_driver.util.logging import just_log_exceptions
from openeogeotrellis import sentinel_hub
from openeogeotrellis.configparams import ConfigParams
from openeogeotrellis.testing import KazooClientMock
Expand Down Expand Up @@ -745,6 +746,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.zk_job_registry = None
self._lock.release()

def _just_log_errors(
self, name: str, job_id: Optional[str] = None, extra: Optional[dict] = None
):
"""Context manager to just log exceptions"""
if job_id:
extra = dict(extra or {}, job_id=job_id)
return just_log_exceptions(
log=self._log.warning, name=f"DoubleJobRegistry.{name}", extra=extra
)

def create_job(
self,
job_id: str,
Expand All @@ -770,15 +781,16 @@ def create_job(
description=description,
)
if self.elastic_job_registry:
ejr_job_info = self.elastic_job_registry.create_job(
process=process,
user_id=user_id,
job_id=job_id,
title=title,
description=description,
api_version=api_version,
job_options=job_options,
)
with self._just_log_errors("create_job", job_id=job_id):
ejr_job_info = self.elastic_job_registry.create_job(
process=process,
user_id=user_id,
job_id=job_id,
title=title,
description=description,
api_version=api_version,
job_options=job_options,
)
if zk_job_info is None and ejr_job_info is None:
raise DoubleJobRegistryException(f"None of ZK/EJR registered {job_id=}")
return zk_job_info or ejr_job_info
Expand All @@ -790,8 +802,9 @@ def get_job(self, job_id: str, user_id: str) -> dict:
with contextlib.suppress(JobNotFoundException):
zk_job = self.zk_job_registry.get_job(job_id=job_id, user_id=user_id)
if self.elastic_job_registry:
with contextlib.suppress(JobNotFoundException):
ejr_job = self.elastic_job_registry.get_job(job_id=job_id)
with self._just_log_errors("get_job", job_id=job_id):
with contextlib.suppress(JobNotFoundException):
ejr_job = self.elastic_job_registry.get_job(job_id=job_id)

self._check_zk_ejr_job_info(job_id=job_id, zk_job_info=zk_job, ejr_job_info=ejr_job)
return zk_job or ejr_job
Expand All @@ -804,9 +817,10 @@ def get_job_metadata(self, job_id: str, user_id: str) -> BatchJobMetadata:
with contextlib.suppress(JobNotFoundException):
zk_job_info = self.zk_job_registry.get_job(job_id=job_id, user_id=user_id)
if self.elastic_job_registry:
with TimingLogger(f"self.elastic_job_registry.get_job({job_id=})", logger=_log.debug):
with contextlib.suppress(JobNotFoundException):
ejr_job_info = self.elastic_job_registry.get_job(job_id=job_id)
with self._just_log_errors("get_job_metadata", job_id=job_id):
with TimingLogger(f"self.elastic_job_registry.get_job({job_id=})", logger=_log.debug):
with contextlib.suppress(JobNotFoundException):
ejr_job_info = self.elastic_job_registry.get_job(job_id=job_id)

self._check_zk_ejr_job_info(job_id=job_id, zk_job_info=zk_job_info, ejr_job_info=ejr_job_info)
job_metadata = zk_job_info_to_metadata(zk_job_info) if zk_job_info else ejr_job_info_to_metadata(ejr_job_info)
Expand All @@ -830,13 +844,15 @@ def set_status(self, job_id: str, user_id: str, status: str,
self.zk_job_registry.set_status(job_id=job_id, user_id=user_id, status=status, started=started,
finished=finished)
if self.elastic_job_registry:
self.elastic_job_registry.set_status(job_id=job_id, status=status, started=started, finished=finished)
with self._just_log_errors("set_status", job_id=job_id):
self.elastic_job_registry.set_status(job_id=job_id, status=status, started=started, finished=finished)

def delete_job(self, job_id: str, user_id: str) -> None:
if self.zk_job_registry:
self.zk_job_registry.delete(job_id=job_id, user_id=user_id)
if self.elastic_job_registry:
self.elastic_job_registry.delete_job(job_id=job_id)
with self._just_log_errors("delete", job_id=job_id):
self.elastic_job_registry.delete_job(job_id=job_id)

# Legacy alias
delete = delete_job
Expand All @@ -847,15 +863,17 @@ def set_dependencies(
if self.zk_job_registry:
self.zk_job_registry.set_dependencies(job_id=job_id, user_id=user_id, dependencies=dependencies)
if self.elastic_job_registry:
self.elastic_job_registry.set_dependencies(
job_id=job_id, dependencies=dependencies
)
with self._just_log_errors("set_dependencies", job_id=job_id):
self.elastic_job_registry.set_dependencies(
job_id=job_id, dependencies=dependencies
)

def remove_dependencies(self, job_id: str, user_id: str):
if self.zk_job_registry:
self.zk_job_registry.remove_dependencies(job_id=job_id, user_id=user_id)
if self.elastic_job_registry:
self.elastic_job_registry.remove_dependencies(job_id=job_id)
with self._just_log_errors("remove_dependencies", job_id=job_id):
self.elastic_job_registry.remove_dependencies(job_id=job_id)

def set_dependency_status(
self, job_id: str, user_id: str, dependency_status: str
Expand All @@ -865,38 +883,42 @@ def set_dependency_status(
job_id=job_id, user_id=user_id, dependency_status=dependency_status
)
if self.elastic_job_registry:
self.elastic_job_registry.set_dependency_status(
job_id=job_id, dependency_status=dependency_status
)
with self._just_log_errors("set_dependency_status", job_id=job_id):
self.elastic_job_registry.set_dependency_status(
job_id=job_id, dependency_status=dependency_status
)

def set_dependency_usage(
self, job_id: str, user_id: str, dependency_usage: Decimal
):
if self.zk_job_registry:
self.zk_job_registry.set_dependency_usage(job_id=job_id, user_id=user_id, processing_units=dependency_usage)
if self.elastic_job_registry:
self.elastic_job_registry.set_dependency_usage(
job_id=job_id, dependency_usage=dependency_usage
)
with self._just_log_errors("set_dependency_usage", job_id=job_id):
self.elastic_job_registry.set_dependency_usage(
job_id=job_id, dependency_usage=dependency_usage
)

def set_proxy_user(self, job_id: str, user_id: str, proxy_user: str):
# TODO: add dedicated method
if self.zk_job_registry:
self.zk_job_registry.patch(job_id=job_id, user_id=user_id, proxy_user=proxy_user)
if self.elastic_job_registry:
self.elastic_job_registry.set_proxy_user(
job_id=job_id, proxy_user=proxy_user
)
with self._just_log_errors("set_proxy_user", job_id=job_id):
self.elastic_job_registry.set_proxy_user(
job_id=job_id, proxy_user=proxy_user
)

def set_application_id(
self, job_id: str, user_id: str, application_id: str
) -> None:
if self.zk_job_registry:
self.zk_job_registry.set_application_id(job_id=job_id, user_id=user_id, application_id=application_id)
if self.elastic_job_registry:
self.elastic_job_registry.set_application_id(
job_id=job_id, application_id=application_id
)
with self._just_log_errors("set_application_id", job_id=job_id):
self.elastic_job_registry.set_application_id(
job_id=job_id, application_id=application_id
)

def mark_ongoing(self, job_id: str, user_id: str) -> None:
if self.zk_job_registry:
Expand All @@ -910,10 +932,11 @@ def get_user_jobs(
if self.zk_job_registry:
zk_jobs = [zk_job_info_to_metadata(j) for j in self.zk_job_registry.get_user_jobs(user_id)]
if self.elastic_job_registry:
ejr_jobs = [
ejr_job_info_to_metadata(j)
for j in self.elastic_job_registry.list_user_jobs(user_id=user_id, fields=fields)
]
with self._just_log_errors("get_user_jobs"):
ejr_jobs = [
ejr_job_info_to_metadata(j)
for j in self.elastic_job_registry.list_user_jobs(user_id=user_id, fields=fields)
]

# TODO: more insightful comparison? (e.g. only consider recent jobs)
self._log.log(
Expand Down Expand Up @@ -946,13 +969,20 @@ def get_all_jobs_before(
return jobs

def get_active_jobs(self) -> Iterator[Dict]:
zk_jobs = None
ejr_jobs = None

if self.zk_job_registry:
yield from self.zk_job_registry.get_running_jobs(parse_specification=True)
elif self.elastic_job_registry:
yield from self.elastic_job_registry.list_trackable_jobs(fields=[
"job_id", "user_id", "application_id", "status", "created", "title", "job_options", "dependencies",
"dependency_usage",
])
zk_jobs = [j for j in self.zk_job_registry.get_running_jobs(parse_specification=True)]

if self.elastic_job_registry:
with self._just_log_errors("get_active_jobs"):
ejr_jobs = self.elastic_job_registry.list_trackable_jobs(fields=[
"job_id", "user_id", "application_id", "status", "created", "title", "job_options", "dependencies",
"dependency_usage",
])

yield from (zk_jobs or ejr_jobs or [])

def set_results_metadata(self, job_id, user_id, costs: Optional[float], usage: dict,
results_metadata: Dict[str, Any]):
Expand All @@ -961,5 +991,6 @@ def set_results_metadata(self, job_id, user_id, costs: Optional[float], usage: d
**dict(results_metadata, costs=costs, usage=usage))

if self.elastic_job_registry:
self.elastic_job_registry.set_results_metadata(job_id=job_id, costs=costs, usage=usage,
results_metadata=results_metadata)
with self._just_log_errors("set_results_metadata"):
self.elastic_job_registry.set_results_metadata(job_id=job_id, costs=costs, usage=usage,
results_metadata=results_metadata)

0 comments on commit dc3a2b0

Please sign in to comment.