Skip to content

Commit

Permalink
feat: Automatically end Experiment runs when Tensorboard CustomJob is…
Browse files Browse the repository at this point in the history
… complete

PiperOrigin-RevId: 683653996
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 8, 2024
1 parent 2b8ae76 commit 30cf221
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 4 deletions.
26 changes: 23 additions & 3 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,14 +1633,34 @@ def _block_until_complete(self):

if isinstance(self, CustomJob):
# End the experiment run associated with the custom job, if exists.
experiment_run = self._gca_resource.job_spec.experiment_run
if experiment_run:
experiment_runs = []
if self._gca_resource.job_spec.experiment_run:
experiment_runs = [self._gca_resource.job_spec.experiment_run]
elif self._gca_resource.job_spec.tensorboard:
tensorboard_id = self._gca_resource.job_spec.tensorboard.split("/")[-1]
try:
tb_runs = aiplatform.TensorboardRun.list(
tensorboard_experiment_name=self.name,
tensorboard_id=tensorboard_id,
)
experiment_runs = [
f"{self.name}-{tb_run.name.split('/')[-1]}"
for tb_run in tb_runs
]
except (ValueError, api_exceptions.GoogleAPIError) as e:
_LOGGER.warning(
f"Failed to list experiment runs for tensorboard "
f"{tensorboard_id} due to: {e}"
)
for experiment_run in experiment_runs:
try:
# sync resource before end run
experiment_run_context = aiplatform.Context(experiment_run)
experiment_run_context.update(
metadata={
metadata_constants._STATE_KEY: gca_execution_compat.Execution.State.COMPLETE.name
metadata_constants._STATE_KEY: (
gca_execution_compat.Execution.State.COMPLETE.name
)
}
)
except (ValueError, api_exceptions.GoogleAPIError) as e:
Expand Down
130 changes: 129 additions & 1 deletion tests/unit/aiplatform/test_custom_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform.compat.types import (
custom_job as gca_custom_job_compat,
tensorboard_run as gca_tensorboard_run,
io,
)

Expand All @@ -55,7 +56,8 @@
_TEST_PARENT = test_constants.ProjectConstants._TEST_PARENT

_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_ID}"
_TEST_TENSORBOARD_ID = "987654321"
_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_TENSORBOARD_ID}"
_TEST_ENABLE_WEB_ACCESS = test_constants.TrainingJobConstants._TEST_ENABLE_WEB_ACCESS
_TEST_WEB_ACCESS_URIS = test_constants.TrainingJobConstants._TEST_WEB_ACCESS_URIS
_TEST_TRAINING_CONTAINER_IMAGE = (
Expand Down Expand Up @@ -162,6 +164,8 @@
_TEST_EXPERIMENT_RUN_CONTEXT_NAME = (
f"{_TEST_PARENT_METADATA}/contexts/{_TEST_EXECUTION_ID}"
)
_TEST_TENSORBOARD_RUN_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_TENSORBOARD_ID}/experiments/{_TEST_ID}/runs/{_TEST_RUN}"
_TEST_TENSORBOARD_RUN_CONTEXT_NAME = f"{_TEST_ID}-{_TEST_RUN}"

_EXPERIMENT_MOCK = GapicContext(
name=_TEST_CONTEXT_NAME,
Expand Down Expand Up @@ -207,6 +211,16 @@ def _get_custom_job_proto_with_experiments(state=None, name=None, error=None):
return custom_job_proto


def _get_custom_job_proto_with_tensorboard(state=None, name=None, error=None):
custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
custom_job_proto.job_spec.worker_pool_specs = _TEST_WORKER_POOL_SPEC
custom_job_proto.name = name
custom_job_proto.state = state
custom_job_proto.error = error
custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME
return custom_job_proto


def _get_custom_job_proto_with_enable_web_access(state=None, name=None, error=None):
custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error)
custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
Expand Down Expand Up @@ -284,6 +298,28 @@ def get_custom_job_with_experiments_mock():
yield get_custom_job_mock


@pytest.fixture
def get_custom_job_with_tensorboard_mock():
with patch.object(
job_service_client.JobServiceClient, "get_custom_job"
) as get_custom_job_mock:
get_custom_job_mock.side_effect = [
_get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
),
_get_custom_job_proto(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_RUNNING,
),
_get_custom_job_proto_with_tensorboard(
name=_TEST_CUSTOM_JOB_NAME,
state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED,
),
]
yield get_custom_job_mock


@pytest.fixture
def get_custom_tpu_v5e_job_mock():
with patch.object(
Expand Down Expand Up @@ -822,6 +858,98 @@ def test_run_custom_job_with_experiment_run_warning(self, caplog):
in caplog.text
)

@pytest.mark.usefixtures(
"get_experiment_run_not_found_mock",
"get_tensorboard_run_artifact_not_found_mock",
)
def test_run_custom_job_with_tensorboard_cannot_list_experiment_runs(
self,
create_custom_job_mock_with_tensorboard,
get_custom_job_with_tensorboard_mock,
caplog,
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
tensorboard=_TEST_TENSORBOARD_NAME,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
max_wait_duration=_TEST_MAX_WAIT_DURATION,
)

job.wait()

assert "Failed to list experiment runs for tensorboard" in caplog.text

@pytest.mark.usefixtures(
"get_experiment_run_not_found_mock",
"get_tensorboard_run_artifact_not_found_mock",
)
def test_run_custom_job_with_tensorboard_cannot_end_experiment_run(
self,
create_custom_job_mock_with_tensorboard,
get_custom_job_with_tensorboard_mock,
caplog,
):

aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
staging_bucket=_TEST_STAGING_BUCKET,
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
)

job = aiplatform.CustomJob(
display_name=_TEST_DISPLAY_NAME,
worker_pool_specs=_TEST_WORKER_POOL_SPEC,
base_output_dir=_TEST_BASE_OUTPUT_DIR,
labels=_TEST_LABELS,
)

with mock.patch.object(
aiplatform.TensorboardRun, "list"
) as list_tensorboard_runs_mock:
tb_run = gca_tensorboard_run.TensorboardRun(
name=_TEST_TENSORBOARD_RUN_NAME,
display_name=_TEST_DISPLAY_NAME,
)
list_tensorboard_runs_mock.return_value = [tb_run]

job.run(
service_account=_TEST_SERVICE_ACCOUNT,
tensorboard=_TEST_TENSORBOARD_NAME,
network=_TEST_NETWORK,
timeout=_TEST_TIMEOUT,
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
create_request_timeout=None,
disable_retries=_TEST_DISABLE_RETRIES,
max_wait_duration=_TEST_MAX_WAIT_DURATION,
)

job.wait()

assert (
f"Failed to end experiment run {_TEST_TENSORBOARD_RUN_CONTEXT_NAME} due to:"
in caplog.text
)

@pytest.mark.parametrize("sync", [True, False])
def test_run_custom_job_with_fail_raises(
self, create_custom_job_mock, get_custom_job_mock_with_fail, sync
Expand Down

0 comments on commit 30cf221

Please sign in to comment.