Skip to content

Commit

Permalink
[PYTHON][RPC] Make rpc proxy jupyter friendly via PopenWorker.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Mar 27, 2021
1 parent 14f829a commit 03e3b7c
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 51 deletions.
10 changes: 9 additions & 1 deletion python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def kill(self):
except IOError:
pass
# kill all child processes recurisvely
kill_child_processes(self._proc.pid)
try:
kill_child_processes(self._proc.pid)
except TypeError:
pass
try:
self._proc.kill()
except OSError:
Expand Down Expand Up @@ -149,6 +152,11 @@ def _start(self):
self._reader = os.fdopen(main_read, "rb")
self._writer = os.fdopen(main_write, "wb")

def join(self):
"""Join the current process worker before it terminates"""
if self._proc:
self._proc.wait()

def send(self, fn, args=(), kwargs=None, timeout=None):
"""Send a new function task fn(*args, **kwargs) to the subprocess.
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/exec/popen_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import threading
import traceback
import pickle
import logging
import cloudpickle

from tvm.contrib.popen_pool import StatusKind
Expand Down Expand Up @@ -49,6 +50,8 @@ def main():
reader = os.fdopen(int(sys.argv[1]), "rb")
writer = os.fdopen(int(sys.argv[2]), "wb")

logging.basicConfig(level=logging.INFO)

lock = threading.Lock()

def _respond(ret_value):
Expand Down
24 changes: 1 addition & 23 deletions python/tvm/exec/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""RPC web proxy, allows redirect to websocket based RPC servers(browsers)"""
from __future__ import absolute_import

import logging
import argparse
import multiprocessing
import sys
import os
from ..rpc.proxy import Proxy
from tvm.rpc.proxy import Proxy


def find_example_resource():
Expand Down Expand Up @@ -82,24 +78,6 @@ def main(args):
"--example-rpc", type=bool, default=False, help="Whether to switch on example rpc mode"
)
parser.add_argument("--tracker", type=str, default="", help="Report to RPC tracker")
parser.add_argument(
"--no-fork",
dest="fork",
action="store_false",
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.",
)
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
logging.info(
"If you are running ROCM/Metal, \
fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
125 changes: 98 additions & 27 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
the proxy server will forward the message between the client and server.
"""
# pylint: disable=unused-variable, unused-argument
from __future__ import absolute_import

import os
import asyncio
import logging
import socket
import multiprocessing
import threading
import errno
import struct
import time
Expand All @@ -43,6 +42,7 @@
"RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg
)

from tvm.contrib.popen_pool import PopenWorker
from . import _ffi_api
from . import base
from .base import TrackerCode
Expand Down Expand Up @@ -261,6 +261,7 @@ def __init__(
logging.info(pair)
self.app = tornado.web.Application(handlers)
self.app.listen(web_port)

self.sock = sock
self.sock.setblocking(0)
self.loop = ioloop.IOLoop.current()
Expand Down Expand Up @@ -471,6 +472,7 @@ def _proxy_server(
index_page,
resource_files,
):
asyncio.set_event_loop(asyncio.new_event_loop())
handler = ProxyServerHandler(
listen_sock,
listen_port,
Expand All @@ -484,6 +486,87 @@ def _proxy_server(
handler.run()


class PopenProxyServerState(object):
"""Internal PopenProxy State for Popen"""

current = None

def __init__(
self,
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
sock.listen(1)
self.thread = threading.Thread(
target=_proxy_server,
args=(
sock,
self.port,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
),
)
# start the server in a different thread
# so we can return the port directly
self.thread.start()


def _popen_start_server(
host,
port=9091,
port_end=9199,
web_port=0,
timeout_client=600,
timeout_server=600,
tracker_addr=None,
index_page=None,
resource_files=None,
):
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenProxyServerState(
host,
port,
port_end,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
)
PopenProxyServerState.current = state
# returns the port so that the main can get the port number.
return state.port


class Proxy(object):
"""Start RPC proxy server on a seperate process.
Expand Down Expand Up @@ -532,43 +615,31 @@ def __init__(
index_page=None,
resource_files=None,
):
sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [98, 48]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logging.info("RPCProxy: client port bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(
target=_proxy_server,
args=(
sock,
self.port,
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_server,
[
host,
port,
port_end,
web_port,
timeout_client,
timeout_server,
tracker_addr,
index_page,
resource_files,
),
],
)
self.proc.start()
sock.close()
# receive the port
self.port = self.proc.recv()
self.host = host

def terminate(self):
"""Terminate the server process"""
if self.proc:
logging.info("Terminating Proxy Server...")
self.proc.terminate()
self.proc.kill()
self.proc = None

def __del__(self):
Expand Down

0 comments on commit 03e3b7c

Please sign in to comment.