Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add job APIs to armada client #3623

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion client/python/armada_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from datetime import timedelta
import logging
from typing import Dict, Iterator, List, Optional
from typing import Dict, Iterable, Iterator, List, Optional

from google.protobuf import empty_pb2

Expand All @@ -17,6 +17,8 @@
submit_pb2,
submit_pb2_grpc,
health_pb2,
job_pb2,
job_pb2_grpc,
)
from armada_client.event import Event
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
Expand Down Expand Up @@ -101,6 +103,7 @@ class ArmadaClient:
def __init__(self, channel, event_timeout: timedelta = timedelta(minutes=15)):
self.submit_stub = submit_pb2_grpc.SubmitStub(channel)
self.event_stub = event_pb2_grpc.EventStub(channel)
self.jobs_stub = job_pb2_grpc.JobsStub(channel)
self.event_timeout = event_timeout

def get_job_events_stream(
Expand Down Expand Up @@ -180,6 +183,54 @@ def submit_jobs(
response = self.submit_stub.SubmitJobs(request)
return response

def get_job_status(self, job_ids: Iterable[str]) -> job_pb2.JobStatusResponse:
"""Get the statuses of armada jobs.

Uses GetJobStatus RPC to get the statuses of jobs.

:param job_ids: The job ids to get the statuses of.
:return: A JobStatusResponse object.
"""

request = job_pb2.JobStatusRequest(
job_ids=job_ids,
)
return self.jobs_stub.GetJobStatus(request)

def get_job_details(
self, job_ids: Iterable[str], expand_job_spec: bool, expand_job_run: bool
) -> job_pb2.JobDetailsResponse:
"""Get the details of armada jobs.

Uses GetJobDetails RPC to get the details of jobs.

:param job_ids: The job ids to get the details of.
:param expand_job_spec: Whether to include the job_spec field in the response.
:param expand_job_run: Whether to include the job_run field in the response.
:return: A JobDetailsResponse object.
"""
request = job_pb2.JobDetailsRequest(
job_ids=job_ids,
expand_job_spec=expand_job_spec,
expand_job_run=expand_job_run,
)
return self.jobs_stub.GetJobDetails(request)

def get_job_run_details(
self, run_ids: Iterable[str]
) -> job_pb2.JobRunDetailsResponse:
"""Get the details of armada job runs.

Uses GetJobRunDetails RPC to get the details of job runs.

:param run_ids: The job run ids to get the details of.
:return: A JobRunDetailsResponse object.
"""
request = job_pb2.JobRunDetailsRequest(
run_ids=run_ids,
)
return self.jobs_stub.GetJobRunDetails(request)

def cancel_jobs(
self,
queue: str,
Expand Down
28 changes: 28 additions & 0 deletions client/python/tests/unit/server_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
event_pb2,
event_pb2_grpc,
health_pb2,
job_pb2,
job_pb2_grpc,
)


Expand Down Expand Up @@ -101,3 +103,29 @@ def Health(self, request, context):
return health_pb2.HealthCheckResponse(
status=health_pb2.HealthCheckResponse.SERVING
)


class JobsService(job_pb2_grpc.JobsServicer):
def GetJobStatus(self, request, context):
job_states = {}
for job_id in request.job_ids:
job_states[job_id] = submit_pb2.JobState.RUNNING

return job_pb2.JobStatusResponse(job_states=job_states)

def GetJobDetails(self, request, context):
job_details = {}
for job_id in request.job_ids:
job_details[job_id] = job_pb2.JobDetails(
job_id=job_id, job_state=submit_pb2.JobState.RUNNING
)
return job_pb2.JobDetailsResponse(job_details=job_details)

def GetJobRunDetails(self, request, context):
job_run_details = {}
for run_id in request.run_ids:
job_run_details[run_id] = job_pb2.JobRunDetails(
run_id=run_id, state=job_pb2.JobRunState.RUN_STATE_RUNNING
)

return job_pb2.JobRunDetailsResponse()
21 changes: 18 additions & 3 deletions client/python/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import grpc
import pytest

from server_mock import EventService, SubmitService

from armada_client.armada import event_pb2_grpc, submit_pb2_grpc, submit_pb2, health_pb2
from server_mock import EventService, SubmitService, JobsService

from armada_client.armada import (
event_pb2_grpc,
submit_pb2_grpc,
submit_pb2,
health_pb2,
job_pb2_grpc,
)
from armada_client.client import ArmadaClient
from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1
from armada_client.k8s.io.apimachinery.pkg.api.resource import (
Expand All @@ -21,6 +27,7 @@ def server_mock():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server)
event_pb2_grpc.add_EventServicer_to_server(EventService(), server)
job_pb2_grpc.add_JobsServicer_to_server(JobsService(), server)
server.add_insecure_port("[::]:50051")
server.start()

Expand Down Expand Up @@ -95,6 +102,14 @@ def test_submit_job():
assert resp.job_response_items[0].job_id == "job-1"


def test_get_job_status():
test_create_queue()
test_submit_job()

resp = tester.get_job_status(job_ids=["job-1"])
assert resp.job_states["job-1"] == submit_pb2.JobState.RUNNING


def test_create_queue():
queue = tester.create_queue_request(name="test", priority_factor=1)
tester.create_queue(queue)
Expand Down
79 changes: 79 additions & 0 deletions docs/python_armada_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,37 @@ Health check for Event Service.



#### get_job_details(job_ids, expand_job_spec, expand_job_run)
Get the details of armada jobs.

Uses GetJobDetails RPC to get the details of jobs.


* **Parameters**


* **job_ids** (*Iterable**[**str**]*) – The job ids to get the details of.


* **expand_job_spec** (*bool*) – Whether to include the job_spec field in the response.


* **expand_job_run** (*bool*) – Whether to include the job_run field in the response.



* **Returns**

A JobDetailsResponse object.



* **Return type**

armada.job_pb2.JobDetailsResponse



#### get_job_events_stream(queue, job_set_id, from_message_id=None)
Get event stream for a job set.

Expand Down Expand Up @@ -296,6 +327,54 @@ for event in events:



#### get_job_run_details(run_ids)
Get the details of armada job runs.

Uses GetJobRunDetails RPC to get the details of job runs.


* **Parameters**

**run_ids** (*Iterable**[**str**]*) – The job run ids to get the details of.



* **Returns**

A JobRunDetailsResponse object.



* **Return type**

armada.job_pb2.JobRunDetailsResponse



#### get_job_status(job_ids)
Get the statuses of armada jobs.

Uses GetJobStatus RPC to get the statuses of jobs.


* **Parameters**

**job_ids** (*Iterable**[**str**]*) – The job ids to get the statuses of.



* **Returns**

A JobStatusResponse object.



* **Return type**

armada.job_pb2.JobStatusResponse



#### get_queue(name)
Get the queue by name.

Expand Down
Loading