Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vertex AI Model Handler Private Endpoint Support #27696

Merged
merged 6 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,24 @@ def parse_known_args(argv):
type=str,
required=True,
help='GCP location for the Endpoint')
parser.add_argument(
'--endpoint_network',
dest='vpc_network',
type=str,
required=False,
help='GCP network the endpoint is peered to')
parser.add_argument(
'--experiment',
dest='experiment',
type=str,
required=False,
help='GCP experiment to pass to init')
help='Vertex AI experiment label to apply to queries')
parser.add_argument(
'--private',
dest='private',
type=bool,
default=False,
help="True if the Vertex AI endpoint is a private endpoint")
return parser.parse_known_args(argv)


Expand Down Expand Up @@ -130,7 +143,9 @@ def run(
endpoint_id=known_args.endpoint,
project=known_args.project,
location=known_args.location,
experiment=known_args.experiment)
experiment=known_args.experiment,
network=known_args.vpc_network,
private=known_args.private)

pipeline = test_pipeline
if not test_pipeline:
Expand Down
60 changes: 41 additions & 19 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Optional
from typing import Sequence

from google.api_core.exceptions import ClientError
from google.api_core.exceptions import ServerError
from google.api_core.exceptions import TooManyRequests
from google.cloud import aiplatform

Expand All @@ -41,20 +41,21 @@
# pylint: disable=line-too-long


def _retry_on_gcp_client_error(exception):
def _retry_on_appropriate_gcp_error(exception):
"""
Retry filter that returns True if a returned HTTP error code is 4xx. This is
used to retry remote requests that fail, most notably 429 (TooManyRequests.)
This is used for GCP-specific client errors.
Retry filter that returns True if a returned HTTP error code is 5xx or 429.
This is used to retry remote requests that fail, most notably 429
(TooManyRequests.)

Args:
exception: the returned exception encountered during the request/response
loop.

Returns:
boolean indication whether or not the exception is a GCP ClientError.
boolean indication whether or not the exception is a Server Error (5xx) or
a TooManyRequests (429) error.
"""
return isinstance(exception, ClientError)
return isinstance(exception, (TooManyRequests, ServerError))


