Skip to content

Commit

Permalink
fix: bentoml.serve to support serving new SDK services from a bento (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
frostming committed Aug 29, 2024
1 parent c2b3032 commit c253494
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 306 deletions.
38 changes: 7 additions & 31 deletions src/_bentoml_impl/server/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
import tempfile
import typing as t

import attrs
from simple_di import Provide
from simple_di import inject

from _bentoml_sdk import Service
from bentoml._internal.container import BentoMLContainer
from bentoml._internal.utils.circus import Server
from bentoml.exceptions import BentoMLConfigException

AnyService = Service[t.Any]
Expand All @@ -26,8 +26,6 @@
from circus.sockets import CircusSocket
from circus.watcher import Watcher

from bentoml._internal.utils.circus import Arbiter

from .allocator import ResourceAllocator

POSIX = os.name == "posix"
Expand Down Expand Up @@ -104,7 +102,7 @@ def create_dependency_watcher(
working_dir: str | None = None,
env: dict[str, str] | None = None,
) -> tuple[Watcher, CircusSocket, str]:
from bentoml.serve import create_watcher
from bentoml.serving import create_watcher

num_workers, worker_envs = scheduler.get_worker_env(svc)
uri, socket = _get_server_socket(svc, uds_path, port_stack, backlog)
Expand Down Expand Up @@ -178,11 +176,11 @@ def serve_http(
from bentoml._internal.utils import reserve_free_port
from bentoml._internal.utils.analytics.usage_stats import track_serve
from bentoml._internal.utils.circus import create_standalone_arbiter
from bentoml.serve import construct_ssl_args
from bentoml.serve import construct_timeouts_args
from bentoml.serve import create_watcher
from bentoml.serve import ensure_prometheus_dir
from bentoml.serve import make_reload_plugin
from bentoml.serving import construct_ssl_args
from bentoml.serving import construct_timeouts_args
from bentoml.serving import create_watcher
from bentoml.serving import ensure_prometheus_dir
from bentoml.serving import make_reload_plugin

from ..loader import import_service
from ..loader import normalize_identifier
Expand Down Expand Up @@ -354,25 +352,3 @@ def serve_http(
except Exception:
shutil.rmtree(uds_path, ignore_errors=True)
raise


@attrs.frozen
class Server:
url: str
arbiter: Arbiter = attrs.field(repr=False)

def start(self) -> None:
pass

def stop(self) -> None:
self.arbiter.stop()

@property
def running(self) -> bool:
return self.arbiter.running

def __enter__(self) -> Server:
return self

def __exit__(self, exc_type: t.Any, exc_value: t.Any, traceback: t.Any) -> None:
self.stop()
2 changes: 1 addition & 1 deletion src/_bentoml_sdk/service/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
T = t.TypeVar("T", bound=object)

if t.TYPE_CHECKING:
from _bentoml_impl.server.serving import Server
from bentoml._internal import external_typing as ext
from bentoml._internal.service.openapi.specification import OpenAPISpecification
from bentoml._internal.utils.circus import Server

from .dependency import Dependency

Expand Down
13 changes: 3 additions & 10 deletions src/bentoml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,11 @@
from .bentos import export_bento
from .bentos import get
from .bentos import import_bento
from .bentos import list # pylint: disable=W0622
from .bentos import list
from .bentos import pull
from .bentos import push
from .bentos import serve

# server API
from .server import GrpcServer
from .server import HTTPServer

if TYPE_CHECKING:
# Framework specific modules
from _bentoml_impl.frameworks import catboost
Expand Down Expand Up @@ -91,7 +87,6 @@
from . import client # Client API
from . import batch # Batch API
from . import exceptions # BentoML exceptions
from . import server # Server API
from . import monitoring # Monitoring API
from . import cloud # Cloud API
from . import deployment # deployment API
Expand Down Expand Up @@ -212,6 +207,8 @@
_bentoml_sdk = None

def __getattr__(name: str) -> Any:
if name in ("HTTPServer", "GrpcServer"):
return getattr(server, name)
if name not in _NEW_SDK_ATTRS + _NEW_CLIENTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
if _bentoml_sdk is None:
Expand Down Expand Up @@ -239,7 +236,6 @@ def __getattr__(name: str) -> Any:
"container",
"server_context",
"client",
"server",
"io",
"Tag",
"Model",
Expand All @@ -260,9 +256,6 @@ def __getattr__(name: str) -> Any:
"serve",
"Bento",
"exceptions",
# server APIs
"HTTPServer",
"GrpcServer",
# Framework specific modules
"catboost",
"detectron",
Expand Down
23 changes: 23 additions & 0 deletions src/bentoml/_internal/utils/circus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING
from urllib.parse import urlparse

import attrs
from circus.arbiter import Arbiter as _Arbiter

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,3 +99,25 @@ def create_standalone_arbiter(
check_delay=kwargs.pop("check_delay", 10),
**kwargs,
)


@attrs.frozen
class Server:
url: str
arbiter: Arbiter = attrs.field(repr=False)

def start(self) -> None:
pass

def stop(self) -> None:
self.arbiter.stop()

@property
def running(self) -> bool:
return self.arbiter.running

def __enter__(self) -> Server:
return self

def __exit__(self, exc_type: t.Any, exc_value: t.Any, traceback: t.Any) -> None:
self.stop()
162 changes: 101 additions & 61 deletions src/bentoml/bentos.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
from .exceptions import InvalidArgument

if t.TYPE_CHECKING:
from _bentoml_sdk import Service as NewService

from ._internal.bento import BentoStore
from ._internal.bento.build_config import CondaOptions
from ._internal.bento.build_config import DockerOptions
from ._internal.bento.build_config import ModelSpec
from ._internal.bento.build_config import PythonOptions
from ._internal.cloud import BentoCloudClient
from .server import Server
from ._internal.service import Service
from ._internal.utils.circus import Server

Servable = str | Bento | Tag | Service | NewService[t.Any]


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -432,87 +437,122 @@ def containerize(bento_tag: Tag | str, **kwargs: t.Any) -> bool:
return False


@inject
def serve(
bento: str | Tag | Bento,
bento: Servable,
server_type: str = "http",
reload: bool = False,
production: bool = False,
production: bool = True,
env: t.Literal["conda"] | None = None,
host: str | None = None,
port: int | None = None,
working_dir: str | None = None,
api_workers: int | None = Provide[BentoMLContainer.api_server_workers],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[BentoMLContainer.ssl.keyfile_password],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
enable_reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
enable_channelz: bool = Provide[BentoMLContainer.grpc.channelz.enabled],
max_concurrent_streams: int | None = Provide[
BentoMLContainer.grpc.max_concurrent_streams
],
working_dir: str = ".",
api_workers: int | None = None,
backlog: int | None = None,
ssl_certfile: str | None = None,
ssl_keyfile: str | None = None,
ssl_keyfile_password: str | None = None,
ssl_version: int | None = None,
ssl_cert_reqs: int | None = None,
ssl_ca_certs: str | None = None,
ssl_ciphers: str | None = None,
enable_reflection: bool | None = None,
enable_channelz: bool | None = None,
max_concurrent_streams: int | None = None,
grpc_protocol_version: str | None = None,
) -> Server[t.Any]:
logger.warning(
"bentoml.serve and bentoml.bentos.serve are deprecated; use bentoml.Server instead."
)
blocking: bool = False,
) -> Server:
from ._internal.log import configure_logging
from ._internal.service import Service

if isinstance(bento, Bento):
bento = str(bento.tag)
elif isinstance(bento, Tag):
bento = str(bento)

configure_logging()
if server_type == "http":
from .server import HTTPServer
from _bentoml_sdk import Service as NewService

from ._internal.service import load

if not isinstance(bento, (Service, NewService)):
svc = load(bento, working_dir=working_dir)
else:
svc = bento

if isinstance(svc, Service): # < 1.2 bento
from .serving import serve_http_production

if not isinstance(bento, str):
bento, working_dir = svc.get_service_import_origin()

return serve_http_production(
bento_identifier=bento,
reload=reload,
host=host,
port=port,
development_mode=not production,
working_dir=working_dir,
api_workers=api_workers,
backlog=backlog,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
threaded=not blocking,
)
else: # >= 1.2 bento
from _bentoml_impl.server.serving import serve_http

if not isinstance(bento, str):
bento = svc.import_string
working_dir = svc.working_dir

svc.inject_config()
return serve_http(
bento_identifier=bento,
working_dir=working_dir,
reload=reload,
host=host,
port=port,
backlog=backlog,
development_mode=not production,
threaded=not blocking,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
)
elif server_type == "grpc":
from .serving import serve_grpc_production

if host is None:
host = t.cast(str, BentoMLContainer.http.host.get())
if port is None:
port = t.cast(int, BentoMLContainer.http.port.get())
if not isinstance(bento, str):
assert isinstance(bento, Service)
bento, working_dir = bento.get_service_import_origin()

res = HTTPServer(
bento=bento,
return serve_grpc_production(
bento_identifier=bento,
reload=reload,
production=production,
env=env,
host=host,
port=port,
working_dir=working_dir,
api_workers=api_workers,
backlog=backlog,
threaded=not blocking,
development_mode=not production,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
)
elif server_type == "grpc":
from .server import GrpcServer

if host is None:
host = t.cast(str, BentoMLContainer.grpc.host.get())
if port is None:
port = t.cast(int, BentoMLContainer.grpc.port.get())

res = GrpcServer(
bento=bento,
reload=reload,
production=production,
env=env,
host=host,
port=port,
working_dir=working_dir,
api_workers=api_workers,
backlog=backlog,
enable_reflection=enable_reflection,
enable_channelz=enable_channelz,
max_concurrent_streams=max_concurrent_streams,
grpc_protocol_version=grpc_protocol_version,
reflection=enable_reflection,
channelz=enable_channelz,
protocol_version=grpc_protocol_version,
)
else:
raise BadInput(f"Unknown server type: '{server_type}'")

res.start()
return res
Loading

0 comments on commit c253494

Please sign in to comment.