From 6f1ab0d1d2d497df51ed2b3760362cd1c4afff14 Mon Sep 17 00:00:00 2001 From: Jared Nance Date: Sat, 11 Jul 2020 19:59:10 -0700 Subject: [PATCH] Fix socket reconnect issue --- aws_embedded_metrics/sinks/tcp_client.py | 54 +++++++--- tests/sinks/test_tcp_client.py | 128 +++++++++++++++++++++++ 2 files changed, 167 insertions(+), 15 deletions(-) create mode 100644 tests/sinks/test_tcp_client.py diff --git a/aws_embedded_metrics/sinks/tcp_client.py b/aws_embedded_metrics/sinks/tcp_client.py index a1e3a93..5ff737d 100644 --- a/aws_embedded_metrics/sinks/tcp_client.py +++ b/aws_embedded_metrics/sinks/tcp_client.py @@ -15,6 +15,7 @@ import logging import socket import threading +import errno from urllib.parse import ParseResult log = logging.getLogger(__name__) @@ -25,24 +26,44 @@ class TcpClient(SocketClient): def __init__(self, endpoint: ParseResult): self._endpoint = endpoint - self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._write_lock = threading.Lock() + # using reentrant lock so that we can retry through recursion + self._write_lock = threading.RLock() + self._connect_lock = threading.RLock() self._should_connect = True def connect(self) -> "TcpClient": - try: - self._sock.connect((self._endpoint.hostname, self._endpoint.port)) - self._should_connect = False - except socket.timeout as e: - log.error("Socket timeout durring connect %s" % (e,)) - self._should_connect = True - except Exception as e: - log.error("Failed to connect to the socket. %s" % (e,)) - self._should_connect = True - return self - - def send_message(self, message: bytes) -> None: - if self._sock._closed or self._should_connect: # type: ignore + with self._connect_lock: + try: + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.connect((self._endpoint.hostname, self._endpoint.port)) + self._should_connect = False + except socket.timeout as e: + log.error("Socket timeout durring connect %s" % (e,)) + except OSError as e: + if e.errno == errno.EISCONN: + log.debug("Socket is already connected.") + self._should_connect = False + else: + log.error("Failed to connect to the socket. %s" % (e,)) + self._should_connect = True + except Exception as e: + log.error("Failed to connect to the socket. %s" % (e,)) + self._should_connect = True + return self + + # TODO: once #21 lands, we should increase the max retries + # the reason this is only 1 is to allow for a single + # reconnect attempt in case the agent disconnects + # additional retries and backoff would impose back + # pressure on the caller that may not be accounted + # for. Before we do that, we need to run the I/O + # operations on a background thread.s + def send_message(self, message: bytes, retry: int = 1) -> None: + if retry < 0: + log.error("Max retries exhausted, dropping message") + return + + if self._sock is None or self._sock._closed or self._should_connect: # type: ignore self.connect() with self._write_lock: @@ -52,9 +73,12 @@ def send_message(self, message: bytes) -> None: except socket.timeout as e: log.error("Socket timeout durring send %s" % (e,)) self.connect() + self.send_message(message, retry - 1) except socket.error as e: log.error("Failed to write metrics to the socket due to socket.error. %s" % (e,)) self.connect() + self.send_message(message, retry - 1) except Exception as e: log.error("Failed to write metrics to the socket due to exception. %s" % (e,)) self.connect() + self.send_message(message, retry - 1) diff --git a/tests/sinks/test_tcp_client.py b/tests/sinks/test_tcp_client.py new file mode 100644 index 0000000..f4c6b07 --- /dev/null +++ b/tests/sinks/test_tcp_client.py @@ -0,0 +1,128 @@ +from aws_embedded_metrics.sinks.tcp_client import TcpClient +from urllib.parse import urlparse +import socket +import threading +import time +import logging + +log = logging.getLogger(__name__) + +test_host = '0.0.0.0' +test_port = 9999 +endpoint = urlparse("tcp://0.0.0.0:9999") +message = "_16-Byte-String_".encode('utf-8') + + +def test_can_send_message(): + # arrange + agent = InProcessAgent().start() + client = TcpClient(endpoint) + + # act + client.connect() + client.send_message(message) + + # assert + time.sleep(1) + messages = agent.messages + assert 1 == len(messages) + assert message == messages[0] + agent.shutdown() + + +def test_can_connect_concurrently_from_threads(): + # arrange + concurrency = 10 + agent = InProcessAgent().start() + client = TcpClient(endpoint) + barrier = threading.Barrier(concurrency, timeout=5) + + def run(): + barrier.wait() + client.connect() + client.send_message(message) + + def start_thread(): + thread = threading.Thread(target=run, args=()) + thread.daemon = True + thread.start() + + # act + for _ in range(concurrency): + start_thread() + + # assert + time.sleep(1) + messages = agent.messages + assert concurrency == len(messages) + for i in range(concurrency): + assert message == messages[i] + agent.shutdown() + + +def test_can_recover_from_agent_shutdown(): + # arrange + agent = InProcessAgent().start() + client = TcpClient(endpoint) + + # act + client.connect() + client.send_message(message) + agent.shutdown() + time.sleep(5) + client.send_message(message) + agent = InProcessAgent().start() + client.send_message(message) + + # assert + time.sleep(1) + messages = agent.messages + assert 1 == len(messages) + assert message == messages[0] + agent.shutdown() + + +class InProcessAgent(object): + """ Agent that runs on a background thread and collects + messages in memory. + """ + + def __init__(self): + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.sock.bind((test_host, test_port)) + self.sock.listen() + self.is_shutdown = False + self.messages = [] + + def start(self) -> "InProcessAgent": + thread = threading.Thread(target=self.run, args=()) + thread.daemon = True + thread.start() + return self + + def run(self): + while not self.is_shutdown: + connection, client_address = self.sock.accept() + self.connection = connection + + try: + while not self.is_shutdown: + log.error("recv") + data = self.connection.recv(16) + if data: + self.messages.append(data) + else: + break + finally: + log.error("Exited the recv loop") + + def shutdown(self): + try: + self.is_shutdown = True + log.error("Connection closing") + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + self.sock.close() + except Exception as e: + log.error("Failed to shutdown %s" % (e,))