diff --git a/flask_socketio/__init__.py b/flask_socketio/__init__.py index b4d5724a..7e7fb30d 100644 --- a/flask_socketio/__init__.py +++ b/flask_socketio/__init__.py @@ -15,7 +15,7 @@ import socketio import flask -from flask import json as flask_json +from flask import _request_ctx_stack, json as flask_json from werkzeug.debug import DebuggedApplication from werkzeug.serving import run_with_reloader @@ -583,10 +583,11 @@ def _handle_event(self, handler, message, namespace, sid, *args): return '', 400 app = self.server.environ[sid]['flask.app'] with app.request_context(self.server.environ[sid]): - if 'saved_session' in self.server.environ[sid]: - self._copy_session( - self.server.environ[sid]['saved_session'], - flask.session) + if 'saved_session' not in self.server.environ[sid]: + self.server.environ[sid]['saved_session'] = \ + dict(flask.session) + _request_ctx_stack.top.session = \ + self.server.environ[sid]['saved_session'] flask.request.sid = sid flask.request.namespace = namespace flask.request.event = {'message': message, 'args': args} @@ -602,11 +603,6 @@ def _handle_event(self, handler, message, namespace, sid, *args): raise type, value, traceback = sys.exc_info() return err_handler(value) - if flask.session.modified and sid in self.server.environ: - self.server.environ[sid]['saved_session'] = {} - self._copy_session( - flask.session, - self.server.environ[sid]['saved_session']) return ret def _copy_session(self, src, dest): diff --git a/test_socketio.py b/test_socketio.py index bfe55bf5..22afbc41 100755 --- a/test_socketio.py +++ b/test_socketio.py @@ -366,10 +366,11 @@ def test_session(self): client = socketio.test_client(app) client.get_received() client.send('echo this message back') - self.assertNotIn('saved_session', socketio.server.environ[client.sid]) + self.assertEqual(socketio.server.environ[client.sid]['saved_session'], + {}) client.send('test session') - session = socketio.server.environ[client.sid]['saved_session'] - self.assertEqual(session['a'], 'b') + self.assertEqual(socketio.server.environ[client.sid]['saved_session'], + {'a': 'b'}) def test_room(self): client1 = socketio.test_client(app)