Skip to content

Commit

Permalink
[RPC] Make tracker jupyter friendly (apache#7961)
Browse files Browse the repository at this point in the history
This PR uses the PopenWorker to handle the tracker start up
and makes the tracker jupyter friendly.
  • Loading branch information
tqchen authored and Trevor Morris committed May 6, 2021
1 parent d496adc commit 570d393
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 53 deletions.
22 changes: 19 additions & 3 deletions python/tvm/contrib/popen_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,26 @@ def _start(self):
self._reader = os.fdopen(main_read, "rb")
self._writer = os.fdopen(main_write, "wb")

def join(self):
"""Join the current process worker before it terminates"""
def join(self, timeout=None):
"""Join the current process worker before it terminates.
Parameters
----------
timeout: Optional[number]
Timeout value, block at most timeout seconds if it
is a positive number.
"""
if self._proc:
try:
self._proc.wait(timeout)
except subprocess.TimeoutExpired:
pass

def is_alive(self):
"""Check if the process is alive"""
if self._proc:
self._proc.wait()
return self._proc.poll() is None
return False

def send(self, fn, args=(), kwargs=None, timeout=None):
"""Send a new function task fn(*args, **kwargs) to the subprocess.
Expand Down
24 changes: 0 additions & 24 deletions python/tvm/exec/rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Tool to start RPC tracker"""
from __future__ import absolute_import

import logging
import argparse
import multiprocessing
import sys
from ..rpc.tracker import Tracker


Expand All @@ -38,27 +34,7 @@ def main(args):
)
parser.add_argument("--port", type=int, default=9190, help="The port of the RPC")
parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC")
parser.add_argument(
"--no-fork",
dest="fork",
action="store_false",
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.",
)
parser.add_argument("--silent", action="store_true", help="Whether run in silent mode.")

parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError("Python3 is required for spawn mode.")
multiprocessing.set_start_method("spawn")
else:
if not args.silent:
logging.info(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main(args)
89 changes: 63 additions & 26 deletions python/tvm/rpc/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
"""
# pylint: disable=invalid-name

import asyncio
import heapq
import logging
import socket
import threading
import multiprocessing
import errno
import struct
import json
from tvm.contrib.popen_pool import PopenWorker

try:
from tornado import ioloop
Expand Down Expand Up @@ -362,14 +363,55 @@ def run(self):


def _tracker_server(listen_sock, stop_key):
asyncio.set_event_loop(asyncio.new_event_loop())
handler = TrackerServerHandler(listen_sock, stop_key)
handler.run()


class PopenTrackerServerState(object):
"""Internal PopenTrackerServer State"""

current = None

def __init__(self, host, port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [errno.EADDRINUSE]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.thread = threading.Thread(target=_tracker_server, args=(sock, self.stop_key))
self.thread.start()
self.host = host


def _popen_start_tracker_server(host, port=9190, port_end=9199, silent=False):
# This is a function that will be sent to the
# Popen worker to run on a separate process.
# Create and start the server in a different thread
state = PopenTrackerServerState(host, port, port_end, silent)
PopenTrackerServerState.current = state
# returns the port so that the main can get the port number.
return (state.port, state.stop_key)


class Tracker(object):
"""Start RPC tracker on a separate process.
Python implementation based on multi-processing.
Python implementation based on PopenWorker.
Parameters
----------
Expand All @@ -389,28 +431,20 @@ class Tracker(object):
def __init__(self, host="0.0.0.0", port=9190, port_end=9199, silent=False):
if silent:
logger.setLevel(logging.WARN)

sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
self.port = None
self.stop_key = base.random_key("tracker")
for my_port in range(port, port_end):
try:
sock.bind((host, my_port))
self.port = my_port
break
except socket.error as sock_err:
if sock_err.errno in [errno.EADDRINUSE]:
continue
raise sock_err
if not self.port:
raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
logger.info("bind to %s:%d", host, self.port)
sock.listen(1)
self.proc = multiprocessing.Process(target=_tracker_server, args=(sock, self.stop_key))
self.proc.start()
self.proc = PopenWorker()
# send the function
self.proc.send(
_popen_start_tracker_server,
[
host,
port,
port_end,
silent,
],
)
# receive the port
self.port, self.stop_key = self.proc.recv()
self.host = host
# close the socket on this process
sock.close()

def _stop_tracker(self):
sock = socket.socket(base.get_addr_family((self.host, self.port)), socket.SOCK_STREAM)
Expand All @@ -427,11 +461,14 @@ def terminate(self):
if self.proc:
if self.proc.is_alive():
self._stop_tracker()
self.proc.join(1)
self.proc.join(0.1)
if self.proc.is_alive():
logger.info("Terminating Tracker Server...")
self.proc.terminate()
self.proc.kill()
self.proc = None

def __del__(self):
self.terminate()
try:
self.terminate()
except TypeError:
pass

0 comments on commit 570d393

Please sign in to comment.