diff --git a/flask_socketio/test_client.py b/flask_socketio/test_client.py index 3e092d2b..c2a70e22 100644 --- a/flask_socketio/test_client.py +++ b/flask_socketio/test_client.py @@ -53,6 +53,7 @@ def _mock_send_packet(sid, pkt): self.acks[self.sid] = None self.callback_counter = 0 self.socketio = socketio + self.connected = {} socketio.server._send_packet = _mock_send_packet socketio.server.environ[self.sid] = {} if isinstance(socketio.server.manager, PubSubManager): @@ -63,6 +64,14 @@ def _mock_send_packet(sid, pkt): self.connect(namespace=namespace, query_string=query_string, headers=headers) + def is_connected(self, namespace=None): + """Check if a namespace is connected. + + :param namespace: The namespace to check. The global namespace is + assumed if this argument is not provided. + """ + return self.connected.get(namespace or '/', False) + def connect(self, namespace=None, query_string=None, headers=None): """Connect the client. @@ -87,12 +96,17 @@ def connect(self, namespace=None, query_string=None, headers=None): if self.flask_test_client: # inject cookies from Flask self.flask_test_client.cookie_jar.inject_wsgi(environ) - self.socketio.server._handle_eio_connect(self.sid, environ) + self.connected['/'] = True + if self.socketio.server._handle_eio_connect( + self.sid, environ) is False: + del self.connected['/'] if namespace is not None and namespace != '/': + self.connected[namespace] = True pkt = packet.Packet(packet.CONNECT, namespace=namespace) with self.app.app_context(): - self.socketio.server._handle_eio_message(self.sid, - pkt.encode()) + if self.socketio.server._handle_eio_message( + self.sid, pkt.encode()) is False: + del self.connected[namespace] def disconnect(self, namespace=None): """Disconnect the client. @@ -100,9 +114,12 @@ def disconnect(self, namespace=None): :param namespace: The namespace to disconnect. The global namespace is assumed if this argument is not provided. """ + if not self.is_connected(namespace): + raise RuntimeError('not connected') pkt = packet.Packet(packet.DISCONNECT, namespace=namespace) with self.app.app_context(): self.socketio.server._handle_eio_message(self.sid, pkt.encode()) + del self.connected[namespace or '/'] def emit(self, event, *args, **kwargs): """Emit an event to the server. @@ -120,6 +137,8 @@ def emit(self, event, *args, **kwargs): assumed if this argument is not provided. """ namespace = kwargs.pop('namespace', None) + if not self.is_connected(namespace): + raise RuntimeError('not connected') callback = kwargs.pop('callback', False) id = None if callback: @@ -172,6 +191,8 @@ def get_received(self, namespace=None): namespace is assumed if this argument is not provided. """ + if not self.is_connected(namespace): + raise RuntimeError('not connected') namespace = namespace or '/' r = [pkt for pkt in self.queue[self.sid] if pkt['namespace'] == namespace] diff --git a/test_socketio.py b/test_socketio.py index 280c9ed7..b2258646 100755 --- a/test_socketio.py +++ b/test_socketio.py @@ -254,12 +254,14 @@ def tearDown(self): def test_connect(self): client = socketio.test_client(app) + self.assertTrue(client.is_connected()) received = client.get_received() self.assertEqual(len(received), 3) self.assertEqual(received[0]['args'], 'connected') self.assertEqual(received[1]['args'], '{}') self.assertEqual(received[2]['args'], '{}') client.disconnect() + self.assertFalse(client.is_connected()) def test_connect_query_string_and_headers(self): client = socketio.test_client( @@ -275,12 +277,14 @@ def test_connect_query_string_and_headers(self): def test_connect_namespace(self): client = socketio.test_client(app, namespace='/test') + self.assertTrue(client.is_connected('/test')) received = client.get_received('/test') self.assertEqual(len(received), 3) self.assertEqual(received[0]['args'], 'connected-test') self.assertEqual(received[1]['args'], '{}') self.assertEqual(received[2]['args'], '{}') client.disconnect(namespace='/test') + self.assertFalse(client.is_connected('/test')) def test_connect_namespace_query_string_and_headers(self): client = socketio.test_client(