class VertexAIModelHandlerJSON(ModelHandler[Any,
Expand All @@ -67,6 +68,7 @@ def __init__(
location: str,
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
**kwargs):
"""Implementation of the ModelHandler interface for Vertex AI.
**NOTE:** This API and its implementation are under development and
Expand All @@ -76,21 +78,33 @@ def __init__(
Vertex AI endpoint. In that way it functions more like a mid-pipeline
IO. Public Vertex AI endpoints have a maximum request size of 1.5 MB.
If you wish to make larger requests and use a private endpoint, provide
the Compute Engine network you wish to use.
the Compute Engine network you wish to use and set `private=True`

Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to query
project: the GCP project name where the endpoint is deployed
location: the GCP location where the endpoint is deployed
experiment: optional. experiment label to apply to the
queries
queries. See
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
for more information.
network: optional. the full name of the Compute Engine
network the endpoint is deployed on; used for private
endpoints only.
endpoints. The network or subnetwork Dataflow pipeline
option must be set and match this network for pipeline
execution.
Ex: "projects/12345/global/networks/myVPC"
private: optional. if the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
"""

self._env_vars = kwargs.get('env_vars', {})

if private and network is None:
raise ValueError(
"A VPC network must be provided to use a private endpoint.")

# TODO: support the full list of options for aiplatform.init()
# See https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform#google_cloud_aiplatform_init
aiplatform.init(
Expand All @@ -102,7 +116,9 @@ def __init__(
# Check for liveness here but don't try to actually store the endpoint
# in the class yet
self.endpoint_name = endpoint_id
_ = self._retrieve_endpoint(self.endpoint_name)
self.is_private = private

_ = self._retrieve_endpoint(self.endpoint_name, self.is_private)

# Configure AdaptiveThrottler and throttling metrics for client-side
# throttling behavior.
Expand All @@ -113,18 +129,27 @@ def __init__(
self.throttler = AdaptiveThrottler(
window_ms=1, bucket_ms=1, overload_ratio=2)

def _retrieve_endpoint(self, endpoint_id: str) -> aiplatform.Endpoint:
def _retrieve_endpoint(
self, endpoint_id: str, is_private: bool) -> aiplatform.Endpoint:
"""Retrieves an AI Platform endpoint and queries it for liveness/deployed
models.

Args:
endpoint_id: the numerical ID of the Vertex AI endpoint to retrieve.
is_private: a boolean indicating if the Vertex AI endpoint is a private
endpoint
Returns:
An aiplatform.Endpoint object
Raises:
ValueError: if endpoint is inactive or has no models deployed to it.
"""
endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
if is_private:
endpoint: aiplatform.Endpoint = aiplatform.PrivateEndpoint(
endpoint_name=endpoint_id)
LOGGER.debug("Treating endpoint %s as private", endpoint_id)
else:
endpoint = aiplatform.Endpoint(endpoint_name=endpoint_id)
LOGGER.debug("Treating endpoint %s as public", endpoint_id)

try:
mod_list = endpoint.list_models()
Expand All @@ -133,7 +158,7 @@ def _retrieve_endpoint(self, endpoint_id: str) -> aiplatform.Endpoint:
"Failed to contact endpoint %s, got exception: %s", endpoint_id, e)

if len(mod_list) == 0:
raise ValueError("Endpoint %s has no models deployed to it.")
raise ValueError("Endpoint %s has no models deployed to it.", endpoint_id)

return endpoint

Expand All @@ -143,11 +168,11 @@ def load_model(self) -> aiplatform.Endpoint:
"""
# Check to make sure the endpoint is still active since pipeline
# construction time
ep = self._retrieve_endpoint(self.endpoint_name)
ep = self._retrieve_endpoint(self.endpoint_name, self.is_private)
return ep

@retry.with_exponential_backoff(
num_retries=5, retry_filter=_retry_on_gcp_client_error)
num_retries=5, retry_filter=_retry_on_appropriate_gcp_error)
def get_request(
self,
batch: Sequence[Any],
Expand All @@ -170,9 +195,6 @@ def get_request(
except TooManyRequests as e:
LOGGER.warning("request was limited by the service with code %i", e.code)
raise
except ClientError as e:
LOGGER.warning("request failed with error code %i", e.code)
raise
except Exception as e:
LOGGER.error("unexpected exception raised as part of request, got %s", e)
raise
Expand Down
18 changes: 15 additions & 3 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import unittest

try:
from apache_beam.ml.inference.vertex_ai_inference import _retry_on_gcp_client_error
from apache_beam.ml.inference.vertex_ai_inference import _retry_on_appropriate_gcp_error
from apache_beam.ml.inference.vertex_ai_inference import VertexAIModelHandlerJSON
from google.api_core.exceptions import TooManyRequests
except ImportError:
raise unittest.SkipTest('VertexAI dependencies are not installed')
Expand All @@ -28,11 +29,22 @@
class RetryOnClientErrorTest(unittest.TestCase):
def test_retry_on_client_error_positive(self):
e = TooManyRequests(message="fake service rate limiting")
self.assertTrue(_retry_on_gcp_client_error(e))
self.assertTrue(_retry_on_appropriate_gcp_error(e))

def test_retry_on_client_error_negative(self):
e = ValueError()
self.assertFalse(_retry_on_gcp_client_error(e))
self.assertFalse(_retry_on_appropriate_gcp_error(e))


class ModelHandlerArgConditions(unittest.TestCase):
def test_exception_on_private_without_network(self):
self.assertRaises(
ValueError,
VertexAIModelHandlerJSON,
endpoint_id="1",
project="testproject",
location="us-central1",
private=True)


if __name__ == '__main__':
Expand Down
Loading