Skip to content

Commit

Permalink
keep connected status in test client
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Feb 16, 2019
1 parent 12aa746 commit 5e399d5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
27 changes: 24 additions & 3 deletions flask_socketio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _mock_send_packet(sid, pkt):
self.acks[self.sid] = None
self.callback_counter = 0
self.socketio = socketio
self.connected = {}
socketio.server._send_packet = _mock_send_packet
socketio.server.environ[self.sid] = {}
if isinstance(socketio.server.manager, PubSubManager):
Expand All @@ -63,6 +64,14 @@ def _mock_send_packet(sid, pkt):
self.connect(namespace=namespace, query_string=query_string,
headers=headers)

def is_connected(self, namespace=None):
"""Check if a namespace is connected.
:param namespace: The namespace to check. The global namespace is
assumed if this argument is not provided.
"""
return self.connected.get(namespace or '/', False)

def connect(self, namespace=None, query_string=None, headers=None):
"""Connect the client.
Expand All @@ -87,22 +96,30 @@ def connect(self, namespace=None, query_string=None, headers=None):
if self.flask_test_client:
# inject cookies from Flask
self.flask_test_client.cookie_jar.inject_wsgi(environ)
self.socketio.server._handle_eio_connect(self.sid, environ)
self.connected['/'] = True
if self.socketio.server._handle_eio_connect(
self.sid, environ) is False:
del self.connected['/']
if namespace is not None and namespace != '/':
self.connected[namespace] = True
pkt = packet.Packet(packet.CONNECT, namespace=namespace)
with self.app.app_context():
self.socketio.server._handle_eio_message(self.sid,
pkt.encode())
if self.socketio.server._handle_eio_message(
self.sid, pkt.encode()) is False:
del self.connected[namespace]

def disconnect(self, namespace=None):
"""Disconnect the client.
:param namespace: The namespace to disconnect. The global namespace is
assumed if this argument is not provided.
"""
if not self.is_connected(namespace):
raise RuntimeError('not connected')
pkt = packet.Packet(packet.DISCONNECT, namespace=namespace)
with self.app.app_context():
self.socketio.server._handle_eio_message(self.sid, pkt.encode())
del self.connected[namespace or '/']

def emit(self, event, *args, **kwargs):
"""Emit an event to the server.
Expand All @@ -120,6 +137,8 @@ def emit(self, event, *args, **kwargs):
assumed if this argument is not provided.
"""
namespace = kwargs.pop('namespace', None)
if not self.is_connected(namespace):
raise RuntimeError('not connected')
callback = kwargs.pop('callback', False)
id = None
if callback:
Expand Down Expand Up @@ -172,6 +191,8 @@ def get_received(self, namespace=None):
namespace is assumed if this argument is not
provided.
"""
if not self.is_connected(namespace):
raise RuntimeError('not connected')
namespace = namespace or '/'
r = [pkt for pkt in self.queue[self.sid]
if pkt['namespace'] == namespace]
Expand Down
4 changes: 4 additions & 0 deletions test_socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,14 @@ def tearDown(self):

def test_connect(self):
client = socketio.test_client(app)
self.assertTrue(client.is_connected())
received = client.get_received()
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected')
self.assertEqual(received[1]['args'], '{}')
self.assertEqual(received[2]['args'], '{}')
client.disconnect()
self.assertFalse(client.is_connected())

def test_connect_query_string_and_headers(self):
client = socketio.test_client(
Expand All @@ -275,12 +277,14 @@ def test_connect_query_string_and_headers(self):

def test_connect_namespace(self):
client = socketio.test_client(app, namespace='/test')
self.assertTrue(client.is_connected('/test'))
received = client.get_received('/test')
self.assertEqual(len(received), 3)
self.assertEqual(received[0]['args'], 'connected-test')
self.assertEqual(received[1]['args'], '{}')
self.assertEqual(received[2]['args'], '{}')
client.disconnect(namespace='/test')
self.assertFalse(client.is_connected('/test'))

def test_connect_namespace_query_string_and_headers(self):
client = socketio.test_client(
Expand Down

0 comments on commit 5e399d5

Please sign in to comment.