Skip to content

Commit

Permalink
Fix socket reconnect issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcnance committed Jul 12, 2020
1 parent ac27573 commit 6f1ab0d
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 15 deletions.
54 changes: 39 additions & 15 deletions aws_embedded_metrics/sinks/tcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import socket
import threading
import errno
from urllib.parse import ParseResult

log = logging.getLogger(__name__)
Expand All @@ -25,24 +26,44 @@
class TcpClient(SocketClient):
def __init__(self, endpoint: ParseResult):
self._endpoint = endpoint
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._write_lock = threading.Lock()
# using reentrant lock so that we can retry through recursion
self._write_lock = threading.RLock()
self._connect_lock = threading.RLock()
self._should_connect = True

def connect(self) -> "TcpClient":
try:
self._sock.connect((self._endpoint.hostname, self._endpoint.port))
self._should_connect = False
except socket.timeout as e:
log.error("Socket timeout durring connect %s" % (e,))
self._should_connect = True
except Exception as e:
log.error("Failed to connect to the socket. %s" % (e,))
self._should_connect = True
return self

def send_message(self, message: bytes) -> None:
if self._sock._closed or self._should_connect: # type: ignore
with self._connect_lock:
try:
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.connect((self._endpoint.hostname, self._endpoint.port))
self._should_connect = False
except socket.timeout as e:
log.error("Socket timeout durring connect %s" % (e,))
except OSError as e:
if e.errno == errno.EISCONN:
log.debug("Socket is already connected.")
self._should_connect = False
else:
log.error("Failed to connect to the socket. %s" % (e,))
self._should_connect = True
except Exception as e:
log.error("Failed to connect to the socket. %s" % (e,))
self._should_connect = True
return self

# TODO: once #21 lands, we should increase the max retries
# the reason this is only 1 is to allow for a single
# reconnect attempt in case the agent disconnects
# additional retries and backoff would impose back
# pressure on the caller that may not be accounted
# for. Before we do that, we need to run the I/O
# operations on a background thread.s
def send_message(self, message: bytes, retry: int = 1) -> None:
if retry < 0:
log.error("Max retries exhausted, dropping message")
return

if self._sock is None or self._sock._closed or self._should_connect: # type: ignore
self.connect()

with self._write_lock:
Expand All @@ -52,9 +73,12 @@ def send_message(self, message: bytes) -> None:
except socket.timeout as e:
log.error("Socket timeout durring send %s" % (e,))
self.connect()
self.send_message(message, retry - 1)
except socket.error as e:
log.error("Failed to write metrics to the socket due to socket.error. %s" % (e,))
self.connect()
self.send_message(message, retry - 1)
except Exception as e:
log.error("Failed to write metrics to the socket due to exception. %s" % (e,))
self.connect()
self.send_message(message, retry - 1)
128 changes: 128 additions & 0 deletions tests/sinks/test_tcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from aws_embedded_metrics.sinks.tcp_client import TcpClient
from urllib.parse import urlparse
import socket
import threading
import time
import logging

log = logging.getLogger(__name__)

test_host = '0.0.0.0'
test_port = 9999
endpoint = urlparse("tcp://0.0.0.0:9999")
message = "_16-Byte-String_".encode('utf-8')


def test_can_send_message():
# arrange
agent = InProcessAgent().start()
client = TcpClient(endpoint)

# act
client.connect()
client.send_message(message)

# assert
time.sleep(1)
messages = agent.messages
assert 1 == len(messages)
assert message == messages[0]
agent.shutdown()


def test_can_connect_concurrently_from_threads():
# arrange
concurrency = 10
agent = InProcessAgent().start()
client = TcpClient(endpoint)
barrier = threading.Barrier(concurrency, timeout=5)

def run():
barrier.wait()
client.connect()
client.send_message(message)

def start_thread():
thread = threading.Thread(target=run, args=())
thread.daemon = True
thread.start()

# act
for _ in range(concurrency):
start_thread()

# assert
time.sleep(1)
messages = agent.messages
assert concurrency == len(messages)
for i in range(concurrency):
assert message == messages[i]
agent.shutdown()


def test_can_recover_from_agent_shutdown():
# arrange
agent = InProcessAgent().start()
client = TcpClient(endpoint)

# act
client.connect()
client.send_message(message)
agent.shutdown()
time.sleep(5)
client.send_message(message)
agent = InProcessAgent().start()
client.send_message(message)

# assert
time.sleep(1)
messages = agent.messages
assert 1 == len(messages)
assert message == messages[0]
agent.shutdown()


class InProcessAgent(object):
""" Agent that runs on a background thread and collects
messages in memory.
"""

def __init__(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((test_host, test_port))
self.sock.listen()
self.is_shutdown = False
self.messages = []

def start(self) -> "InProcessAgent":
thread = threading.Thread(target=self.run, args=())
thread.daemon = True
thread.start()
return self

def run(self):
while not self.is_shutdown:
connection, client_address = self.sock.accept()
self.connection = connection

try:
while not self.is_shutdown:
log.error("recv")
data = self.connection.recv(16)
if data:
self.messages.append(data)
else:
break
finally:
log.error("Exited the recv loop")

def shutdown(self):
try:
self.is_shutdown = True
log.error("Connection closing")
self.connection.shutdown(socket.SHUT_RDWR)
self.connection.close()
self.sock.close()
except Exception as e:
log.error("Failed to shutdown %s" % (e,))

0 comments on commit 6f1ab0d

Please sign in to comment.